1030 lines
36 KiB
Python
1030 lines
36 KiB
Python
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py
|
|
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py
|
|
|
|
"""
|
|
Benchmark online serving with dynamic requests.
|
|
|
|
Usage:
|
|
python3 -m sglang.bench_serving --backend sglang --num-prompt 10
|
|
|
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5
|
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import random
|
|
import sys
|
|
import time
|
|
import traceback
|
|
import warnings
|
|
from argparse import ArgumentParser
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
|
|
|
import aiohttp
|
|
import numpy as np
|
|
import requests
|
|
from data_processing import MsgContent, SampleOutput, get_dataset
|
|
from tqdm.asyncio import tqdm
|
|
from transformers import PreTrainedTokenizerBase
|
|
|
|
from sglang.bench_serving import get_tokenizer, remove_prefix, set_ulimit
|
|
|
|
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)
|
|
|
|
global args
|
|
|
|
|
|
@dataclass
|
|
class RequestFuncInput:
|
|
prompts: List[Tuple[MsgContent, int, int]]
|
|
api_url: str
|
|
model: str
|
|
lora_name: str
|
|
extra_request_body: Dict[str, Any]
|
|
|
|
# For multiturn chat, store the context
|
|
prev_messages: List = field(default_factory=list)
|
|
finished_prompts: int = 0
|
|
|
|
|
|
@dataclass
|
|
class RequestFuncOutput:
|
|
generated_text: List[str] = field(default_factory=list)
|
|
prompt_len: List[int] = field(default_factory=list)
|
|
output_len: List[int] = field(default_factory=list)
|
|
latency: List[float] = field(default_factory=list)
|
|
ttft: List[float] = field(default_factory=list)
|
|
itl: List[float] = field(default_factory=list) # List of inter-token latencies
|
|
|
|
success: bool = False
|
|
error: str = ""
|
|
|
|
|
|
# set ignore_eos True by default
|
|
async def async_request_openai_completions(
|
|
request_func_input: RequestFuncInput,
|
|
queue: asyncio.Queue,
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
pbar: Optional[tqdm] = None,
|
|
) -> RequestFuncOutput:
|
|
api_url = request_func_input.api_url
|
|
assert api_url.endswith(
|
|
"completions"
|
|
), "OpenAI Completions API URL must end with 'completions'."
|
|
|
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
|
payload = {
|
|
"model": request_func_input.model,
|
|
"temperature": 0.0,
|
|
"best_of": 1,
|
|
"stream": not args.disable_stream,
|
|
"stream_options": {"include_usage": True},
|
|
"ignore_eos": not args.disable_ignore_eos,
|
|
**request_func_input.extra_request_body,
|
|
}
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
|
}
|
|
|
|
output = RequestFuncOutput()
|
|
|
|
prompt_idx = request_func_input.finished_prompts
|
|
messages = request_func_input.prev_messages
|
|
prompt, input_len, max_tokens = request_func_input.prompts[prompt_idx]
|
|
prompt_len = sum(
|
|
prompt[1] + prompt[2] # input_len + output_len
|
|
for prompt in request_func_input.prompts[:prompt_idx]
|
|
)
|
|
prompt_len += input_len
|
|
|
|
# Messages
|
|
messages.append(
|
|
{
|
|
"role": "user",
|
|
"content": prompt,
|
|
}
|
|
)
|
|
payload["messages"] = messages
|
|
payload["max_tokens"] = max_tokens
|
|
|
|
# output.prompt_len = request_func_input.prompt_len
|
|
# print(payload)
|
|
|
|
generated_text = ""
|
|
ttft = 0.0
|
|
st = time.perf_counter()
|
|
most_recent_timestamp = st
|
|
try:
|
|
async with session.post(
|
|
url=api_url, json=payload, headers=headers
|
|
) as response:
|
|
if response.status == 200:
|
|
actual_prompt_len = prompt_len - 1
|
|
actual_output_len = 0
|
|
async for chunk_bytes in response.content:
|
|
chunk_bytes = chunk_bytes.strip()
|
|
if not chunk_bytes:
|
|
continue
|
|
|
|
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
|
latency = time.perf_counter() - st
|
|
if chunk == "[DONE]":
|
|
pass
|
|
else:
|
|
data = json.loads(chunk)
|
|
timestamp = time.perf_counter()
|
|
# NOTE: Some completion API might have a last
|
|
# usage summary response without a token so we
|
|
# want to check a token was generated
|
|
if data["usage"] is not None and len(data["usage"]) > 0:
|
|
actual_prompt_len = data["usage"]["prompt_tokens"]
|
|
actual_output_len = data["usage"]["completion_tokens"]
|
|
continue
|
|
delta = data["choices"][0]["delta"]
|
|
|
|
if delta.get("content", None):
|
|
# First token
|
|
if ttft == 0.0:
|
|
ttft = time.perf_counter() - st
|
|
output.ttft.append(ttft)
|
|
|
|
# Decoding phase
|
|
else:
|
|
output.itl.append(timestamp - most_recent_timestamp)
|
|
|
|
generated_text += delta["content"]
|
|
most_recent_timestamp = timestamp
|
|
|
|
output.prompt_len.append(actual_prompt_len) # truncate <s>
|
|
output.output_len.append(actual_output_len)
|
|
output.generated_text.append(generated_text)
|
|
output.success = True
|
|
output.latency.append(latency)
|
|
|
|
# Prepare for the new request
|
|
request_func_input.prompts[prompt_idx] = (
|
|
prompt,
|
|
input_len,
|
|
actual_output_len, # changes from max_tokens to output_len
|
|
)
|
|
prompt_idx += 1
|
|
messages.append(
|
|
{
|
|
"role": "assistant",
|
|
"content": generated_text,
|
|
}
|
|
)
|
|
|
|
# Move the new request to the end of the queue
|
|
if prompt_idx < len(request_func_input.prompts):
|
|
request_func_input.finished_prompts = prompt_idx
|
|
request_func_input.prev_messages = messages
|
|
await queue.put(request_func_input)
|
|
else:
|
|
output.error = response.reason or ""
|
|
output.success = False
|
|
except Exception:
|
|
output.success = False
|
|
exc_info = sys.exc_info()
|
|
output.error = "".join(traceback.format_exception(*exc_info))
|
|
|
|
if pbar:
|
|
pbar.update(1)
|
|
return output
|
|
|
|
|
|
async def async_request_profile(api_url: str) -> RequestFuncOutput:
|
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
|
output = RequestFuncOutput()
|
|
try:
|
|
async with session.post(url=api_url) as response:
|
|
if response.status == 200:
|
|
output.success = True
|
|
else:
|
|
output.error = response.reason or ""
|
|
output.success = False
|
|
except Exception:
|
|
output.success = False
|
|
exc_info = sys.exc_info()
|
|
output.error = "".join(traceback.format_exception(*exc_info))
|
|
|
|
return output
|
|
|
|
|
|
ASYNC_REQUEST_FUNCS = {
|
|
"sglang": async_request_openai_completions,
|
|
"vllm": async_request_openai_completions,
|
|
"lmdeploy": async_request_openai_completions,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class BenchmarkMetrics:
|
|
completed: int
|
|
total_input: int
|
|
total_output: int
|
|
total_output_retokenized: int
|
|
request_throughput: float
|
|
input_throughput: float
|
|
output_throughput: float
|
|
output_throughput_retokenized: float
|
|
total_throughput: float
|
|
total_throughput_retokenized: float
|
|
mean_ttft_ms: float
|
|
median_ttft_ms: float
|
|
std_ttft_ms: float
|
|
p90_ttft_ms: float
|
|
p99_ttft_ms: float
|
|
mean_tpot_ms: float
|
|
median_tpot_ms: float
|
|
std_tpot_ms: float
|
|
p90_tpot_ms: float
|
|
p99_tpot_ms: float
|
|
mean_itl_ms: float
|
|
median_itl_ms: float
|
|
std_itl_ms: float
|
|
p90_itl_ms: float
|
|
p99_itl_ms: float
|
|
mean_e2e_latency_ms: float
|
|
median_e2e_latency_ms: float
|
|
std_e2e_latency_ms: float
|
|
p99_e2e_latency_ms: float
|
|
concurrency: float
|
|
|
|
|
|
async def get_requests(
|
|
input_requests_queue: asyncio.Queue,
|
|
request_rate: float,
|
|
num_actual_requests: int,
|
|
) -> AsyncGenerator[RequestFuncInput, None]:
|
|
for _ in range(num_actual_requests):
|
|
try:
|
|
request = await asyncio.wait_for(
|
|
input_requests_queue.get(), timeout=300
|
|
) # Wait for 5 minites then abort
|
|
except Exception as e:
|
|
print(f"exception: {e}")
|
|
break
|
|
|
|
yield request
|
|
|
|
if request_rate == float("inf"):
|
|
continue
|
|
|
|
interval = np.random.exponential(1.0 / request_rate)
|
|
await asyncio.sleep(interval)
|
|
|
|
|
|
def calculate_metrics(
|
|
outputs: List[RequestFuncOutput],
|
|
dur_s: float,
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
backend: str,
|
|
) -> Tuple[BenchmarkMetrics, List[int]]:
|
|
output_lens: List[int] = []
|
|
retokenized_output_lens: List[int] = []
|
|
total_input = 0
|
|
completed = 0
|
|
itls: List[float] = []
|
|
tpots: List[float] = []
|
|
ttfts: List[float] = []
|
|
e2e_latencies: List[float] = []
|
|
output_success = 0
|
|
for i in range(len(outputs)):
|
|
if outputs[i].success:
|
|
output_success += 1
|
|
assert len(outputs[i].generated_text) == len(outputs[i].latency)
|
|
assert len(outputs[i].generated_text) == len(outputs[i].ttft)
|
|
for j in range(len(outputs[i].generated_text)):
|
|
output_len = outputs[i].output_len[j]
|
|
output_lens.append(output_len)
|
|
retokenized_output_len = len(
|
|
tokenizer.encode(
|
|
outputs[i].generated_text[j], add_special_tokens=False
|
|
)
|
|
)
|
|
retokenized_output_lens.append(retokenized_output_len)
|
|
total_input += outputs[i].prompt_len[j]
|
|
if output_len > 1:
|
|
tpots.append(
|
|
(outputs[i].latency[j] - outputs[i].ttft[j]) / (output_len - 1)
|
|
)
|
|
|
|
completed += 1
|
|
itls += outputs[i].itl
|
|
ttfts += outputs[i].ttft
|
|
e2e_latencies += outputs[i].latency
|
|
|
|
else:
|
|
output_lens.append(0)
|
|
retokenized_output_lens.append(0)
|
|
|
|
if completed == 0:
|
|
warnings.warn(
|
|
"All requests failed. This is likely due to a misconfiguration "
|
|
"on the benchmark arguments.",
|
|
stacklevel=2,
|
|
)
|
|
metrics = BenchmarkMetrics(
|
|
completed=completed,
|
|
total_input=total_input,
|
|
total_output=sum(output_lens),
|
|
total_output_retokenized=sum(retokenized_output_lens),
|
|
request_throughput=completed / dur_s,
|
|
input_throughput=total_input / dur_s,
|
|
output_throughput=sum(output_lens) / dur_s,
|
|
output_throughput_retokenized=sum(retokenized_output_lens) / dur_s,
|
|
total_throughput=(total_input + sum(output_lens)) / dur_s,
|
|
total_throughput_retokenized=(total_input + sum(retokenized_output_lens))
|
|
/ dur_s,
|
|
mean_ttft_ms=np.mean(ttfts or 0)
|
|
* 1000, # ttfts is empty if streaming is not supported by backend
|
|
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
|
std_ttft_ms=np.std(ttfts or 0) * 1000,
|
|
p90_ttft_ms=np.percentile(ttfts or 0, 90) * 1000,
|
|
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
|
|
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
|
median_tpot_ms=np.median(tpots or 0) * 1000,
|
|
std_tpot_ms=np.std(tpots or 0) * 1000,
|
|
p90_tpot_ms=np.percentile(tpots or 0, 90) * 1000,
|
|
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
|
|
mean_itl_ms=np.mean(itls or 0) * 1000,
|
|
median_itl_ms=np.median(itls or 0) * 1000,
|
|
std_itl_ms=np.std(itls or 0) * 1000,
|
|
p90_itl_ms=np.percentile(itls or 0, 90) * 1000,
|
|
p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
|
|
mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000,
|
|
median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
|
|
std_e2e_latency_ms=np.std(e2e_latencies) * 1000,
|
|
p99_e2e_latency_ms=np.percentile(e2e_latencies, 99) * 1000,
|
|
concurrency=np.sum(e2e_latencies) / dur_s,
|
|
)
|
|
|
|
return metrics, output_lens
|
|
|
|
|
|
async def benchmark(
|
|
backend: str,
|
|
api_url: str,
|
|
base_url: str,
|
|
model_id: str,
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
input_requests: SampleOutput,
|
|
request_rate: float,
|
|
max_concurrency: Optional[int],
|
|
disable_tqdm: bool,
|
|
lora_name: str,
|
|
extra_request_body: Dict[str, Any],
|
|
profile: bool,
|
|
enable_shared_prefix: bool,
|
|
):
|
|
if backend in ASYNC_REQUEST_FUNCS:
|
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
|
else:
|
|
raise ValueError(f"Unknown backend: {backend}")
|
|
|
|
# Limit concurrency
|
|
# From https://github.com/vllm-project/vllm/pull/9390
|
|
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
|
|
|
|
async def limited_request_func(request_func_input, queue, tokenizer, pbar):
|
|
if semaphore is None:
|
|
return await request_func(
|
|
request_func_input=request_func_input,
|
|
queue=queue,
|
|
tokenizer=tokenizer,
|
|
pbar=pbar,
|
|
)
|
|
async with semaphore:
|
|
return await request_func(
|
|
request_func_input=request_func_input,
|
|
queue=queue,
|
|
tokenizer=tokenizer,
|
|
pbar=pbar,
|
|
)
|
|
|
|
num_actual_requests = sum(len(r) for r in input_requests)
|
|
print(f"Num of shared prefixes or conversations: {len(input_requests)}")
|
|
print(f"Num of total requests: {num_actual_requests}")
|
|
|
|
# flatten the requests for shared prefix
|
|
if enable_shared_prefix:
|
|
input_requests = [[r] for requests in input_requests for r in requests]
|
|
inputs_requests_queue = asyncio.Queue(maxsize=len(input_requests))
|
|
print("Starting initial single prompt test run...")
|
|
# NOTE: Just use the first request of the first conversation for warmup
|
|
test_input = RequestFuncInput(
|
|
model=model_id,
|
|
prompts=input_requests[0][:1],
|
|
api_url=api_url,
|
|
lora_name=lora_name,
|
|
extra_request_body=extra_request_body,
|
|
)
|
|
test_output = await request_func(
|
|
request_func_input=test_input, queue=inputs_requests_queue, tokenizer=tokenizer
|
|
)
|
|
if not test_output.success:
|
|
raise ValueError(
|
|
"Initial test run failed - Please make sure benchmark arguments "
|
|
f"are correctly specified. Error: {test_output.error}"
|
|
)
|
|
else:
|
|
print("Initial test run completed. Starting main benchmark run...")
|
|
|
|
# Check the states
|
|
assert inputs_requests_queue.empty()
|
|
|
|
# Flush cache
|
|
if "sglang" in backend:
|
|
requests.post(base_url + "/flush_cache")
|
|
|
|
time.sleep(1.0)
|
|
|
|
# Start profiler
|
|
if profile:
|
|
print("Starting profiler...")
|
|
profile_output = await async_request_profile(
|
|
api_url=base_url + "/start_profile"
|
|
)
|
|
if profile_output.success:
|
|
print("Profiler started")
|
|
|
|
for request in input_requests:
|
|
request_func_input = RequestFuncInput(
|
|
model=model_id,
|
|
prompts=request,
|
|
api_url=api_url,
|
|
lora_name=lora_name,
|
|
extra_request_body=extra_request_body,
|
|
)
|
|
inputs_requests_queue.put_nowait(request_func_input)
|
|
if (
|
|
not args.enable_multiturn
|
|
and not args.enable_shared_prefix
|
|
and not args.dataset_name == "generated-shared-prefix"
|
|
):
|
|
assert len(input_requests) == num_actual_requests
|
|
|
|
pbar = None if disable_tqdm else tqdm(total=num_actual_requests)
|
|
|
|
benchmark_start_time = time.perf_counter()
|
|
tasks: List[asyncio.Task] = []
|
|
async for request in get_requests(
|
|
inputs_requests_queue, request_rate, num_actual_requests
|
|
):
|
|
tasks.append(
|
|
asyncio.create_task(
|
|
limited_request_func(
|
|
request_func_input=request,
|
|
queue=inputs_requests_queue,
|
|
tokenizer=tokenizer,
|
|
pbar=pbar,
|
|
)
|
|
)
|
|
)
|
|
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
|
|
|
# Stop profiler
|
|
if profile:
|
|
print("Stopping profiler...")
|
|
profile_output = await async_request_profile(api_url=base_url + "/stop_profile")
|
|
if profile_output.success:
|
|
print("Profiler stopped")
|
|
|
|
if pbar is not None:
|
|
pbar.close()
|
|
|
|
# Compute metrics and print results
|
|
benchmark_duration = time.perf_counter() - benchmark_start_time
|
|
metrics, output_lens = calculate_metrics(
|
|
outputs=outputs,
|
|
dur_s=benchmark_duration,
|
|
tokenizer=tokenizer,
|
|
backend=backend,
|
|
)
|
|
|
|
print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
|
|
print("{:<40} {:<10}".format("Backend:", backend))
|
|
print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
|
|
print(
|
|
"{:<40} {:<10}".format(
|
|
"Max reqeuest concurrency:",
|
|
max_concurrency if max_concurrency else "not set",
|
|
)
|
|
)
|
|
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
|
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
|
|
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
|
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
|
|
print(
|
|
"{:<40} {:<10}".format(
|
|
"Total generated tokens (retokenized):", metrics.total_output_retokenized
|
|
)
|
|
)
|
|
print(
|
|
"{:<40} {:<10.2f}".format(
|
|
"Request throughput (req/s):", metrics.request_throughput
|
|
)
|
|
)
|
|
print(
|
|
"{:<40} {:<10.2f}".format(
|
|
"Input token throughput (tok/s):", metrics.input_throughput
|
|
)
|
|
)
|
|
print(
|
|
"{:<40} {:<10.2f}".format(
|
|
"Output token throughput (tok/s):", metrics.output_throughput
|
|
)
|
|
)
|
|
print(
|
|
"{:<40} {:<10.2f}".format(
|
|
"Total token throughput (tok/s):", metrics.total_throughput
|
|
)
|
|
)
|
|
print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency))
|
|
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
|
|
print(
|
|
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
|
|
)
|
|
print(
|
|
"{:<40} {:<10.2f}".format(
|
|
"Median E2E Latency (ms):", metrics.median_e2e_latency_ms
|
|
)
|
|
)
|
|
print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-"))
|
|
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
|
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
|
|
print("{:<40} {:<10.2f}".format("P90 TTFT (ms):", metrics.p90_ttft_ms))
|
|
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
|
|
print(
|
|
"{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-")
|
|
)
|
|
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
|
|
print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
|
|
print("{:<40} {:<10.2f}".format("P90 TPOT (ms):", metrics.p90_tpot_ms))
|
|
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
|
print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-"))
|
|
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
|
|
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
|
|
print("{:<40} {:<10.2f}".format("P90 ITL (ms):", metrics.p90_itl_ms))
|
|
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
|
|
print("=" * 50)
|
|
|
|
if (
|
|
metrics.median_ttft_ms is not None
|
|
and metrics.mean_itl_ms is not None
|
|
and metrics.output_throughput is not None
|
|
):
|
|
result = {
|
|
# Arguments
|
|
"backend": args.backend,
|
|
"dataset_name": args.dataset_name,
|
|
"request_rate": request_rate,
|
|
"max_concurrency": max_concurrency,
|
|
"fixed_output_len": args.fixed_output_len,
|
|
"random_input_len": args.random_input_len,
|
|
"random_output_len": args.random_output_len,
|
|
"random_range_ratio": args.random_range_ratio,
|
|
# Results
|
|
"duration": benchmark_duration,
|
|
"completed": metrics.completed,
|
|
"total_input_tokens": metrics.total_input,
|
|
"total_output_tokens": metrics.total_output,
|
|
"total_output_tokens_retokenized": metrics.total_output_retokenized,
|
|
"request_throughput": metrics.request_throughput,
|
|
"input_throughput": metrics.input_throughput,
|
|
"output_throughput": metrics.output_throughput,
|
|
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
|
|
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
|
|
"std_e2e_latency_ms": metrics.std_e2e_latency_ms,
|
|
"p99_e2e_latency_ms": metrics.p99_e2e_latency_ms,
|
|
"mean_ttft_ms": metrics.mean_ttft_ms,
|
|
"median_ttft_ms": metrics.median_ttft_ms,
|
|
"std_ttft_ms": metrics.std_ttft_ms,
|
|
"p99_ttft_ms": metrics.p99_ttft_ms,
|
|
"mean_tpot_ms": metrics.mean_tpot_ms,
|
|
"median_tpot_ms": metrics.median_tpot_ms,
|
|
"std_tpot_ms": metrics.std_tpot_ms,
|
|
"p99_tpot_ms": metrics.p99_tpot_ms,
|
|
"mean_itl_ms": metrics.mean_itl_ms,
|
|
"median_itl_ms": metrics.median_itl_ms,
|
|
"std_itl_ms": metrics.std_itl_ms,
|
|
"p99_itl_ms": metrics.p99_itl_ms,
|
|
"concurrency": metrics.concurrency,
|
|
"input_throughput": metrics.input_throughput,
|
|
"output_throughput": metrics.output_throughput,
|
|
"fixed_output_len": args.fixed_output_len,
|
|
"random_input_len": args.random_input_len,
|
|
"random_output_len": args.random_output_len,
|
|
"random_range_ratio": args.random_range_ratio,
|
|
"duration": benchmark_duration,
|
|
"completed": metrics.completed,
|
|
}
|
|
else:
|
|
print(f"Error running benchmark for request rate: {request_rate}")
|
|
print("-" * 30)
|
|
|
|
# Determine output file name
|
|
if args.output_file:
|
|
output_file_name = args.output_file
|
|
else:
|
|
now = datetime.now().strftime("%m%d")
|
|
if args.dataset_name == "random":
|
|
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
|
|
else:
|
|
output_file_name = (
|
|
f"{args.backend}_{now}_{args.num_prompts}_{args.dataset_name}.jsonl"
|
|
)
|
|
|
|
# Append results to a JSONL file
|
|
with open(output_file_name, "a") as file:
|
|
file.write(json.dumps(result) + "\n")
|
|
|
|
result = {
|
|
"duration": benchmark_duration,
|
|
"completed": metrics.completed,
|
|
"total_input_tokens": metrics.total_input,
|
|
"total_output_tokens": metrics.total_output,
|
|
"total_output_tokens_retokenized": metrics.total_output_retokenized,
|
|
"request_throughput": metrics.request_throughput,
|
|
"input_throughput": metrics.input_throughput,
|
|
"output_throughput": metrics.output_throughput,
|
|
"mean_ttft_ms": metrics.mean_ttft_ms,
|
|
"median_ttft_ms": metrics.median_ttft_ms,
|
|
"std_ttft_ms": metrics.std_ttft_ms,
|
|
"p90_ttft_ms": metrics.p90_ttft_ms,
|
|
"p99_ttft_ms": metrics.p99_ttft_ms,
|
|
"mean_tpot_ms": metrics.mean_tpot_ms,
|
|
"median_tpot_ms": metrics.median_tpot_ms,
|
|
"std_tpot_ms": metrics.std_tpot_ms,
|
|
"p90_tpot_ms": metrics.p90_tpot_ms,
|
|
"p99_tpot_ms": metrics.p99_tpot_ms,
|
|
"mean_itl_ms": metrics.mean_itl_ms,
|
|
"median_itl_ms": metrics.median_itl_ms,
|
|
"std_itl_ms": metrics.std_itl_ms,
|
|
"p90_itl_ms": metrics.p90_itl_ms,
|
|
"p99_itl_ms": metrics.p99_itl_ms,
|
|
"input_lens": [output.prompt_len for output in outputs],
|
|
"output_lens": output_lens,
|
|
"ttfts": [output.ttft for output in outputs],
|
|
"itls": [output.itl for output in outputs],
|
|
"generated_texts": [output.generated_text for output in outputs],
|
|
"errors": [output.error for output in outputs],
|
|
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
|
|
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
|
|
}
|
|
return result
|
|
|
|
|
|
def run_benchmark(args_: argparse.Namespace):
|
|
global args
|
|
args = args_
|
|
|
|
# Set default value for max_concurrency if not present
|
|
if not hasattr(args, "max_concurrency"):
|
|
args.max_concurrency = None
|
|
|
|
# Set global environments
|
|
set_ulimit()
|
|
random.seed(args.seed)
|
|
np.random.seed(args.seed)
|
|
|
|
extra_request_body = {}
|
|
if args.extra_request_body:
|
|
extra_request_body = json.loads(args.extra_request_body)
|
|
|
|
# Set url
|
|
if args.port is None:
|
|
args.port = {
|
|
"sglang": 30000,
|
|
"lmdeploy": 23333,
|
|
"vllm": 8000,
|
|
}.get(args.backend, 30000)
|
|
|
|
model_url = (
|
|
f"{args.base_url}/v1/models"
|
|
if args.base_url
|
|
else f"http://{args.host}:{args.port}/v1/models"
|
|
)
|
|
|
|
if args.backend in ["sglang", "vllm", "lmdeploy"]:
|
|
api_url = (
|
|
f"{args.base_url}/v1/chat/completions"
|
|
if args.base_url
|
|
else f"http://{args.host}:{args.port}/v1/chat/completions"
|
|
)
|
|
base_url = (
|
|
f"http://{args.host}:{args.port}" if args.base_url is None else args.base_url
|
|
)
|
|
|
|
# Get model name
|
|
if args.model is None:
|
|
if args.backend == "truss":
|
|
print(
|
|
"Please provide a model with `--model` when using truss backend. e.g. --model meta-llama/Llama-3.1-8B-Instruct"
|
|
)
|
|
sys.exit(1)
|
|
try:
|
|
response = requests.get(model_url)
|
|
model_list = response.json().get("data", [])
|
|
args.model = model_list[0]["id"] if model_list else None
|
|
except Exception as e:
|
|
print(f"Failed to fetch model from {model_url}. Error: {e}")
|
|
print(
|
|
"Please specify the correct host and port using `--host` and `--port`."
|
|
)
|
|
sys.exit(1)
|
|
|
|
if args.model is None:
|
|
print("No model specified or found. Please provide a model using `--model`.")
|
|
sys.exit(1)
|
|
|
|
# Dataset compatibility check
|
|
if args.enable_multiturn:
|
|
# TODO: Support multiturn for random
|
|
if args.dataset_name not in ["sharegpt", "ultrachat", "loogle", "nextqa"]:
|
|
print(
|
|
"Multiturn conversation is only supported for sharegpt, ultrachat, loogle, and nextqa datasets."
|
|
)
|
|
sys.exit(1)
|
|
|
|
if args.enable_shared_prefix:
|
|
if args.dataset_name not in ["loogle", "nextqa"]:
|
|
print("Shared prefix is only supported for loogle and nextqa datasets.")
|
|
sys.exit(1)
|
|
|
|
print(f"{args}\n")
|
|
|
|
# Read dataset
|
|
backend = args.backend
|
|
model_id = args.model
|
|
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
|
|
|
tokenizer = get_tokenizer(tokenizer_id)
|
|
|
|
input_requests = get_dataset(args, tokenizer)
|
|
|
|
return asyncio.run(
|
|
benchmark(
|
|
backend=backend,
|
|
api_url=api_url,
|
|
base_url=base_url,
|
|
model_id=model_id,
|
|
tokenizer=tokenizer,
|
|
input_requests=input_requests,
|
|
request_rate=args.request_rate,
|
|
max_concurrency=args.max_concurrency,
|
|
disable_tqdm=args.disable_tqdm,
|
|
lora_name=args.lora_name,
|
|
extra_request_body=extra_request_body,
|
|
profile=args.profile,
|
|
enable_shared_prefix=args.enable_shared_prefix,
|
|
)
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = ArgumentParser(description="Benchmark the online serving throughput.")
|
|
parser.add_argument(
|
|
"--backend",
|
|
type=str,
|
|
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
|
default="sglang",
|
|
help="Must specify a backend, depending on the LLM Inference Engine.",
|
|
)
|
|
parser.add_argument(
|
|
"--base-url",
|
|
type=str,
|
|
default=None,
|
|
help="Server or API base url if not using http host and port.",
|
|
)
|
|
parser.add_argument(
|
|
"--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
|
|
)
|
|
parser.add_argument(
|
|
"--port",
|
|
type=int,
|
|
help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
|
|
)
|
|
parser.add_argument(
|
|
"--dataset-name",
|
|
type=str,
|
|
default="sharegpt",
|
|
choices=[
|
|
"sharegpt",
|
|
"random",
|
|
"generated-shared-prefix",
|
|
"ultrachat",
|
|
"loogle",
|
|
"nextqa",
|
|
],
|
|
help="Name of the dataset to benchmark on.",
|
|
)
|
|
parser.add_argument(
|
|
"--dataset-path", type=str, default="", help="Path to the dataset."
|
|
)
|
|
parser.add_argument(
|
|
"--model",
|
|
type=str,
|
|
help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
|
|
)
|
|
parser.add_argument(
|
|
"--tokenizer",
|
|
type=str,
|
|
help="Name or path of the tokenizer. If not set, using the model conf.",
|
|
)
|
|
parser.add_argument(
|
|
"--chat-template",
|
|
type=str,
|
|
help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
|
|
)
|
|
parser.add_argument(
|
|
"--num-prompts",
|
|
type=int,
|
|
default=1000,
|
|
help="Number of prompts to process. Default is 1000.",
|
|
)
|
|
parser.add_argument(
|
|
"--fixed-output-len",
|
|
type=int,
|
|
default=None,
|
|
help="Output length for each request. Overrides the output length from the dataset.",
|
|
)
|
|
parser.add_argument(
|
|
"--sharegpt-context-len",
|
|
type=int,
|
|
default=None,
|
|
help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
|
|
)
|
|
parser.add_argument(
|
|
"--random-input-len",
|
|
type=int,
|
|
default=1024,
|
|
help="Number of input tokens per request, used only for random dataset.",
|
|
)
|
|
parser.add_argument(
|
|
"--random-output-len",
|
|
default=1024,
|
|
type=int,
|
|
help="Number of output tokens per request, used only for random dataset.",
|
|
)
|
|
parser.add_argument(
|
|
"--random-range-ratio",
|
|
type=float,
|
|
default=0.0,
|
|
help="Range of sampled ratio of input/output length, "
|
|
"used only for random dataset.",
|
|
)
|
|
parser.add_argument(
|
|
"--request-rate",
|
|
type=float,
|
|
default=float("inf"),
|
|
help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
|
|
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.",
|
|
)
|
|
parser.add_argument(
|
|
"--max-concurrency",
|
|
type=int,
|
|
default=None,
|
|
help="Maximum number of concurrent requests. This can be used "
|
|
"to help simulate an environment where a higher level component "
|
|
"is enforcing a maximum number of concurrent requests. While the "
|
|
"--request-rate argument controls the rate at which requests are "
|
|
"initiated, this argument will control how many are actually allowed "
|
|
"to execute at a time. This means that when used in combination, the "
|
|
"actual request rate may be lower than specified with --request-rate, "
|
|
"if the server is not processing requests fast enough to keep up.",
|
|
)
|
|
parser.add_argument(
|
|
"--multi",
|
|
action="store_true",
|
|
help="Use request rate range rather than single value.",
|
|
)
|
|
parser.add_argument(
|
|
"--request-rate-range",
|
|
type=str,
|
|
default="2,34,2",
|
|
help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.",
|
|
)
|
|
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
|
|
parser.add_argument(
|
|
"--enable-multiturn",
|
|
action="store_true",
|
|
help="Enable multiturn chat for online serving benchmarking. "
|
|
"This option is effective on the following datasets: "
|
|
"sharegpt, ultrachat, loogle, nextqa",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-shared-prefix",
|
|
action="store_true",
|
|
help="Enable shared prefix for online serving benchmarking. "
|
|
"This option is effective on the following datasets: "
|
|
"loogle, nextqa",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--disable-shuffle",
|
|
action="store_true",
|
|
help="Disable shuffling datasets. This is useful to generate stable output "
|
|
"in benchmarking",
|
|
)
|
|
parser.add_argument(
|
|
"--disable-tqdm",
|
|
action="store_true",
|
|
help="Specify to disable tqdm progress bar.",
|
|
)
|
|
parser.add_argument(
|
|
"--disable-stream",
|
|
action="store_true",
|
|
help="Disable streaming mode.",
|
|
)
|
|
parser.add_argument(
|
|
"--return-logprob",
|
|
action="store_true",
|
|
help="Return logprob.",
|
|
)
|
|
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
|
parser.add_argument(
|
|
"--disable-ignore-eos",
|
|
action="store_true",
|
|
help="Disable ignoring EOS.",
|
|
)
|
|
parser.add_argument(
|
|
"--extra-request-body",
|
|
metavar='{"key1": "value1", "key2": "value2"}',
|
|
type=str,
|
|
help="Append given JSON object to the request payload. You can use this to specify"
|
|
"additional generate params like sampling params.",
|
|
)
|
|
parser.add_argument(
|
|
"--apply-chat-template",
|
|
action="store_true",
|
|
help="Apply chat template",
|
|
)
|
|
parser.add_argument(
|
|
"--profile",
|
|
action="store_true",
|
|
help="Use Torch Profiler. The endpoint must be launched with "
|
|
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
|
)
|
|
parser.add_argument(
|
|
"--lora-name",
|
|
type=str,
|
|
default=None,
|
|
help="The name of LoRA adapter",
|
|
)
|
|
|
|
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
|
|
group.add_argument(
|
|
"--gsp-num-groups",
|
|
type=int,
|
|
default=64,
|
|
help="Number of system prompt groups for generated-shared-prefix dataset",
|
|
)
|
|
group.add_argument(
|
|
"--gsp-prompts-per-group",
|
|
type=int,
|
|
default=16,
|
|
help="Number of prompts per system prompt group for generated-shared-prefix dataset",
|
|
)
|
|
group.add_argument(
|
|
"--gsp-system-prompt-len",
|
|
type=int,
|
|
default=2048,
|
|
help="Target length in tokens for system prompts in generated-shared-prefix dataset",
|
|
)
|
|
group.add_argument(
|
|
"--gsp-question-len",
|
|
type=int,
|
|
default=128,
|
|
help="Target length in tokens for questions in generated-shared-prefix dataset",
|
|
)
|
|
group.add_argument(
|
|
"--gsp-output-len",
|
|
type=int,
|
|
default=256,
|
|
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
|
)
|
|
# videos specific
|
|
parser.add_argument(
|
|
"--max-frames",
|
|
type=int,
|
|
default=sys.maxsize,
|
|
help="The maximum number of frames to extract from each video. "
|
|
"This option is specific to the nextqa dataset (video benchmark). ",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
if args.enable_multiturn and args.enable_shared_prefix:
|
|
parser.error(
|
|
"--enable-multiturn and --enable-shared-prefix cannot be set at the same time."
|
|
)
|
|
|
|
run_benchmark(args)
|