477 lines
17 KiB
Python
477 lines
17 KiB
Python
# Copyright 2023-2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import random
|
|
import resource
|
|
import sys
|
|
import time
|
|
import traceback
|
|
from argparse import ArgumentParser
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import aiohttp
|
|
import numpy as np
|
|
from launch_server import LORA_PATH, NUM_LORAS
|
|
from tqdm.asyncio import tqdm
|
|
from transformers import PreTrainedTokenizerBase
|
|
|
|
from sglang.bench_serving import (
|
|
AIOHTTP_TIMEOUT,
|
|
RequestFuncInput,
|
|
RequestFuncOutput,
|
|
calculate_metrics,
|
|
get_request,
|
|
get_tokenizer,
|
|
remove_prefix,
|
|
sample_random_requests,
|
|
)
|
|
|
|
global args
|
|
|
|
|
|
# set ignore_eos True by default
|
|
async def async_request_openai_completions(
|
|
request_func_input: RequestFuncInput,
|
|
pbar: Optional[tqdm] = None,
|
|
) -> RequestFuncOutput:
|
|
api_url = request_func_input.api_url
|
|
# assert api_url.endswith(
|
|
# "completions"
|
|
# ), "OpenAI Completions API URL must end with 'completions'."
|
|
|
|
prompt = request_func_input.prompt
|
|
|
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
|
# payload = {
|
|
# "model": request_func_input.model,
|
|
# "prompt": prompt,
|
|
# "temperature": 0.0,
|
|
# "best_of": 1,
|
|
# "max_tokens": request_func_input.output_len,
|
|
# "stream": not args.disable_stream,
|
|
# "ignore_eos": not args.disable_ignore_eos,
|
|
# **request_func_input.extra_request_body,
|
|
# }
|
|
# headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
|
if args.base_only:
|
|
payload = {
|
|
"text": prompt,
|
|
"sampling_params": {"max_new_tokens": request_func_input.output_len},
|
|
}
|
|
else:
|
|
payload = {
|
|
"text": prompt,
|
|
"sampling_params": {"max_new_tokens": request_func_input.output_len},
|
|
"lora_path": f"lora{random.randint(0, NUM_LORAS - 1)}",
|
|
}
|
|
headers = {"Authorization": ""}
|
|
|
|
output = RequestFuncOutput()
|
|
output.prompt_len = request_func_input.prompt_len
|
|
|
|
generated_text = ""
|
|
ttft = 0.0
|
|
st = time.perf_counter()
|
|
most_recent_timestamp = st
|
|
try:
|
|
async with session.post(
|
|
url=api_url, json=payload, headers=headers
|
|
) as response:
|
|
if response.status == 200:
|
|
async for chunk_bytes in response.content:
|
|
chunk_bytes = chunk_bytes.strip()
|
|
if not chunk_bytes:
|
|
continue
|
|
|
|
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
|
latency = time.perf_counter() - st
|
|
if chunk == "[DONE]":
|
|
pass
|
|
else:
|
|
data = json.loads(chunk)
|
|
|
|
# NOTE: Some completion API might have a last
|
|
# usage summary response without a token so we
|
|
# want to check a token was generated
|
|
if data["text"]:
|
|
# if data["choices"][0]["text"]:
|
|
timestamp = time.perf_counter()
|
|
# First token
|
|
if ttft == 0.0:
|
|
ttft = time.perf_counter() - st
|
|
output.ttft = ttft
|
|
|
|
# Decoding phase
|
|
else:
|
|
output.itl.append(timestamp - most_recent_timestamp)
|
|
|
|
most_recent_timestamp = timestamp
|
|
# generated_text += data["choices"][0]["text"]
|
|
generated_text += data["text"]
|
|
|
|
output.generated_text = generated_text
|
|
output.success = True
|
|
output.latency = latency
|
|
output.output_len = request_func_input.output_len
|
|
else:
|
|
output.error = response.reason or ""
|
|
output.success = False
|
|
except Exception:
|
|
output.success = False
|
|
exc_info = sys.exc_info()
|
|
output.error = "".join(traceback.format_exception(*exc_info))
|
|
|
|
if pbar:
|
|
pbar.update(1)
|
|
return output
|
|
|
|
|
|
ASYNC_REQUEST_FUNCS = {
|
|
"sglang": async_request_openai_completions,
|
|
}
|
|
|
|
|
|
async def benchmark(
|
|
backend: str,
|
|
api_url: str,
|
|
model_id: str,
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
input_requests: List[Tuple[str, int, int]],
|
|
request_rate: float,
|
|
disable_tqdm: bool,
|
|
extra_request_body: Dict[str, Any],
|
|
):
|
|
if backend in ASYNC_REQUEST_FUNCS:
|
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
|
else:
|
|
raise ValueError(f"Unknown backend: {backend}")
|
|
|
|
print("Starting initial single prompt test run...")
|
|
test_request = input_requests[0]
|
|
test_input = RequestFuncInput(
|
|
model=model_id,
|
|
prompt=test_request.prompt,
|
|
api_url=api_url,
|
|
prompt_len=test_request.prompt_len,
|
|
output_len=test_request.output_len,
|
|
lora_name="dummy", # the lora_name argument will not be used
|
|
image_data=None,
|
|
extra_request_body=extra_request_body,
|
|
)
|
|
test_output = await request_func(request_func_input=test_input)
|
|
if not test_output.success:
|
|
raise ValueError(
|
|
"Initial test run failed - Please make sure benchmark arguments "
|
|
f"are correctly specified. Error: {test_output.error}"
|
|
)
|
|
else:
|
|
print("Initial test run completed. Starting main benchmark run...")
|
|
|
|
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
|
|
|
benchmark_start_time = time.perf_counter()
|
|
tasks: List[asyncio.Task] = []
|
|
async for request in get_request(input_requests, request_rate):
|
|
request_func_input = RequestFuncInput(
|
|
model=model_id,
|
|
prompt=request.prompt,
|
|
api_url=api_url,
|
|
prompt_len=request.prompt_len,
|
|
output_len=request.output_len,
|
|
lora_name="dummy",
|
|
image_data=None,
|
|
extra_request_body=extra_request_body,
|
|
)
|
|
tasks.append(
|
|
asyncio.create_task(
|
|
request_func(request_func_input=request_func_input, pbar=pbar)
|
|
)
|
|
)
|
|
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
|
|
|
if pbar is not None:
|
|
pbar.close()
|
|
|
|
benchmark_duration = time.perf_counter() - benchmark_start_time
|
|
|
|
metrics, output_lens = calculate_metrics(
|
|
input_requests=input_requests,
|
|
outputs=outputs,
|
|
dur_s=benchmark_duration,
|
|
tokenizer=tokenizer,
|
|
backend=backend,
|
|
)
|
|
|
|
print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
|
|
print("{:<40} {:<10}".format("Backend:", backend))
|
|
print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
|
|
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
|
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
|
|
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
|
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
|
|
print(
|
|
"{:<40} {:<10}".format(
|
|
"Total generated tokens (retokenized):", metrics.total_output_retokenized
|
|
)
|
|
)
|
|
print(
|
|
"{:<40} {:<10.2f}".format(
|
|
"Request throughput (req/s):", metrics.request_throughput
|
|
)
|
|
)
|
|
print(
|
|
"{:<40} {:<10.2f}".format(
|
|
"Input token throughput (tok/s):", metrics.input_throughput
|
|
)
|
|
)
|
|
print(
|
|
"{:<40} {:<10.2f}".format(
|
|
"Output token throughput (tok/s):", metrics.output_throughput
|
|
)
|
|
)
|
|
print(
|
|
"{:<40} {:<10.2f}".format("Total throughput (tok/s):", metrics.total_throughput)
|
|
)
|
|
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
|
|
print(
|
|
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
|
|
)
|
|
print(
|
|
"{:<40} {:<10.2f}".format(
|
|
"Median E2E Latency (ms):", metrics.median_e2e_latency_ms
|
|
)
|
|
)
|
|
print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-"))
|
|
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
|
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
|
|
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
|
|
print(
|
|
"{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-")
|
|
)
|
|
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
|
|
print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
|
|
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
|
print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-"))
|
|
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
|
|
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
|
|
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
|
|
print("=" * 50)
|
|
|
|
if (
|
|
metrics.median_ttft_ms is not None
|
|
and metrics.mean_itl_ms is not None
|
|
and metrics.output_throughput is not None
|
|
):
|
|
result = {
|
|
"backend": args.backend,
|
|
"request_rate": request_rate,
|
|
"total_input_tokens": metrics.total_input,
|
|
"total_output_tokens": metrics.total_output,
|
|
"total_output_tokens_retokenized": metrics.total_output_retokenized,
|
|
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
|
|
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
|
|
"median_ttft_ms": metrics.median_ttft_ms,
|
|
"median_itl_ms": metrics.median_itl_ms,
|
|
"output_throughput": metrics.output_throughput,
|
|
"random_input_len": args.random_input_len,
|
|
"random_output_len": args.random_output_len,
|
|
"random_range_ratio": args.random_range_ratio,
|
|
"duration": benchmark_duration,
|
|
"completed": metrics.completed,
|
|
}
|
|
else:
|
|
print(f"Error running benchmark for request rate: {request_rate}")
|
|
print("-" * 30)
|
|
|
|
# Determine output file name
|
|
if args.output_file:
|
|
output_file_name = args.output_file
|
|
else:
|
|
now = datetime.now().strftime("%m%d")
|
|
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
|
|
|
|
# Append results to a JSONL file
|
|
with open(output_file_name, "a") as file:
|
|
file.write(json.dumps(result) + "\n")
|
|
|
|
result = {
|
|
"duration": benchmark_duration,
|
|
"completed": metrics.completed,
|
|
"total_input_tokens": metrics.total_input,
|
|
"total_output_tokens": metrics.total_output,
|
|
"total_output_tokens_retokenized": metrics.total_output_retokenized,
|
|
"request_throughput": metrics.request_throughput,
|
|
"input_throughput": metrics.input_throughput,
|
|
"output_throughput": metrics.output_throughput,
|
|
"mean_ttft_ms": metrics.mean_ttft_ms,
|
|
"median_ttft_ms": metrics.median_ttft_ms,
|
|
"std_ttft_ms": metrics.std_ttft_ms,
|
|
"p99_ttft_ms": metrics.p99_ttft_ms,
|
|
"mean_tpot_ms": metrics.mean_tpot_ms,
|
|
"median_tpot_ms": metrics.median_tpot_ms,
|
|
"std_tpot_ms": metrics.std_tpot_ms,
|
|
"p99_tpot_ms": metrics.p99_tpot_ms,
|
|
"mean_itl_ms": metrics.mean_itl_ms,
|
|
"median_itl_ms": metrics.median_itl_ms,
|
|
"std_itl_ms": metrics.std_itl_ms,
|
|
"p99_itl_ms": metrics.p99_itl_ms,
|
|
"input_lens": [output.prompt_len for output in outputs],
|
|
"output_lens": output_lens,
|
|
"ttfts": [output.ttft for output in outputs],
|
|
"itls": [output.itl for output in outputs],
|
|
"generated_texts": [output.generated_text for output in outputs],
|
|
"errors": [output.error for output in outputs],
|
|
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
|
|
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
|
|
}
|
|
return result
|
|
|
|
|
|
def run_benchmark(args_: argparse.Namespace):
|
|
global args
|
|
args = args_
|
|
|
|
# Set global environments
|
|
set_ulimit()
|
|
random.seed(args.seed)
|
|
np.random.seed(args.seed)
|
|
|
|
# Set url
|
|
if args.port is None:
|
|
args.port = {
|
|
"sglang": 30000,
|
|
}.get(args.backend, 30000)
|
|
|
|
# api_url = (
|
|
# f"{args.base_url}/v1/completions"
|
|
# if args.base_url
|
|
# else f"http://{args.host}:{args.port}/v1/completions"
|
|
# )
|
|
api_url = (
|
|
f"{args.base_url}/generate"
|
|
if args.base_url
|
|
else f"http://{args.host}:{args.port}/generate"
|
|
)
|
|
|
|
print(f"{args}\n")
|
|
|
|
# Read dataset
|
|
backend = args.backend
|
|
model_id = args.model = LORA_PATH["base"]
|
|
tokenizer_id = args.model
|
|
|
|
tokenizer = get_tokenizer(tokenizer_id)
|
|
|
|
input_requests = sample_random_requests(
|
|
input_len=args.random_input_len,
|
|
output_len=args.random_output_len,
|
|
num_prompts=args.num_prompts,
|
|
range_ratio=args.random_range_ratio,
|
|
tokenizer=tokenizer,
|
|
dataset_path="",
|
|
)
|
|
|
|
return asyncio.run(
|
|
benchmark(
|
|
backend=backend,
|
|
api_url=api_url,
|
|
model_id=model_id,
|
|
tokenizer=tokenizer,
|
|
input_requests=input_requests,
|
|
request_rate=args.request_rate,
|
|
disable_tqdm=False,
|
|
extra_request_body={},
|
|
)
|
|
)
|
|
|
|
|
|
def set_ulimit(target_soft_limit=65535):
|
|
resource_type = resource.RLIMIT_NOFILE
|
|
current_soft, current_hard = resource.getrlimit(resource_type)
|
|
|
|
if current_soft < target_soft_limit:
|
|
try:
|
|
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
|
|
except ValueError as e:
|
|
print(f"Fail to set RLIMIT_NOFILE: {e}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = ArgumentParser(description="Benchmark the online lora serving throughput.")
|
|
parser.add_argument(
|
|
"--backend",
|
|
type=str,
|
|
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
|
default="sglang",
|
|
help="Must specify a backend, depending on the LLM Inference Engine.",
|
|
)
|
|
parser.add_argument(
|
|
"--base-url",
|
|
type=str,
|
|
default=None,
|
|
help="Server or API base url if not using http host and port.",
|
|
)
|
|
parser.add_argument(
|
|
"--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
|
|
)
|
|
parser.add_argument(
|
|
"--port",
|
|
type=int,
|
|
help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
|
|
)
|
|
parser.add_argument(
|
|
"--num-prompts",
|
|
type=int,
|
|
default=50,
|
|
help="Number of prompts to process. Default is 1000.",
|
|
)
|
|
parser.add_argument(
|
|
"--random-input-len",
|
|
type=int,
|
|
default=1024,
|
|
help="Number of input tokens per request, used only for random dataset.",
|
|
)
|
|
parser.add_argument(
|
|
"--random-output-len",
|
|
type=int,
|
|
default=128,
|
|
help="Number of output tokens per request, used only for random dataset.",
|
|
)
|
|
parser.add_argument(
|
|
"--random-range-ratio",
|
|
type=float,
|
|
default=0.0,
|
|
help="Range of sampled ratio of input/output length, "
|
|
"used only for random dataset.",
|
|
)
|
|
parser.add_argument(
|
|
"--request-rate",
|
|
type=float,
|
|
default=float("inf"),
|
|
help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
|
|
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.",
|
|
)
|
|
parser.add_argument(
|
|
"--base-only",
|
|
action="store_true",
|
|
)
|
|
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
|
|
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
|
args = parser.parse_args()
|
|
run_benchmark(args)
|