1580 lines
55 KiB
Python
1580 lines
55 KiB
Python
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py
|
|
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py
|
|
|
|
"""
|
|
Benchmark online serving with dynamic requests.
|
|
|
|
Usage:
|
|
python3 -m sglang.bench_serving --backend sglang --num-prompt 10
|
|
|
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import pickle
|
|
import random
|
|
import resource
|
|
import sys
|
|
import time
|
|
import traceback
|
|
import warnings
|
|
from argparse import ArgumentParser
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
|
|
|
import aiohttp
|
|
import numpy as np
|
|
import requests
|
|
from tqdm.asyncio import tqdm
|
|
from transformers import (
|
|
AutoTokenizer,
|
|
PreTrainedTokenizer,
|
|
PreTrainedTokenizerBase,
|
|
PreTrainedTokenizerFast,
|
|
)
|
|
|
|
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
|
ASSISTANT_SUFFIX = "Assistant:"
|
|
|
|
global args
|
|
|
|
|
|
@dataclass
|
|
class RequestFuncInput:
|
|
prompt: str
|
|
api_url: str
|
|
prompt_len: int
|
|
output_len: int
|
|
model: str
|
|
lora_name: str
|
|
extra_request_body: Dict[str, Any]
|
|
|
|
|
|
@dataclass
|
|
class RequestFuncOutput:
|
|
generated_text: str = ""
|
|
success: bool = False
|
|
latency: float = 0.0
|
|
ttft: float = 0.0 # Time to first token
|
|
itl: List[float] = field(default_factory=list) # List of inter-token latencies
|
|
prompt_len: int = 0
|
|
error: str = ""
|
|
output_len: int = 0
|
|
|
|
|
|
def remove_prefix(text: str, prefix: str) -> str:
|
|
return text[len(prefix) :] if text.startswith(prefix) else text
|
|
|
|
|
|
def remove_suffix(text: str, suffix: str) -> str:
|
|
return text[: -len(suffix)] if text.endswith(suffix) else text
|
|
|
|
|
|
def get_auth_headers() -> Dict[str, str]:
|
|
api_key = os.environ.get("OPENAI_API_KEY")
|
|
if api_key:
|
|
return {"Authorization": f"Bearer {api_key}"}
|
|
else:
|
|
return {}
|
|
|
|
|
|
# trt llm does not support ignore_eos
|
|
# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
|
|
async def async_request_trt_llm(
|
|
request_func_input: RequestFuncInput,
|
|
pbar: Optional[tqdm] = None,
|
|
) -> RequestFuncOutput:
|
|
api_url = request_func_input.api_url
|
|
assert api_url.endswith("generate_stream")
|
|
|
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
|
payload = {
|
|
"accumulate_tokens": True,
|
|
"text_input": request_func_input.prompt,
|
|
"temperature": 0.000001,
|
|
"top_p": 1.0,
|
|
"max_tokens": request_func_input.output_len,
|
|
"stream": True,
|
|
"min_length": request_func_input.output_len,
|
|
"end_id": 1048576,
|
|
**request_func_input.extra_request_body,
|
|
}
|
|
if args.disable_ignore_eos:
|
|
del payload["min_length"]
|
|
del payload["end_id"]
|
|
output = RequestFuncOutput()
|
|
output.prompt_len = request_func_input.prompt_len
|
|
|
|
ttft = 0.0
|
|
st = time.perf_counter()
|
|
most_recent_timestamp = st
|
|
try:
|
|
async with session.post(url=api_url, json=payload) as response:
|
|
if response.status == 200:
|
|
async for chunk_bytes in response.content:
|
|
chunk_bytes = chunk_bytes.strip()
|
|
if not chunk_bytes:
|
|
continue
|
|
|
|
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:")
|
|
|
|
data = json.loads(chunk)
|
|
output.generated_text += data["text_output"]
|
|
timestamp = time.perf_counter()
|
|
# First token
|
|
if ttft == 0.0:
|
|
ttft = timestamp - st
|
|
output.ttft = ttft
|
|
|
|
# Decoding phase
|
|
else:
|
|
output.itl.append(timestamp - most_recent_timestamp)
|
|
|
|
most_recent_timestamp = timestamp
|
|
|
|
output.latency = most_recent_timestamp - st
|
|
output.success = True
|
|
output.output_len = request_func_input.output_len
|
|
|
|
else:
|
|
output.error = response.reason or ""
|
|
output.success = False
|
|
except Exception:
|
|
output.success = False
|
|
exc_info = sys.exc_info()
|
|
output.error = "".join(traceback.format_exception(*exc_info))
|
|
|
|
if pbar:
|
|
pbar.update(1)
|
|
return output
|
|
|
|
|
|
# set ignore_eos True by default
|
|
async def async_request_openai_completions(
|
|
request_func_input: RequestFuncInput,
|
|
pbar: Optional[tqdm] = None,
|
|
) -> RequestFuncOutput:
|
|
api_url = request_func_input.api_url
|
|
assert api_url.endswith(
|
|
"completions"
|
|
), "OpenAI Completions API URL must end with 'completions'."
|
|
|
|
prompt = request_func_input.prompt
|
|
|
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
|
payload = {
|
|
"model": request_func_input.model,
|
|
"prompt": prompt,
|
|
"temperature": 0.0,
|
|
"best_of": 1,
|
|
"max_tokens": request_func_input.output_len,
|
|
"stream": not args.disable_stream,
|
|
"ignore_eos": not args.disable_ignore_eos,
|
|
**request_func_input.extra_request_body,
|
|
}
|
|
headers = get_auth_headers()
|
|
|
|
output = RequestFuncOutput()
|
|
output.prompt_len = request_func_input.prompt_len
|
|
|
|
generated_text = ""
|
|
output_len = request_func_input.output_len
|
|
ttft = 0.0
|
|
st = time.perf_counter()
|
|
most_recent_timestamp = st
|
|
try:
|
|
async with session.post(
|
|
url=api_url, json=payload, headers=headers
|
|
) as response:
|
|
if response.status == 200:
|
|
async for chunk_bytes in response.content:
|
|
chunk_bytes = chunk_bytes.strip()
|
|
if not chunk_bytes:
|
|
continue
|
|
|
|
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
|
latency = time.perf_counter() - st
|
|
if chunk == "[DONE]":
|
|
pass
|
|
else:
|
|
data = json.loads(chunk)
|
|
|
|
# NOTE: Some completion API might have a last
|
|
# usage summary response without a token so we
|
|
# want to check a token was generated
|
|
if data["choices"][0]["text"]:
|
|
timestamp = time.perf_counter()
|
|
# First token
|
|
if ttft == 0.0:
|
|
ttft = time.perf_counter() - st
|
|
output.ttft = ttft
|
|
|
|
# Decoding phase
|
|
else:
|
|
output.itl.append(timestamp - most_recent_timestamp)
|
|
|
|
most_recent_timestamp = timestamp
|
|
generated_text += data["choices"][0]["text"]
|
|
output_len = (data.get("usage") or {}).get(
|
|
"completion_tokens", output_len
|
|
)
|
|
|
|
output.generated_text = generated_text
|
|
output.success = True
|
|
output.latency = latency
|
|
output.output_len = output_len
|
|
else:
|
|
output.error = response.reason or ""
|
|
output.success = False
|
|
except Exception:
|
|
output.success = False
|
|
exc_info = sys.exc_info()
|
|
output.error = "".join(traceback.format_exception(*exc_info))
|
|
|
|
if pbar:
|
|
pbar.update(1)
|
|
return output
|
|
|
|
|
|
async def async_request_truss(
|
|
request_func_input: RequestFuncInput,
|
|
pbar: Optional[tqdm] = None,
|
|
) -> RequestFuncOutput:
|
|
api_url = request_func_input.api_url
|
|
|
|
prompt = request_func_input.prompt
|
|
|
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
|
payload = {
|
|
"model": request_func_input.model,
|
|
"prompt": prompt,
|
|
"temperature": 0.0,
|
|
"best_of": 1,
|
|
"max_tokens": request_func_input.output_len,
|
|
"stream": not args.disable_stream,
|
|
"ignore_eos": not args.disable_ignore_eos,
|
|
**request_func_input.extra_request_body,
|
|
}
|
|
headers = get_auth_headers()
|
|
|
|
output = RequestFuncOutput()
|
|
output.prompt_len = request_func_input.prompt_len
|
|
|
|
generated_text = ""
|
|
ttft = 0.0
|
|
st = time.perf_counter()
|
|
most_recent_timestamp = st
|
|
try:
|
|
async with session.post(
|
|
url=api_url, json=payload, headers=headers
|
|
) as response:
|
|
if response.status == 200:
|
|
async for chunk_bytes in response.content:
|
|
chunk_bytes = chunk_bytes.strip()
|
|
if not chunk_bytes:
|
|
continue
|
|
|
|
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
|
latency = time.perf_counter() - st
|
|
if chunk == "[DONE]":
|
|
pass
|
|
else:
|
|
data = json.loads(chunk)
|
|
|
|
# NOTE: Some completion API might have a last
|
|
# usage summary response without a token so we
|
|
# want to check a token was generated
|
|
if data["choices"][0]["delta"]["content"]:
|
|
timestamp = time.perf_counter()
|
|
# First token
|
|
if ttft == 0.0:
|
|
ttft = time.perf_counter() - st
|
|
output.ttft = ttft
|
|
|
|
# Decoding phase
|
|
else:
|
|
output.itl.append(timestamp - most_recent_timestamp)
|
|
|
|
most_recent_timestamp = timestamp
|
|
generated_text += data["choices"][0]["delta"]["content"]
|
|
|
|
output.generated_text = generated_text
|
|
output.success = True
|
|
output.latency = latency
|
|
output.output_len = request_func_input.output_len
|
|
else:
|
|
output.error = response.reason or ""
|
|
output.success = False
|
|
except Exception:
|
|
output.success = False
|
|
exc_info = sys.exc_info()
|
|
output.error = "".join(traceback.format_exception(*exc_info))
|
|
|
|
if pbar:
|
|
pbar.update(1)
|
|
return output
|
|
|
|
|
|
async def async_request_sglang_generate(
|
|
request_func_input: RequestFuncInput,
|
|
pbar: Optional[tqdm] = None,
|
|
) -> RequestFuncOutput:
|
|
api_url = request_func_input.api_url
|
|
prompt = request_func_input.prompt
|
|
|
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
|
payload = {
|
|
"text": prompt,
|
|
"sampling_params": {
|
|
"temperature": 0.0,
|
|
"max_new_tokens": request_func_input.output_len,
|
|
"ignore_eos": not args.disable_ignore_eos,
|
|
},
|
|
"stream": not args.disable_stream,
|
|
"lora_path": request_func_input.lora_name,
|
|
"return_logprob": args.return_logprob,
|
|
"logprob_start_len": -1,
|
|
**request_func_input.extra_request_body,
|
|
}
|
|
headers = get_auth_headers()
|
|
|
|
output = RequestFuncOutput()
|
|
output.prompt_len = request_func_input.prompt_len
|
|
|
|
generated_text = ""
|
|
output_len = request_func_input.output_len
|
|
ttft = 0.0
|
|
st = time.perf_counter()
|
|
most_recent_timestamp = st
|
|
last_output_len = 0
|
|
try:
|
|
async with session.post(
|
|
url=api_url, json=payload, headers=headers
|
|
) as response:
|
|
if response.status == 200:
|
|
async for chunk_bytes in response.content:
|
|
chunk_bytes = chunk_bytes.strip()
|
|
if not chunk_bytes:
|
|
continue
|
|
# print(chunk_bytes)
|
|
|
|
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
|
latency = time.perf_counter() - st
|
|
if chunk == "[DONE]":
|
|
pass
|
|
else:
|
|
data = json.loads(chunk)
|
|
|
|
# NOTE: Some completion API might have a last
|
|
# usage summary response without a token so we
|
|
# want to check a token was generated
|
|
if data["text"]:
|
|
timestamp = time.perf_counter()
|
|
generated_text = data["text"]
|
|
output_len = data["meta_info"]["completion_tokens"]
|
|
|
|
# First token
|
|
if ttft == 0.0:
|
|
ttft = time.perf_counter() - st
|
|
output.ttft = ttft
|
|
|
|
# Decoding phase
|
|
else:
|
|
num_new_tokens = output_len - last_output_len
|
|
if num_new_tokens == 0:
|
|
continue
|
|
adjust_itl = (
|
|
timestamp - most_recent_timestamp
|
|
) / num_new_tokens
|
|
output.itl.extend([adjust_itl] * num_new_tokens)
|
|
|
|
most_recent_timestamp = timestamp
|
|
last_output_len = output_len
|
|
|
|
output.generated_text = generated_text
|
|
output.success = True
|
|
output.latency = latency
|
|
output.output_len = output_len
|
|
else:
|
|
output.error = response.reason or ""
|
|
output.success = False
|
|
except Exception:
|
|
output.success = False
|
|
exc_info = sys.exc_info()
|
|
output.error = "".join(traceback.format_exception(*exc_info))
|
|
print(f"{output.error=}")
|
|
|
|
if pbar:
|
|
pbar.update(1)
|
|
return output
|
|
|
|
|
|
async def async_request_gserver(
|
|
request_func_input: RequestFuncInput,
|
|
pbar: Optional[tqdm] = None,
|
|
) -> RequestFuncOutput:
|
|
raise NotImplementedError()
|
|
|
|
|
|
async def async_request_profile(api_url: str) -> RequestFuncOutput:
|
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
|
output = RequestFuncOutput()
|
|
try:
|
|
async with session.post(url=api_url) as response:
|
|
if response.status == 200:
|
|
output.success = True
|
|
else:
|
|
output.error = response.reason or ""
|
|
output.success = False
|
|
except Exception:
|
|
output.success = False
|
|
exc_info = sys.exc_info()
|
|
output.error = "".join(traceback.format_exception(*exc_info))
|
|
|
|
return output
|
|
|
|
|
|
def get_model(pretrained_model_name_or_path: str) -> str:
|
|
if os.getenv("SGLANG_USE_MODELSCOPE", "false").lower() == "true":
|
|
import huggingface_hub.constants
|
|
from modelscope import snapshot_download
|
|
|
|
model_path = snapshot_download(
|
|
model_id=pretrained_model_name_or_path,
|
|
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
|
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
|
|
)
|
|
|
|
return model_path
|
|
return pretrained_model_name_or_path
|
|
|
|
|
|
def get_tokenizer(
|
|
pretrained_model_name_or_path: str,
|
|
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
|
if pretrained_model_name_or_path.endswith(
|
|
".json"
|
|
) or pretrained_model_name_or_path.endswith(".model"):
|
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
|
|
|
return get_tokenizer(pretrained_model_name_or_path)
|
|
|
|
if pretrained_model_name_or_path is not None and not os.path.exists(
|
|
pretrained_model_name_or_path
|
|
):
|
|
pretrained_model_name_or_path = get_model(pretrained_model_name_or_path)
|
|
return AutoTokenizer.from_pretrained(
|
|
pretrained_model_name_or_path, trust_remote_code=True
|
|
)
|
|
|
|
|
|
def get_dataset(args, tokenizer):
|
|
if args.dataset_name == "sharegpt":
|
|
input_requests = sample_sharegpt_requests(
|
|
dataset_path=args.dataset_path,
|
|
num_requests=args.num_prompts,
|
|
tokenizer=tokenizer,
|
|
fixed_output_len=args.sharegpt_output_len,
|
|
context_len=args.sharegpt_context_len,
|
|
prompt_suffix=args.prompt_suffix,
|
|
apply_chat_template=args.apply_chat_template,
|
|
)
|
|
elif args.dataset_name == "random":
|
|
input_requests = sample_random_requests(
|
|
input_len=args.random_input_len,
|
|
output_len=args.random_output_len,
|
|
num_prompts=args.num_prompts,
|
|
range_ratio=args.random_range_ratio,
|
|
tokenizer=tokenizer,
|
|
dataset_path=args.dataset_path,
|
|
)
|
|
elif args.dataset_name == "generated-shared-prefix":
|
|
input_requests = sample_generated_shared_prefix_requests(
|
|
num_groups=args.gsp_num_groups,
|
|
prompts_per_group=args.gsp_prompts_per_group,
|
|
system_prompt_len=args.gsp_system_prompt_len,
|
|
question_len=args.gsp_question_len,
|
|
output_len=args.gsp_output_len,
|
|
tokenizer=tokenizer,
|
|
args=args,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
|
return input_requests
|
|
|
|
|
|
ASYNC_REQUEST_FUNCS = {
|
|
"sglang": async_request_sglang_generate,
|
|
"sglang-native": async_request_sglang_generate,
|
|
"sglang-oai": async_request_openai_completions,
|
|
"vllm": async_request_openai_completions,
|
|
"lmdeploy": async_request_openai_completions,
|
|
"trt": async_request_trt_llm,
|
|
"gserver": async_request_gserver,
|
|
"truss": async_request_truss,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class BenchmarkMetrics:
|
|
completed: int
|
|
total_input: int
|
|
total_output: int
|
|
total_output_retokenized: int
|
|
request_throughput: float
|
|
input_throughput: float
|
|
output_throughput: float
|
|
output_throughput_retokenized: float
|
|
total_throughput: float
|
|
total_throughput_retokenized: float
|
|
mean_ttft_ms: float
|
|
median_ttft_ms: float
|
|
std_ttft_ms: float
|
|
p99_ttft_ms: float
|
|
mean_tpot_ms: float
|
|
median_tpot_ms: float
|
|
std_tpot_ms: float
|
|
p99_tpot_ms: float
|
|
mean_itl_ms: float
|
|
median_itl_ms: float
|
|
std_itl_ms: float
|
|
p95_itl_ms: float
|
|
p99_itl_ms: float
|
|
max_itl_ms: float
|
|
mean_e2e_latency_ms: float
|
|
median_e2e_latency_ms: float
|
|
std_e2e_latency_ms: float
|
|
p99_e2e_latency_ms: float
|
|
concurrency: float
|
|
|
|
|
|
SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
|
|
|
|
|
|
def download_and_cache_file(url: str, filename: Optional[str] = None):
|
|
"""Read and cache a file from a url."""
|
|
if filename is None:
|
|
filename = os.path.join("/tmp", url.split("/")[-1])
|
|
|
|
# Check if the cache file already exists
|
|
if os.path.exists(filename):
|
|
return filename
|
|
|
|
print(f"Downloading from {url} to {filename}")
|
|
|
|
# Stream the response to show the progress bar
|
|
response = requests.get(url, stream=True)
|
|
response.raise_for_status() # Check for request errors
|
|
|
|
# Total size of the file in bytes
|
|
total_size = int(response.headers.get("content-length", 0))
|
|
chunk_size = 1024 # Download in chunks of 1KB
|
|
|
|
# Use tqdm to display the progress bar
|
|
with open(filename, "wb") as f, tqdm(
|
|
desc=filename,
|
|
total=total_size,
|
|
unit="B",
|
|
unit_scale=True,
|
|
unit_divisor=1024,
|
|
) as bar:
|
|
for chunk in response.iter_content(chunk_size=chunk_size):
|
|
f.write(chunk)
|
|
bar.update(len(chunk))
|
|
|
|
return filename
|
|
|
|
|
|
def sample_sharegpt_requests(
|
|
dataset_path: str,
|
|
num_requests: int,
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
fixed_output_len: Optional[int] = None,
|
|
context_len: Optional[int] = None,
|
|
prompt_suffix: Optional[str] = "",
|
|
apply_chat_template=False,
|
|
) -> List[Tuple[str, int, int]]:
|
|
if fixed_output_len is not None and fixed_output_len < 4:
|
|
raise ValueError("output_len too small")
|
|
|
|
# Download sharegpt if necessary
|
|
if not os.path.isfile(dataset_path) and dataset_path == "":
|
|
dataset_path = download_and_cache_file(SHAREGPT_URL)
|
|
|
|
# Load the dataset.
|
|
with open(dataset_path) as f:
|
|
dataset = json.load(f)
|
|
|
|
# Filter out the conversations with less than 2 turns.
|
|
dataset = [
|
|
data
|
|
for data in dataset
|
|
if len(data.get("conversations", data.get("conversation", []))) >= 2
|
|
]
|
|
# Only keep the first two turns of each conversation.
|
|
dataset = [
|
|
(
|
|
data.get("conversations", data.get("conversation", []))[0]["value"],
|
|
data.get("conversations", data.get("conversation", []))[1]["value"],
|
|
)
|
|
for data in dataset
|
|
]
|
|
|
|
# Shuffle the dataset.
|
|
random.shuffle(dataset)
|
|
|
|
# Filter out sequences that are too long or too short
|
|
filtered_dataset: List[Tuple[str, int, int]] = []
|
|
for i in range(len(dataset)):
|
|
if len(filtered_dataset) == num_requests:
|
|
break
|
|
|
|
# Tokenize the prompts and completions.
|
|
prompt = dataset[i][0]
|
|
if prompt_suffix:
|
|
prompt = (
|
|
remove_suffix(prompt, ASSISTANT_SUFFIX)
|
|
+ prompt_suffix
|
|
+ ASSISTANT_SUFFIX
|
|
)
|
|
|
|
if apply_chat_template:
|
|
prompt = tokenizer.apply_chat_template(
|
|
[{"role": "user", "content": prompt}],
|
|
add_generation_prompt=True,
|
|
tokenize=False,
|
|
)
|
|
prompt = prompt.replace(tokenizer.bos_token, "")
|
|
|
|
prompt_token_ids = tokenizer.encode(prompt)
|
|
completion = dataset[i][1]
|
|
completion_token_ids = tokenizer.encode(completion)
|
|
prompt_len = len(prompt_token_ids)
|
|
output_len = (
|
|
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
|
|
)
|
|
|
|
if prompt_len < 2 or output_len < 2:
|
|
# Prune too short sequences.
|
|
continue
|
|
|
|
if context_len and prompt_len + output_len > context_len:
|
|
# Prune too long sequences.
|
|
continue
|
|
|
|
filtered_dataset.append((prompt, prompt_len, output_len))
|
|
|
|
print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}")
|
|
print(f"#Output tokens: {np.sum([x[2] for x in filtered_dataset])}")
|
|
return filtered_dataset
|
|
|
|
|
|
def sample_random_requests(
|
|
input_len: int,
|
|
output_len: int,
|
|
num_prompts: int,
|
|
range_ratio: float,
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
dataset_path: str,
|
|
) -> List[Tuple[str, int, int]]:
|
|
|
|
input_lens = np.random.randint(
|
|
max(int(input_len * range_ratio), 1),
|
|
input_len + 1,
|
|
size=num_prompts,
|
|
)
|
|
output_lens = np.random.randint(
|
|
int(output_len * range_ratio),
|
|
output_len + 1,
|
|
size=num_prompts,
|
|
)
|
|
|
|
if True:
|
|
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
|
|
|
|
# Download sharegpt if necessary
|
|
if not os.path.isfile(dataset_path):
|
|
dataset_path = download_and_cache_file(SHAREGPT_URL)
|
|
|
|
# Load the dataset.
|
|
with open(dataset_path) as f:
|
|
dataset = json.load(f)
|
|
# Filter out the conversations with less than 2 turns.
|
|
dataset = [
|
|
data
|
|
for data in dataset
|
|
if len(data.get("conversations", data.get("conversation", []))) >= 2
|
|
]
|
|
# Only keep the first two turns of each conversation.
|
|
dataset = [
|
|
(
|
|
data.get("conversations", data.get("conversation", []))[0]["value"],
|
|
data.get("conversations", data.get("conversation", []))[1]["value"],
|
|
)
|
|
for data in dataset
|
|
]
|
|
# Shuffle the dataset.
|
|
random.shuffle(dataset)
|
|
|
|
# Filter out sequences that are too long or too short
|
|
input_requests: List[Tuple[str, int, int]] = []
|
|
for data in dataset:
|
|
i = len(input_requests)
|
|
if i == num_prompts:
|
|
break
|
|
|
|
# Tokenize the prompts and completions.
|
|
prompt = data[0]
|
|
prompt_token_ids = tokenizer.encode(prompt)
|
|
prompt_len = len(prompt_token_ids)
|
|
|
|
# Skip empty prompt
|
|
if prompt_len == 0:
|
|
continue
|
|
|
|
if prompt_len > input_lens[i]:
|
|
input_ids = prompt_token_ids[: input_lens[i]]
|
|
else:
|
|
ratio = (input_lens[i] + prompt_len - 1) // prompt_len
|
|
input_ids = (prompt_token_ids * ratio)[: input_lens[i]]
|
|
prompt = tokenizer.decode(input_ids)
|
|
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
|
|
else:
|
|
# Sample token ids from random integers. This can cause some NaN issues.
|
|
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
|
|
input_requests = []
|
|
for i in range(num_prompts):
|
|
prompt = tokenizer.decode(
|
|
[
|
|
(offsets[i] + i + j) % tokenizer.vocab_size
|
|
for j in range(input_lens[i])
|
|
]
|
|
)
|
|
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
|
|
|
|
print(f"#Input tokens: {np.sum(input_lens)}")
|
|
print(f"#Output tokens: {np.sum(output_lens)}")
|
|
return input_requests
|
|
|
|
|
|
def gen_prompt(tokenizer, token_num):
|
|
"""Generate a random prompt of specified token length using tokenizer vocabulary."""
|
|
all_available_tokens = list(tokenizer.get_vocab().values())
|
|
selected_tokens = random.choices(all_available_tokens, k=token_num)
|
|
return tokenizer.decode(selected_tokens)
|
|
|
|
|
|
def get_gen_prefix_cache_path(args, tokenizer):
|
|
"""Create cache directory under ~/.cache/sglang/benchmark"""
|
|
cache_dir = Path.home() / ".cache" / "sglang" / "benchmark"
|
|
|
|
# Create a unique cache filename based on the generation parameters
|
|
cache_key = (
|
|
f"gen_shared_prefix_{args.gsp_num_groups}_{args.gsp_prompts_per_group}_"
|
|
f"{args.gsp_system_prompt_len}_{args.gsp_question_len}_{args.gsp_output_len}_"
|
|
f"{tokenizer.__class__.__name__}.pkl"
|
|
)
|
|
return cache_dir / cache_key
|
|
|
|
|
|
def sample_generated_shared_prefix_requests(
|
|
num_groups: int,
|
|
prompts_per_group: int,
|
|
system_prompt_len: int,
|
|
question_len: int,
|
|
output_len: int,
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
args: argparse.Namespace,
|
|
) -> List[Tuple[str, int, int]]:
|
|
"""Generate benchmark requests with shared system prompts using random tokens and caching."""
|
|
cache_path = get_gen_prefix_cache_path(args, tokenizer)
|
|
|
|
# Try to load from cache first
|
|
if cache_path.exists():
|
|
print(f"\nLoading cached generated input data from {cache_path}")
|
|
with open(cache_path, "rb") as f:
|
|
return pickle.load(f)
|
|
|
|
print("\nGenerating new input data...")
|
|
|
|
# Generate system prompts for each group
|
|
system_prompts = []
|
|
for _ in range(num_groups):
|
|
system_prompt = gen_prompt(tokenizer, system_prompt_len)
|
|
system_prompts.append(system_prompt)
|
|
|
|
# Generate questions
|
|
questions = []
|
|
for _ in range(num_groups * prompts_per_group):
|
|
question = gen_prompt(tokenizer, question_len)
|
|
questions.append(question)
|
|
|
|
# Combine system prompts with questions
|
|
input_requests = []
|
|
total_input_tokens = 0
|
|
total_output_tokens = 0
|
|
|
|
for group_idx in tqdm(range(num_groups), desc="Generating system prompt"):
|
|
system_prompt = system_prompts[group_idx]
|
|
for prompt_idx in tqdm(
|
|
range(prompts_per_group), desc="Generating questions", leave=False
|
|
):
|
|
question = questions[group_idx * prompts_per_group + prompt_idx]
|
|
full_prompt = f"{system_prompt}\n\n{question}"
|
|
prompt_len = len(tokenizer.encode(full_prompt))
|
|
|
|
input_requests.append((full_prompt, prompt_len, output_len))
|
|
total_input_tokens += prompt_len
|
|
total_output_tokens += output_len
|
|
|
|
# Shuffle questions
|
|
random.shuffle(input_requests)
|
|
|
|
# Print statistics
|
|
print(f"\nGenerated shared prefix dataset statistics:")
|
|
print(f"Number of groups: {num_groups}")
|
|
print(f"Prompts per group: {prompts_per_group}")
|
|
print(f"Total prompts: {len(input_requests)}")
|
|
print(f"Total input tokens: {total_input_tokens}")
|
|
print(f"Total output tokens: {total_output_tokens}")
|
|
print(
|
|
f"Average system prompt length: {sum(len(tokenizer.encode(sp)) for sp in system_prompts) / len(system_prompts):.1f} tokens"
|
|
)
|
|
print(
|
|
f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
|
|
)
|
|
|
|
# Save to cache
|
|
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
|
print(f"Caching generated input data to {cache_path}")
|
|
with open(cache_path, "wb") as f:
|
|
pickle.dump(input_requests, f)
|
|
|
|
return input_requests
|
|
|
|
|
|
async def get_request(
|
|
input_requests: List[Tuple[str, int, int]],
|
|
request_rate: float,
|
|
) -> AsyncGenerator[Tuple[str, int, int], None]:
|
|
input_requests = iter(input_requests)
|
|
for request in input_requests:
|
|
yield request
|
|
|
|
if request_rate == float("inf"):
|
|
# If the request rate is infinity, then we don't need to wait.
|
|
continue
|
|
|
|
# Sample the request interval from the exponential distribution.
|
|
interval = np.random.exponential(1.0 / request_rate)
|
|
# The next request will be sent after the interval.
|
|
await asyncio.sleep(interval)
|
|
|
|
|
|
def calculate_metrics(
|
|
input_requests: List[Tuple[str, int, int]],
|
|
outputs: List[RequestFuncOutput],
|
|
dur_s: float,
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
backend: str,
|
|
) -> Tuple[BenchmarkMetrics, List[int]]:
|
|
output_lens: List[int] = []
|
|
retokenized_output_lens: List[int] = []
|
|
total_input = 0
|
|
completed = 0
|
|
itls: List[float] = []
|
|
tpots: List[float] = []
|
|
ttfts: List[float] = []
|
|
e2e_latencies: List[float] = []
|
|
for i in range(len(outputs)):
|
|
if outputs[i].success:
|
|
output_len = outputs[i].output_len
|
|
output_lens.append(output_len)
|
|
retokenized_output_len = len(
|
|
tokenizer.encode(outputs[i].generated_text, add_special_tokens=False)
|
|
)
|
|
retokenized_output_lens.append(retokenized_output_len)
|
|
total_input += input_requests[i][1]
|
|
if output_len > 1:
|
|
tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
|
|
itls += outputs[i].itl
|
|
ttfts.append(outputs[i].ttft)
|
|
|
|
e2e_latencies.append(outputs[i].latency)
|
|
|
|
completed += 1
|
|
else:
|
|
output_lens.append(0)
|
|
retokenized_output_lens.append(0)
|
|
|
|
if completed == 0:
|
|
warnings.warn(
|
|
"All requests failed. This is likely due to a misconfiguration "
|
|
"on the benchmark arguments.",
|
|
stacklevel=2,
|
|
)
|
|
metrics = BenchmarkMetrics(
|
|
completed=completed,
|
|
total_input=total_input,
|
|
total_output=sum(output_lens),
|
|
total_output_retokenized=sum(retokenized_output_lens),
|
|
request_throughput=completed / dur_s,
|
|
input_throughput=total_input / dur_s,
|
|
output_throughput=sum(output_lens) / dur_s,
|
|
output_throughput_retokenized=sum(retokenized_output_lens) / dur_s,
|
|
total_throughput=(total_input + sum(output_lens)) / dur_s,
|
|
total_throughput_retokenized=(total_input + sum(retokenized_output_lens))
|
|
/ dur_s,
|
|
mean_ttft_ms=np.mean(ttfts or 0)
|
|
* 1000, # ttfts is empty if streaming is not supported by backend
|
|
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
|
std_ttft_ms=np.std(ttfts or 0) * 1000,
|
|
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
|
|
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
|
median_tpot_ms=np.median(tpots or 0) * 1000,
|
|
std_tpot_ms=np.std(tpots or 0) * 1000,
|
|
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
|
|
mean_itl_ms=np.mean(itls or 0) * 1000,
|
|
median_itl_ms=np.median(itls or 0) * 1000,
|
|
std_itl_ms=np.std(itls or 0) * 1000,
|
|
p95_itl_ms=np.percentile(itls or 0, 95) * 1000,
|
|
p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
|
|
max_itl_ms=np.max(itls or 0) * 1000,
|
|
mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000,
|
|
median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
|
|
std_e2e_latency_ms=np.std(e2e_latencies) * 1000,
|
|
p99_e2e_latency_ms=np.percentile(e2e_latencies, 99) * 1000,
|
|
concurrency=np.sum(e2e_latencies) / dur_s,
|
|
)
|
|
|
|
return metrics, output_lens
|
|
|
|
|
|
async def benchmark(
|
|
backend: str,
|
|
api_url: str,
|
|
base_url: str,
|
|
model_id: str,
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
input_requests: List[Tuple[str, int, int]],
|
|
request_rate: float,
|
|
max_concurrency: Optional[int],
|
|
disable_tqdm: bool,
|
|
lora_names: List[str],
|
|
extra_request_body: Dict[str, Any],
|
|
profile: bool,
|
|
pd_seperated: bool = False,
|
|
):
|
|
if backend in ASYNC_REQUEST_FUNCS:
|
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
|
else:
|
|
raise ValueError(f"Unknown backend: {backend}")
|
|
|
|
# Limit concurrency
|
|
# From https://github.com/vllm-project/vllm/pull/9390
|
|
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
|
|
|
|
async def limited_request_func(request_func_input, pbar):
|
|
if semaphore is None:
|
|
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
|
async with semaphore:
|
|
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
|
|
|
# Warmup
|
|
print("Starting initial single prompt test run...")
|
|
test_prompt, test_prompt_len, test_output_len = input_requests[0]
|
|
if lora_names != None and len(lora_names) != 0:
|
|
lora_name = lora_names[0]
|
|
else:
|
|
lora_name = None
|
|
|
|
test_input = RequestFuncInput(
|
|
model=model_id,
|
|
prompt=test_prompt,
|
|
api_url=api_url,
|
|
prompt_len=test_prompt_len,
|
|
output_len=min(test_output_len, 32),
|
|
lora_name=lora_name,
|
|
extra_request_body=extra_request_body,
|
|
)
|
|
test_output = await request_func(request_func_input=test_input)
|
|
if not test_output.success:
|
|
raise ValueError(
|
|
"Initial test run failed - Please make sure benchmark arguments "
|
|
f"are correctly specified. Error: {test_output.error}"
|
|
)
|
|
else:
|
|
print("Initial test run completed. Starting main benchmark run...")
|
|
|
|
# Flush cache
|
|
if "sglang" in backend:
|
|
requests.post(base_url + "/flush_cache", headers=get_auth_headers())
|
|
|
|
time.sleep(1.0)
|
|
|
|
# Start profiler
|
|
if profile:
|
|
print("Starting profiler...")
|
|
profile_output = await async_request_profile(
|
|
api_url=base_url + "/start_profile"
|
|
)
|
|
if profile_output.success:
|
|
print("Profiler started")
|
|
|
|
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
|
|
|
# Run all requests
|
|
benchmark_start_time = time.perf_counter()
|
|
tasks: List[asyncio.Task] = []
|
|
async for request in get_request(input_requests, request_rate):
|
|
prompt, prompt_len, output_len = request
|
|
if lora_names != None and len(lora_names) != 0:
|
|
idx = random.randint(0, len(lora_names) - 1)
|
|
lora_name = lora_names[idx]
|
|
else:
|
|
lora_name = None
|
|
|
|
request_func_input = RequestFuncInput(
|
|
model=model_id,
|
|
prompt=prompt,
|
|
api_url=api_url,
|
|
prompt_len=prompt_len,
|
|
output_len=output_len,
|
|
lora_name=lora_name,
|
|
extra_request_body=extra_request_body,
|
|
)
|
|
tasks.append(
|
|
asyncio.create_task(
|
|
limited_request_func(request_func_input=request_func_input, pbar=pbar)
|
|
)
|
|
)
|
|
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
|
|
|
# Stop profiler
|
|
if profile:
|
|
print("Stopping profiler...")
|
|
profile_output = await async_request_profile(api_url=base_url + "/stop_profile")
|
|
if profile_output.success:
|
|
print("Profiler stopped")
|
|
|
|
if pbar is not None:
|
|
pbar.close()
|
|
|
|
if "sglang" in backend:
|
|
server_info = requests.get(base_url + "/get_server_info")
|
|
if pd_seperated:
|
|
accept_length = server_info.json()["decode"][0].get(
|
|
"avg_spec_accept_length", None
|
|
)
|
|
else:
|
|
accept_length = server_info.json().get("avg_spec_accept_length", None)
|
|
else:
|
|
accept_length = None
|
|
|
|
# Compute metrics and print results
|
|
benchmark_duration = time.perf_counter() - benchmark_start_time
|
|
metrics, output_lens = calculate_metrics(
|
|
input_requests=input_requests,
|
|
outputs=outputs,
|
|
dur_s=benchmark_duration,
|
|
tokenizer=tokenizer,
|
|
backend=backend,
|
|
)
|
|
|
|
print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
|
|
print("{:<40} {:<10}".format("Backend:", backend))
|
|
print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
|
|
print(
|
|
"{:<40} {:<10}".format(
|
|
"Max reqeuest concurrency:",
|
|
max_concurrency if max_concurrency else "not set",
|
|
)
|
|
)
|
|
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
|
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
|
|
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
|
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
|
|
print(
|
|
"{:<40} {:<10}".format(
|
|
"Total generated tokens (retokenized):", metrics.total_output_retokenized
|
|
)
|
|
)
|
|
print(
|
|
"{:<40} {:<10.2f}".format(
|
|
"Request throughput (req/s):", metrics.request_throughput
|
|
)
|
|
)
|
|
print(
|
|
"{:<40} {:<10.2f}".format(
|
|
"Input token throughput (tok/s):", metrics.input_throughput
|
|
)
|
|
)
|
|
print(
|
|
"{:<40} {:<10.2f}".format(
|
|
"Output token throughput (tok/s):", metrics.output_throughput
|
|
)
|
|
)
|
|
print(
|
|
"{:<40} {:<10.2f}".format(
|
|
"Total token throughput (tok/s):", metrics.total_throughput
|
|
)
|
|
)
|
|
print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency))
|
|
if accept_length:
|
|
print("{:<40} {:<10.2f}".format("Accept length:", accept_length))
|
|
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
|
|
print(
|
|
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
|
|
)
|
|
print(
|
|
"{:<40} {:<10.2f}".format(
|
|
"Median E2E Latency (ms):", metrics.median_e2e_latency_ms
|
|
)
|
|
)
|
|
print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-"))
|
|
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
|
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
|
|
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
|
|
print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-"))
|
|
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
|
|
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
|
|
print("{:<40} {:<10.2f}".format("P95 ITL (ms):", metrics.p95_itl_ms))
|
|
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
|
|
print("{:<40} {:<10.2f}".format("Max ITL (ms):", metrics.max_itl_ms))
|
|
print("=" * 50)
|
|
|
|
if (
|
|
metrics.median_ttft_ms is not None
|
|
and metrics.mean_itl_ms is not None
|
|
and metrics.output_throughput is not None
|
|
):
|
|
result = {
|
|
# Arguments
|
|
"backend": args.backend,
|
|
"dataset_name": args.dataset_name,
|
|
"request_rate": request_rate,
|
|
"max_concurrency": max_concurrency,
|
|
"sharegpt_output_len": args.sharegpt_output_len,
|
|
"random_input_len": args.random_input_len,
|
|
"random_output_len": args.random_output_len,
|
|
"random_range_ratio": args.random_range_ratio,
|
|
# Results
|
|
"duration": benchmark_duration,
|
|
"completed": metrics.completed,
|
|
"total_input_tokens": metrics.total_input,
|
|
"total_output_tokens": metrics.total_output,
|
|
"total_output_tokens_retokenized": metrics.total_output_retokenized,
|
|
"request_throughput": metrics.request_throughput,
|
|
"input_throughput": metrics.input_throughput,
|
|
"output_throughput": metrics.output_throughput,
|
|
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
|
|
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
|
|
"std_e2e_latency_ms": metrics.std_e2e_latency_ms,
|
|
"p99_e2e_latency_ms": metrics.p99_e2e_latency_ms,
|
|
"mean_ttft_ms": metrics.mean_ttft_ms,
|
|
"median_ttft_ms": metrics.median_ttft_ms,
|
|
"std_ttft_ms": metrics.std_ttft_ms,
|
|
"p99_ttft_ms": metrics.p99_ttft_ms,
|
|
"mean_tpot_ms": metrics.mean_tpot_ms,
|
|
"median_tpot_ms": metrics.median_tpot_ms,
|
|
"std_tpot_ms": metrics.std_tpot_ms,
|
|
"p99_tpot_ms": metrics.p99_tpot_ms,
|
|
"mean_itl_ms": metrics.mean_itl_ms,
|
|
"median_itl_ms": metrics.median_itl_ms,
|
|
"std_itl_ms": metrics.std_itl_ms,
|
|
"p95_itl_ms": metrics.p95_itl_ms,
|
|
"p99_itl_ms": metrics.p99_itl_ms,
|
|
"concurrency": metrics.concurrency,
|
|
"accept_length": accept_length,
|
|
}
|
|
else:
|
|
print(f"Error running benchmark for request rate: {request_rate}")
|
|
print("-" * 30)
|
|
|
|
# Determine output file name
|
|
if args.output_file:
|
|
output_file_name = args.output_file
|
|
else:
|
|
now = datetime.now().strftime("%m%d")
|
|
if args.dataset_name == "random":
|
|
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
|
|
else:
|
|
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl"
|
|
|
|
# Append results to a JSONL file
|
|
with open(output_file_name, "a") as file:
|
|
file.write(json.dumps(result) + "\n")
|
|
|
|
result.update(
|
|
{
|
|
"input_lens": [output.prompt_len for output in outputs],
|
|
"output_lens": output_lens,
|
|
"ttfts": [output.ttft for output in outputs],
|
|
"itls": [output.itl for output in outputs],
|
|
"generated_texts": [output.generated_text for output in outputs],
|
|
"errors": [output.error for output in outputs],
|
|
}
|
|
)
|
|
return result
|
|
|
|
|
|
def check_chat_template(model_path):
|
|
try:
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
|
return "chat_template" in tokenizer.init_kwargs
|
|
except Exception as e:
|
|
print(f"Fail to load tokenizer config with error={e}")
|
|
return False
|
|
|
|
|
|
def set_global_args(args_: argparse.Namespace):
|
|
"""Set the global args."""
|
|
global args
|
|
args = args_
|
|
|
|
|
|
def run_benchmark(args_: argparse.Namespace):
|
|
global args
|
|
args = args_
|
|
|
|
# Set default value for max_concurrency if not present
|
|
if not hasattr(args, "max_concurrency"):
|
|
args.max_concurrency = None
|
|
|
|
print(f"benchmark_args={args}")
|
|
|
|
# Set global environments
|
|
set_ulimit()
|
|
random.seed(args.seed)
|
|
np.random.seed(args.seed)
|
|
|
|
extra_request_body = {}
|
|
if args.extra_request_body:
|
|
extra_request_body = json.loads(args.extra_request_body)
|
|
|
|
# Set url
|
|
if args.port is None:
|
|
args.port = {
|
|
"sglang": 30000,
|
|
"sglang-native": 30000,
|
|
"sglang-oai": 30000,
|
|
"lmdeploy": 23333,
|
|
"vllm": 8000,
|
|
"trt": 8000,
|
|
"gserver": 9988,
|
|
"truss": 8080,
|
|
}.get(args.backend, 30000)
|
|
|
|
model_url = (
|
|
f"{args.base_url}/v1/models"
|
|
if args.base_url
|
|
else f"http://{args.host}:{args.port}/v1/models"
|
|
)
|
|
|
|
if args.backend in ["sglang", "sglang-native"]:
|
|
api_url = (
|
|
f"{args.base_url}/generate"
|
|
if args.base_url
|
|
else f"http://{args.host}:{args.port}/generate"
|
|
)
|
|
elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]:
|
|
api_url = (
|
|
f"{args.base_url}/v1/completions"
|
|
if args.base_url
|
|
else f"http://{args.host}:{args.port}/v1/completions"
|
|
)
|
|
elif args.backend == "trt":
|
|
api_url = (
|
|
f"{args.base_url}/v2/models/ensemble/generate_stream"
|
|
if args.base_url
|
|
else f"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream"
|
|
)
|
|
if args.model is None:
|
|
print("Please provide a model using `--model` when using `trt` backend.")
|
|
sys.exit(1)
|
|
elif args.backend == "gserver":
|
|
api_url = args.base_url if args.base_url else f"{args.host}:{args.port}"
|
|
args.model = args.model or "default"
|
|
elif args.backend == "truss":
|
|
api_url = (
|
|
f"{args.base_url}/v1/models/model:predict"
|
|
if args.base_url
|
|
else f"http://{args.host}:{args.port}/v1/models/model:predict"
|
|
)
|
|
base_url = (
|
|
f"http://{args.host}:{args.port}" if args.base_url is None else args.base_url
|
|
)
|
|
|
|
# Get model name
|
|
if args.model is None:
|
|
if args.backend == "truss":
|
|
print(
|
|
"Please provide a model with `--model` when using truss backend. e.g. --model meta-llama/Llama-3.1-8B-Instruct"
|
|
)
|
|
sys.exit(1)
|
|
try:
|
|
response = requests.get(model_url, headers=get_auth_headers())
|
|
model_list = response.json().get("data", [])
|
|
args.model = model_list[0]["id"] if model_list else None
|
|
except Exception as e:
|
|
print(f"Failed to fetch model from {model_url}. Error: {e}")
|
|
print(
|
|
"Please specify the correct host and port using `--host` and `--port`."
|
|
)
|
|
sys.exit(1)
|
|
|
|
if args.model is None:
|
|
print("No model specified or found. Please provide a model using `--model`.")
|
|
sys.exit(1)
|
|
|
|
if not check_chat_template(args.model):
|
|
print(
|
|
"\nWARNING It is recommended to use the `Chat` or `Instruct` model for benchmarking.\n"
|
|
"Because when the tokenizer counts the output tokens, if there is gibberish, it might count incorrectly.\n"
|
|
)
|
|
|
|
print(f"{args}\n")
|
|
|
|
# Read dataset
|
|
backend = args.backend
|
|
model_id = args.model
|
|
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
|
tokenizer = get_tokenizer(tokenizer_id)
|
|
input_requests = get_dataset(args, tokenizer)
|
|
|
|
return asyncio.run(
|
|
benchmark(
|
|
backend=backend,
|
|
api_url=api_url,
|
|
base_url=base_url,
|
|
model_id=model_id,
|
|
tokenizer=tokenizer,
|
|
input_requests=input_requests,
|
|
request_rate=args.request_rate,
|
|
max_concurrency=args.max_concurrency,
|
|
disable_tqdm=args.disable_tqdm,
|
|
lora_names=args.lora_name,
|
|
extra_request_body=extra_request_body,
|
|
profile=args.profile,
|
|
pd_seperated=args.pd_seperated,
|
|
)
|
|
)
|
|
|
|
|
|
def set_ulimit(target_soft_limit=65535):
|
|
resource_type = resource.RLIMIT_NOFILE
|
|
current_soft, current_hard = resource.getrlimit(resource_type)
|
|
|
|
if current_soft < target_soft_limit:
|
|
try:
|
|
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
|
|
except ValueError as e:
|
|
print(f"Fail to set RLIMIT_NOFILE: {e}")
|
|
|
|
|
|
class LoRAPathAction(argparse.Action):
|
|
def __call__(self, parser, namespace, values, option_string=None):
|
|
setattr(namespace, self.dest, [])
|
|
for lora_name in values:
|
|
getattr(namespace, self.dest).append(lora_name)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = ArgumentParser(description="Benchmark the online serving throughput.")
|
|
parser.add_argument(
|
|
"--backend",
|
|
type=str,
|
|
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
|
default="sglang",
|
|
help="Must specify a backend, depending on the LLM Inference Engine.",
|
|
)
|
|
parser.add_argument(
|
|
"--base-url",
|
|
type=str,
|
|
default=None,
|
|
help="Server or API base url if not using http host and port.",
|
|
)
|
|
parser.add_argument(
|
|
"--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
|
|
)
|
|
parser.add_argument(
|
|
"--port",
|
|
type=int,
|
|
help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
|
|
)
|
|
parser.add_argument(
|
|
"--dataset-name",
|
|
type=str,
|
|
default="sharegpt",
|
|
choices=["sharegpt", "random", "generated-shared-prefix"],
|
|
help="Name of the dataset to benchmark on.",
|
|
)
|
|
parser.add_argument(
|
|
"--dataset-path", type=str, default="", help="Path to the dataset."
|
|
)
|
|
parser.add_argument(
|
|
"--model",
|
|
type=str,
|
|
help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
|
|
)
|
|
parser.add_argument(
|
|
"--tokenizer",
|
|
type=str,
|
|
help="Name or path of the tokenizer. If not set, using the model conf.",
|
|
)
|
|
parser.add_argument(
|
|
"--num-prompts",
|
|
type=int,
|
|
default=1000,
|
|
help="Number of prompts to process. Default is 1000.",
|
|
)
|
|
parser.add_argument(
|
|
"--sharegpt-output-len",
|
|
type=int,
|
|
default=None,
|
|
help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
|
|
)
|
|
parser.add_argument(
|
|
"--sharegpt-context-len",
|
|
type=int,
|
|
default=None,
|
|
help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
|
|
)
|
|
parser.add_argument(
|
|
"--random-input-len",
|
|
type=int,
|
|
default=1024,
|
|
help="Number of input tokens per request, used only for random dataset.",
|
|
)
|
|
parser.add_argument(
|
|
"--random-output-len",
|
|
default=1024,
|
|
type=int,
|
|
help="Number of output tokens per request, used only for random dataset.",
|
|
)
|
|
parser.add_argument(
|
|
"--random-range-ratio",
|
|
type=float,
|
|
default=0.0,
|
|
help="Range of sampled ratio of input/output length, "
|
|
"used only for random dataset.",
|
|
)
|
|
parser.add_argument(
|
|
"--request-rate",
|
|
type=float,
|
|
default=float("inf"),
|
|
help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
|
|
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.",
|
|
)
|
|
parser.add_argument(
|
|
"--max-concurrency",
|
|
type=int,
|
|
default=None,
|
|
help="Maximum number of concurrent requests. This can be used "
|
|
"to help simulate an environment where a higher level component "
|
|
"is enforcing a maximum number of concurrent requests. While the "
|
|
"--request-rate argument controls the rate at which requests are "
|
|
"initiated, this argument will control how many are actually allowed "
|
|
"to execute at a time. This means that when used in combination, the "
|
|
"actual request rate may be lower than specified with --request-rate, "
|
|
"if the server is not processing requests fast enough to keep up.",
|
|
)
|
|
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
|
|
parser.add_argument(
|
|
"--disable-tqdm",
|
|
action="store_true",
|
|
help="Specify to disable tqdm progress bar.",
|
|
)
|
|
parser.add_argument(
|
|
"--disable-stream",
|
|
action="store_true",
|
|
help="Disable streaming mode.",
|
|
)
|
|
parser.add_argument(
|
|
"--return-logprob",
|
|
action="store_true",
|
|
help="Return logprob.",
|
|
)
|
|
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
|
parser.add_argument(
|
|
"--disable-ignore-eos",
|
|
action="store_true",
|
|
help="Disable ignoring EOS.",
|
|
)
|
|
parser.add_argument(
|
|
"--extra-request-body",
|
|
metavar='{"key1": "value1", "key2": "value2"}',
|
|
type=str,
|
|
help="Append given JSON object to the request payload. You can use this to specify"
|
|
"additional generate params like sampling params.",
|
|
)
|
|
parser.add_argument(
|
|
"--apply-chat-template",
|
|
action="store_true",
|
|
help="Apply chat template",
|
|
)
|
|
parser.add_argument(
|
|
"--profile",
|
|
action="store_true",
|
|
help="Use Torch Profiler. The endpoint must be launched with "
|
|
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
|
)
|
|
parser.add_argument(
|
|
"--lora-name",
|
|
type=str,
|
|
nargs="*",
|
|
default=None,
|
|
action=LoRAPathAction,
|
|
help="The names of LoRA adapters. You can provide a list of names in the format {name} {name} {name}...",
|
|
)
|
|
parser.add_argument(
|
|
"--prompt-suffix",
|
|
type=str,
|
|
default="",
|
|
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
|
|
)
|
|
parser.add_argument(
|
|
"--pd-seperated",
|
|
action="store_true",
|
|
help="Benchmark PD disaggregation server",
|
|
)
|
|
|
|
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
|
|
group.add_argument(
|
|
"--gsp-num-groups",
|
|
type=int,
|
|
default=64,
|
|
help="Number of system prompt groups for generated-shared-prefix dataset",
|
|
)
|
|
group.add_argument(
|
|
"--gsp-prompts-per-group",
|
|
type=int,
|
|
default=16,
|
|
help="Number of prompts per system prompt group for generated-shared-prefix dataset",
|
|
)
|
|
group.add_argument(
|
|
"--gsp-system-prompt-len",
|
|
type=int,
|
|
default=2048,
|
|
help="Target length in tokens for system prompts in generated-shared-prefix dataset",
|
|
)
|
|
group.add_argument(
|
|
"--gsp-question-len",
|
|
type=int,
|
|
default=128,
|
|
help="Target length in tokens for questions in generated-shared-prefix dataset",
|
|
)
|
|
group.add_argument(
|
|
"--gsp-output-len",
|
|
type=int,
|
|
default=256,
|
|
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
|
)
|
|
args = parser.parse_args()
|
|
run_benchmark(args)
|