127 lines
3.9 KiB
Python
127 lines
3.9 KiB
Python
"""
|
|
Usage:
|
|
python3 -m sglang.test.run_eval --port 30000 --eval-name mmlu --num-examples 10
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import time
|
|
|
|
from sglang.test.simple_eval_common import (
|
|
ChatCompletionSampler,
|
|
make_report,
|
|
set_ulimit,
|
|
)
|
|
|
|
|
|
def run_eval(args):
|
|
set_ulimit()
|
|
|
|
if "OPENAI_API_KEY" not in os.environ:
|
|
os.environ["OPENAI_API_KEY"] = "EMPTY"
|
|
|
|
base_url = (
|
|
f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1"
|
|
)
|
|
|
|
if args.eval_name == "mmlu":
|
|
from sglang.test.simple_eval_mmlu import MMLUEval
|
|
|
|
filename = "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv"
|
|
eval_obj = MMLUEval(filename, args.num_examples, args.num_threads)
|
|
elif args.eval_name == "math":
|
|
from sglang.test.simple_eval_math import MathEval
|
|
|
|
equality_checker = ChatCompletionSampler(model="gpt-4-turbo")
|
|
|
|
filename = (
|
|
"https://openaipublic.blob.core.windows.net/simple-evals/math_test.csv"
|
|
)
|
|
eval_obj = MathEval(
|
|
filename, equality_checker, args.num_examples, args.num_threads
|
|
)
|
|
elif args.eval_name == "mgsm":
|
|
from sglang.test.simple_eval_mgsm import MGSMEval
|
|
|
|
eval_obj = MGSMEval(args.num_examples, args.num_threads)
|
|
elif args.eval_name == "mgsm_en":
|
|
from sglang.test.simple_eval_mgsm import MGSMEval
|
|
|
|
eval_obj = MGSMEval(args.num_examples, args.num_threads, languages=["en"])
|
|
elif args.eval_name == "gpqa":
|
|
from sglang.test.simple_eval_gpqa import GPQAEval
|
|
|
|
filename = (
|
|
"https://openaipublic.blob.core.windows.net/simple-evals/gpqa_diamond.csv"
|
|
)
|
|
eval_obj = GPQAEval(filename, args.num_examples, args.num_threads)
|
|
elif args.eval_name == "humaneval":
|
|
from sglang.test.simple_eval_humaneval import HumanEval
|
|
|
|
eval_obj = HumanEval(args.num_examples, args.num_threads)
|
|
else:
|
|
raise ValueError(f"Invalid eval name: {args.eval_name}")
|
|
|
|
sampler = ChatCompletionSampler(
|
|
model=args.model,
|
|
max_tokens=2048,
|
|
base_url=base_url,
|
|
temperature=getattr(args, "temperature", 0.0),
|
|
)
|
|
|
|
# Run eval
|
|
tic = time.time()
|
|
result = eval_obj(sampler)
|
|
latency = time.time() - tic
|
|
|
|
# Dump reports
|
|
metrics = result.metrics | {"score": result.score}
|
|
file_stem = f"{args.eval_name}_{sampler.model.replace('/', '_')}"
|
|
report_filename = f"/tmp/{file_stem}.html"
|
|
print(f"Writing report to {report_filename}")
|
|
with open(report_filename, "w") as fh:
|
|
fh.write(make_report(result))
|
|
metrics = result.metrics | {"score": result.score}
|
|
print(metrics)
|
|
result_filename = f"/tmp/{file_stem}.json"
|
|
with open(result_filename, "w") as f:
|
|
f.write(json.dumps(metrics, indent=2))
|
|
print(f"Writing results to {result_filename}")
|
|
|
|
# Print results
|
|
print(f"Total latency: {latency:.3f} s")
|
|
print(f"Score: {metrics['score']:.3f}")
|
|
|
|
return metrics
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--base-url",
|
|
type=str,
|
|
default=None,
|
|
help="Server or API base url if not using http host and port.",
|
|
)
|
|
parser.add_argument(
|
|
"--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
|
|
)
|
|
parser.add_argument(
|
|
"--port",
|
|
type=int,
|
|
help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
|
|
)
|
|
parser.add_argument(
|
|
"--model",
|
|
type=str,
|
|
help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
|
|
)
|
|
parser.add_argument("--eval-name", type=str, default="mmlu")
|
|
parser.add_argument("--num-examples", type=int)
|
|
parser.add_argument("--num-threads", type=int, default=512)
|
|
parser.add_argument("--temperature", type=float, default=0.0)
|
|
args = parser.parse_args()
|
|
|
|
run_eval(args)
|