sglang0.4.5.post1/python/sglang/bench_offline_throughput.py

425 lines
14 KiB
Python

"""
Benchmark the throughput in the offline mode.
It accepts server arguments (the same as launch_server.py) and benchmark arguments (the same as bench_serving.py).
# Usage
## Sharegpt dataset with default args
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10
## Random dataset with default args
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024
"""
import argparse
import dataclasses
import json
import logging
import os
import random
import time
from typing import Dict, List, Optional, Tuple
import numpy as np
from sglang.bench_serving import (
get_dataset,
get_tokenizer,
sample_random_requests,
set_ulimit,
)
from sglang.lang.backend.runtime_endpoint import Runtime
from sglang.srt.entrypoints.engine import Engine
from sglang.srt.server_args import ServerArgs
@dataclasses.dataclass
class BenchArgs:
backend: str = "engine"
result_filename: str = ""
dataset_name: str = "sharegpt"
dataset_path: str = ""
num_prompts: int = 1000
sharegpt_output_len: Optional[int] = None
sharegpt_context_len: Optional[int] = None
random_input_len: int = 1024
random_output_len: int = 1024
random_range_ratio: float = 0.0
gsp_num_groups: int = 64
gsp_prompts_per_group: int = 16
gsp_system_prompt_len: int = 2048
gsp_question_len: int = 128
gsp_output_len: int = 256
seed: int = 1
disable_ignore_eos: bool = False
extra_request_body: Optional[str] = None
apply_chat_template: bool = False
profile: bool = False
skip_warmup: bool = False
do_not_exit: bool = False
prompt_suffix: str = ""
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--backend", type=str, default=BenchArgs.backend)
parser.add_argument(
"--result-filename", type=str, default=BenchArgs.result_filename
)
parser.add_argument(
"--dataset-name",
type=str,
default="sharegpt",
choices=["sharegpt", "random", "generated-shared-prefix"],
help="Name of the dataset to benchmark on.",
)
parser.add_argument(
"--dataset-path", type=str, default="", help="Path to the dataset."
)
parser.add_argument(
"--num-prompts",
type=int,
default=BenchArgs.num_prompts,
help="Number of prompts to process. Default is 1000.",
)
parser.add_argument(
"--sharegpt-output-len",
type=int,
default=BenchArgs.sharegpt_output_len,
help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
)
parser.add_argument(
"--sharegpt-context-len",
type=int,
default=BenchArgs.sharegpt_context_len,
help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
)
parser.add_argument(
"--random-input-len",
type=int,
default=BenchArgs.random_input_len,
help="Number of input tokens per request, used only for random dataset.",
)
parser.add_argument(
"--random-output-len",
type=int,
default=BenchArgs.random_output_len,
help="Number of output tokens per request, used only for random dataset.",
)
parser.add_argument(
"--random-range-ratio",
type=float,
default=BenchArgs.random_range_ratio,
help="Range of sampled ratio of input/output length, "
"used only for random dataset.",
)
parser.add_argument(
"--gsp-num-groups",
type=int,
default=BenchArgs.gsp_num_groups,
help="Number of groups with shared prefix, used"
"only for generate-shared-prefix",
)
parser.add_argument(
"--gsp-prompts-per-group",
type=int,
default=BenchArgs.gsp_prompts_per_group,
help="Number of prompts per group of shared prefix, used"
"only for generate-shared-prefix",
)
parser.add_argument(
"--gsp-system-prompt-len",
type=int,
default=BenchArgs.gsp_system_prompt_len,
help="System prompt length, used" "only for generate-shared-prefix",
)
parser.add_argument(
"--gsp-question-len",
type=int,
default=BenchArgs.gsp_question_len,
help="Question length, used" "only for generate-shared-prefix",
)
parser.add_argument(
"--gsp-output-len",
type=int,
default=BenchArgs.gsp_output_len,
help="Target length in tokens for outputs in generated-shared-prefix dataset",
)
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument(
"--disable-ignore-eos",
action="store_true",
help="Disable ignore EOS token",
)
parser.add_argument(
"--extra-request-body",
metavar='{"key1": "value1", "key2": "value2"}',
type=str,
default=BenchArgs.extra_request_body,
help="Append given JSON object to the request payload. You can use this to specify"
"additional generate params like sampling params.",
)
parser.add_argument(
"--apply-chat-template",
action="store_true",
help="Apply chat template",
)
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
)
parser.add_argument(
"--skip-warmup",
action="store_true",
help="Skip the warmup batches.",
)
parser.add_argument(
"--do-not-exit",
action="store_true",
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
)
parser.add_argument(
"--prompt-suffix",
type=str,
default="",
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
attrs = [attr.name for attr in dataclasses.fields(cls)]
return cls(**{attr: getattr(args, attr) for attr in attrs})
def throughput_test_once(
backend_name: str,
backend,
reqs: List[Tuple[str, int, int]],
ignore_eos: bool,
extra_request_body: Dict,
profile: bool,
):
measurement_results = {
"backend": backend_name,
"successful_requests": len(reqs),
"total_latency": -1,
"total_input_tokens": sum(r[1] for r in reqs),
"total_output_tokens": -1,
"request_throughput": -1,
"input_throughput": -1,
"output_throughput": -1,
"total_throughput": -1,
}
prompt = [r[0] for r in reqs]
sampling_params = [
{
"temperature": 0,
"max_new_tokens": r[2],
"ignore_eos": ignore_eos,
**extra_request_body,
}
for r in reqs
]
if profile:
assert (
"SGLANG_TORCH_PROFILER_DIR" in os.environ
), "Please set SGLANG_TORCH_PROFILER_DIR."
os.makedirs(os.environ["SGLANG_TORCH_PROFILER_DIR"], exist_ok=True)
backend.start_profile()
st = time.perf_counter()
gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params)
latency = time.perf_counter() - st
if profile:
backend.stop_profile()
monitor_trace_file(os.getenv("SGLANG_TORCH_PROFILER_DIR"))
if backend_name == "runtime":
gen_out = json.loads(gen_out)
server_info = backend.get_server_info()
measurement_results["total_latency"] = latency
measurement_results["total_output_tokens"] = sum(
o["meta_info"]["completion_tokens"] for o in gen_out
)
measurement_results["request_throughput"] = (
measurement_results["successful_requests"] / latency
)
measurement_results["input_throughput"] = (
measurement_results["total_input_tokens"] / latency
)
measurement_results["output_throughput"] = (
measurement_results["total_output_tokens"] / latency
)
measurement_results["total_throughput"] = (
measurement_results["total_input_tokens"]
+ measurement_results["total_output_tokens"]
) / latency
measurement_results["last_gen_throughput"] = server_info["last_gen_throughput"]
return measurement_results
def monitor_trace_file(directory, interval=1):
print(f"Monitoring {directory} for new trace files...")
known_files = set(os.listdir(directory))
while True:
flag = False
time.sleep(interval)
current_files = set(os.listdir(directory))
new_files = current_files - known_files
for new_file in new_files:
new_file_path = os.path.join(directory, new_file)
print(f"New file detected: {new_file}")
previous_size = 0
while True:
try:
current_size = os.path.getsize(new_file_path)
except FileNotFoundError:
print(f"File {new_file} is no longer accessible.")
break
if current_size > previous_size:
previous_size = current_size
else:
flag = True
break
time.sleep(interval)
if flag:
break
def throughput_test(
server_args: ServerArgs,
bench_args: BenchArgs,
):
if bench_args.backend == "engine":
backend = Engine(**dataclasses.asdict(server_args))
if not backend:
raise ValueError("Please provide valid engine arguments")
elif bench_args.backend == "runtime":
backend = Runtime(**dataclasses.asdict(server_args))
else:
raise ValueError('Please set backend to either "engine" or "runtime"')
tokenizer_id = server_args.tokenizer_path or server_args.model_path
tokenizer = get_tokenizer(tokenizer_id)
# Set global environmnets
set_ulimit()
random.seed(bench_args.seed)
np.random.seed(bench_args.seed)
# Parse args
extra_request_body = {}
if bench_args.extra_request_body:
extra_request_body = json.loads(args.extra_request_body)
# Read dataset
input_requests = get_dataset(bench_args, tokenizer)
warmup_requests = sample_random_requests(
input_len=256,
output_len=16,
num_prompts=min(bench_args.num_prompts, 16),
range_ratio=1.0,
tokenizer=tokenizer,
dataset_path=bench_args.dataset_path,
)
# Warm up
if not bench_args.skip_warmup:
logging.info("\nWarmup...")
throughput_test_once(
backend_name=bench_args.backend,
backend=backend,
reqs=warmup_requests,
ignore_eos=not bench_args.disable_ignore_eos,
extra_request_body=extra_request_body,
profile=False,
)
time.sleep(0.5)
logging.info("\nBenchmark...")
result = throughput_test_once(
backend_name=bench_args.backend,
backend=backend,
reqs=input_requests,
ignore_eos=not bench_args.disable_ignore_eos,
extra_request_body=extra_request_body,
profile=bench_args.profile,
)
backend.shutdown()
if bench_args.result_filename:
with open(bench_args.result_filename, "a") as fout:
fout.write(json.dumps(result) + "\n")
print(
"\n{s:{c}^{n}}".format(s=" Offline Throughput Benchmark Result ", n=50, c="=")
)
print("{:<40} {:<10}".format("Backend:", result["backend"]))
print("{:<40} {:<10}".format("Successful requests:", result["successful_requests"]))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", result["total_latency"]))
print("{:<40} {:<10}".format("Total input tokens:", result["total_input_tokens"]))
print(
"{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"])
)
print(
"{:<40} {:<10.2f}".format(
"Last generation throughput (tok/s):", result["last_gen_throughput"]
)
)
print(
"{:<40} {:<10.2f}".format(
"Request throughput (req/s):", result["request_throughput"]
)
)
print(
"{:<40} {:<10.2f}".format(
"Input token throughput (tok/s):", result["input_throughput"]
)
)
print(
"{:<40} {:<10.2f}".format(
"Output token throughput (tok/s):", result["output_throughput"]
)
)
print(
"{:<40} {:<10.2f}".format(
"Total token throughput (tok/s):", result["total_throughput"]
)
)
print("=" * 50)
return result
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
BenchArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
bench_args = BenchArgs.from_cli_args(args)
logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()),
format="%(message)s",
)
throughput_test(server_args, bench_args)
while bench_args.do_not_exit:
pass