425 lines
14 KiB
Python
425 lines
14 KiB
Python
"""
|
|
Benchmark the throughput in the offline mode.
|
|
It accepts server arguments (the same as launch_server.py) and benchmark arguments (the same as bench_serving.py).
|
|
|
|
# Usage
|
|
## Sharegpt dataset with default args
|
|
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10
|
|
|
|
## Random dataset with default args
|
|
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024
|
|
"""
|
|
|
|
import argparse
|
|
import dataclasses
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
import time
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
|
|
from sglang.bench_serving import (
|
|
get_dataset,
|
|
get_tokenizer,
|
|
sample_random_requests,
|
|
set_ulimit,
|
|
)
|
|
from sglang.lang.backend.runtime_endpoint import Runtime
|
|
from sglang.srt.entrypoints.engine import Engine
|
|
from sglang.srt.server_args import ServerArgs
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class BenchArgs:
|
|
backend: str = "engine"
|
|
result_filename: str = ""
|
|
dataset_name: str = "sharegpt"
|
|
dataset_path: str = ""
|
|
num_prompts: int = 1000
|
|
sharegpt_output_len: Optional[int] = None
|
|
sharegpt_context_len: Optional[int] = None
|
|
random_input_len: int = 1024
|
|
random_output_len: int = 1024
|
|
random_range_ratio: float = 0.0
|
|
gsp_num_groups: int = 64
|
|
gsp_prompts_per_group: int = 16
|
|
gsp_system_prompt_len: int = 2048
|
|
gsp_question_len: int = 128
|
|
gsp_output_len: int = 256
|
|
seed: int = 1
|
|
disable_ignore_eos: bool = False
|
|
extra_request_body: Optional[str] = None
|
|
apply_chat_template: bool = False
|
|
profile: bool = False
|
|
skip_warmup: bool = False
|
|
do_not_exit: bool = False
|
|
prompt_suffix: str = ""
|
|
|
|
@staticmethod
|
|
def add_cli_args(parser: argparse.ArgumentParser):
|
|
parser.add_argument("--backend", type=str, default=BenchArgs.backend)
|
|
parser.add_argument(
|
|
"--result-filename", type=str, default=BenchArgs.result_filename
|
|
)
|
|
parser.add_argument(
|
|
"--dataset-name",
|
|
type=str,
|
|
default="sharegpt",
|
|
choices=["sharegpt", "random", "generated-shared-prefix"],
|
|
help="Name of the dataset to benchmark on.",
|
|
)
|
|
parser.add_argument(
|
|
"--dataset-path", type=str, default="", help="Path to the dataset."
|
|
)
|
|
parser.add_argument(
|
|
"--num-prompts",
|
|
type=int,
|
|
default=BenchArgs.num_prompts,
|
|
help="Number of prompts to process. Default is 1000.",
|
|
)
|
|
parser.add_argument(
|
|
"--sharegpt-output-len",
|
|
type=int,
|
|
default=BenchArgs.sharegpt_output_len,
|
|
help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
|
|
)
|
|
parser.add_argument(
|
|
"--sharegpt-context-len",
|
|
type=int,
|
|
default=BenchArgs.sharegpt_context_len,
|
|
help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
|
|
)
|
|
parser.add_argument(
|
|
"--random-input-len",
|
|
type=int,
|
|
default=BenchArgs.random_input_len,
|
|
help="Number of input tokens per request, used only for random dataset.",
|
|
)
|
|
parser.add_argument(
|
|
"--random-output-len",
|
|
type=int,
|
|
default=BenchArgs.random_output_len,
|
|
help="Number of output tokens per request, used only for random dataset.",
|
|
)
|
|
parser.add_argument(
|
|
"--random-range-ratio",
|
|
type=float,
|
|
default=BenchArgs.random_range_ratio,
|
|
help="Range of sampled ratio of input/output length, "
|
|
"used only for random dataset.",
|
|
)
|
|
parser.add_argument(
|
|
"--gsp-num-groups",
|
|
type=int,
|
|
default=BenchArgs.gsp_num_groups,
|
|
help="Number of groups with shared prefix, used"
|
|
"only for generate-shared-prefix",
|
|
)
|
|
parser.add_argument(
|
|
"--gsp-prompts-per-group",
|
|
type=int,
|
|
default=BenchArgs.gsp_prompts_per_group,
|
|
help="Number of prompts per group of shared prefix, used"
|
|
"only for generate-shared-prefix",
|
|
)
|
|
parser.add_argument(
|
|
"--gsp-system-prompt-len",
|
|
type=int,
|
|
default=BenchArgs.gsp_system_prompt_len,
|
|
help="System prompt length, used" "only for generate-shared-prefix",
|
|
)
|
|
parser.add_argument(
|
|
"--gsp-question-len",
|
|
type=int,
|
|
default=BenchArgs.gsp_question_len,
|
|
help="Question length, used" "only for generate-shared-prefix",
|
|
)
|
|
parser.add_argument(
|
|
"--gsp-output-len",
|
|
type=int,
|
|
default=BenchArgs.gsp_output_len,
|
|
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
|
)
|
|
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
|
parser.add_argument(
|
|
"--disable-ignore-eos",
|
|
action="store_true",
|
|
help="Disable ignore EOS token",
|
|
)
|
|
parser.add_argument(
|
|
"--extra-request-body",
|
|
metavar='{"key1": "value1", "key2": "value2"}',
|
|
type=str,
|
|
default=BenchArgs.extra_request_body,
|
|
help="Append given JSON object to the request payload. You can use this to specify"
|
|
"additional generate params like sampling params.",
|
|
)
|
|
parser.add_argument(
|
|
"--apply-chat-template",
|
|
action="store_true",
|
|
help="Apply chat template",
|
|
)
|
|
parser.add_argument(
|
|
"--profile",
|
|
action="store_true",
|
|
help="Use Torch Profiler. The endpoint must be launched with "
|
|
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
|
)
|
|
parser.add_argument(
|
|
"--skip-warmup",
|
|
action="store_true",
|
|
help="Skip the warmup batches.",
|
|
)
|
|
parser.add_argument(
|
|
"--do-not-exit",
|
|
action="store_true",
|
|
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
|
|
)
|
|
parser.add_argument(
|
|
"--prompt-suffix",
|
|
type=str,
|
|
default="",
|
|
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
|
|
)
|
|
|
|
@classmethod
|
|
def from_cli_args(cls, args: argparse.Namespace):
|
|
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
|
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
|
|
|
|
|
def throughput_test_once(
|
|
backend_name: str,
|
|
backend,
|
|
reqs: List[Tuple[str, int, int]],
|
|
ignore_eos: bool,
|
|
extra_request_body: Dict,
|
|
profile: bool,
|
|
):
|
|
measurement_results = {
|
|
"backend": backend_name,
|
|
"successful_requests": len(reqs),
|
|
"total_latency": -1,
|
|
"total_input_tokens": sum(r[1] for r in reqs),
|
|
"total_output_tokens": -1,
|
|
"request_throughput": -1,
|
|
"input_throughput": -1,
|
|
"output_throughput": -1,
|
|
"total_throughput": -1,
|
|
}
|
|
|
|
prompt = [r[0] for r in reqs]
|
|
sampling_params = [
|
|
{
|
|
"temperature": 0,
|
|
"max_new_tokens": r[2],
|
|
"ignore_eos": ignore_eos,
|
|
**extra_request_body,
|
|
}
|
|
for r in reqs
|
|
]
|
|
|
|
if profile:
|
|
assert (
|
|
"SGLANG_TORCH_PROFILER_DIR" in os.environ
|
|
), "Please set SGLANG_TORCH_PROFILER_DIR."
|
|
os.makedirs(os.environ["SGLANG_TORCH_PROFILER_DIR"], exist_ok=True)
|
|
backend.start_profile()
|
|
|
|
st = time.perf_counter()
|
|
gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params)
|
|
latency = time.perf_counter() - st
|
|
|
|
if profile:
|
|
backend.stop_profile()
|
|
monitor_trace_file(os.getenv("SGLANG_TORCH_PROFILER_DIR"))
|
|
|
|
if backend_name == "runtime":
|
|
gen_out = json.loads(gen_out)
|
|
|
|
server_info = backend.get_server_info()
|
|
|
|
measurement_results["total_latency"] = latency
|
|
measurement_results["total_output_tokens"] = sum(
|
|
o["meta_info"]["completion_tokens"] for o in gen_out
|
|
)
|
|
measurement_results["request_throughput"] = (
|
|
measurement_results["successful_requests"] / latency
|
|
)
|
|
measurement_results["input_throughput"] = (
|
|
measurement_results["total_input_tokens"] / latency
|
|
)
|
|
measurement_results["output_throughput"] = (
|
|
measurement_results["total_output_tokens"] / latency
|
|
)
|
|
measurement_results["total_throughput"] = (
|
|
measurement_results["total_input_tokens"]
|
|
+ measurement_results["total_output_tokens"]
|
|
) / latency
|
|
measurement_results["last_gen_throughput"] = server_info["last_gen_throughput"]
|
|
|
|
return measurement_results
|
|
|
|
|
|
def monitor_trace_file(directory, interval=1):
|
|
|
|
print(f"Monitoring {directory} for new trace files...")
|
|
|
|
known_files = set(os.listdir(directory))
|
|
|
|
while True:
|
|
flag = False
|
|
time.sleep(interval)
|
|
current_files = set(os.listdir(directory))
|
|
|
|
new_files = current_files - known_files
|
|
for new_file in new_files:
|
|
new_file_path = os.path.join(directory, new_file)
|
|
print(f"New file detected: {new_file}")
|
|
|
|
previous_size = 0
|
|
while True:
|
|
try:
|
|
current_size = os.path.getsize(new_file_path)
|
|
except FileNotFoundError:
|
|
print(f"File {new_file} is no longer accessible.")
|
|
break
|
|
|
|
if current_size > previous_size:
|
|
previous_size = current_size
|
|
else:
|
|
flag = True
|
|
break
|
|
|
|
time.sleep(interval)
|
|
if flag:
|
|
break
|
|
|
|
|
|
def throughput_test(
|
|
server_args: ServerArgs,
|
|
bench_args: BenchArgs,
|
|
):
|
|
if bench_args.backend == "engine":
|
|
backend = Engine(**dataclasses.asdict(server_args))
|
|
if not backend:
|
|
raise ValueError("Please provide valid engine arguments")
|
|
elif bench_args.backend == "runtime":
|
|
backend = Runtime(**dataclasses.asdict(server_args))
|
|
else:
|
|
raise ValueError('Please set backend to either "engine" or "runtime"')
|
|
|
|
tokenizer_id = server_args.tokenizer_path or server_args.model_path
|
|
tokenizer = get_tokenizer(tokenizer_id)
|
|
|
|
# Set global environmnets
|
|
set_ulimit()
|
|
random.seed(bench_args.seed)
|
|
np.random.seed(bench_args.seed)
|
|
|
|
# Parse args
|
|
extra_request_body = {}
|
|
if bench_args.extra_request_body:
|
|
extra_request_body = json.loads(args.extra_request_body)
|
|
|
|
# Read dataset
|
|
input_requests = get_dataset(bench_args, tokenizer)
|
|
|
|
warmup_requests = sample_random_requests(
|
|
input_len=256,
|
|
output_len=16,
|
|
num_prompts=min(bench_args.num_prompts, 16),
|
|
range_ratio=1.0,
|
|
tokenizer=tokenizer,
|
|
dataset_path=bench_args.dataset_path,
|
|
)
|
|
|
|
# Warm up
|
|
if not bench_args.skip_warmup:
|
|
logging.info("\nWarmup...")
|
|
throughput_test_once(
|
|
backend_name=bench_args.backend,
|
|
backend=backend,
|
|
reqs=warmup_requests,
|
|
ignore_eos=not bench_args.disable_ignore_eos,
|
|
extra_request_body=extra_request_body,
|
|
profile=False,
|
|
)
|
|
time.sleep(0.5)
|
|
|
|
logging.info("\nBenchmark...")
|
|
result = throughput_test_once(
|
|
backend_name=bench_args.backend,
|
|
backend=backend,
|
|
reqs=input_requests,
|
|
ignore_eos=not bench_args.disable_ignore_eos,
|
|
extra_request_body=extra_request_body,
|
|
profile=bench_args.profile,
|
|
)
|
|
backend.shutdown()
|
|
|
|
if bench_args.result_filename:
|
|
with open(bench_args.result_filename, "a") as fout:
|
|
fout.write(json.dumps(result) + "\n")
|
|
|
|
print(
|
|
"\n{s:{c}^{n}}".format(s=" Offline Throughput Benchmark Result ", n=50, c="=")
|
|
)
|
|
print("{:<40} {:<10}".format("Backend:", result["backend"]))
|
|
print("{:<40} {:<10}".format("Successful requests:", result["successful_requests"]))
|
|
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", result["total_latency"]))
|
|
print("{:<40} {:<10}".format("Total input tokens:", result["total_input_tokens"]))
|
|
print(
|
|
"{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"])
|
|
)
|
|
print(
|
|
"{:<40} {:<10.2f}".format(
|
|
"Last generation throughput (tok/s):", result["last_gen_throughput"]
|
|
)
|
|
)
|
|
print(
|
|
"{:<40} {:<10.2f}".format(
|
|
"Request throughput (req/s):", result["request_throughput"]
|
|
)
|
|
)
|
|
print(
|
|
"{:<40} {:<10.2f}".format(
|
|
"Input token throughput (tok/s):", result["input_throughput"]
|
|
)
|
|
)
|
|
print(
|
|
"{:<40} {:<10.2f}".format(
|
|
"Output token throughput (tok/s):", result["output_throughput"]
|
|
)
|
|
)
|
|
print(
|
|
"{:<40} {:<10.2f}".format(
|
|
"Total token throughput (tok/s):", result["total_throughput"]
|
|
)
|
|
)
|
|
print("=" * 50)
|
|
|
|
return result
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
ServerArgs.add_cli_args(parser)
|
|
BenchArgs.add_cli_args(parser)
|
|
args = parser.parse_args()
|
|
server_args = ServerArgs.from_cli_args(args)
|
|
bench_args = BenchArgs.from_cli_args(args)
|
|
|
|
logging.basicConfig(
|
|
level=getattr(logging, server_args.log_level.upper()),
|
|
format="%(message)s",
|
|
)
|
|
|
|
throughput_test(server_args, bench_args)
|
|
|
|
while bench_args.do_not_exit:
|
|
pass
|