sglang_v0.5.2/sglang/benchmark/hicache/bench_multiturn.py

518 lines
18 KiB
Python

import argparse
import asyncio
import json
import queue
import random
import threading
import time
from datetime import datetime
from typing import Optional
import aiohttp
import numpy as np
import requests
from tqdm.asyncio import tqdm
from sglang.bench_serving import (
RequestFuncOutput,
get_tokenizer,
remove_prefix,
sample_random_requests,
)
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)
def parse_args():
parser = argparse.ArgumentParser(
description="Script to benchmark concurrent requests to a server."
)
parser.add_argument(
"--num-clients",
type=int,
default=256,
help="Number of concurrent clients",
)
parser.add_argument(
"--max-parallel",
type=int,
default=128,
help="Maximum number of parallel requests",
)
parser.add_argument(
"--request-length",
type=int,
default=512,
help="Length of each new request",
)
parser.add_argument(
"--output-length",
type=int,
default=64,
help="Length of each output",
)
parser.add_argument(
"--num-rounds",
type=int,
default=5,
help="Number of rounds per client",
)
parser.add_argument(
"--distribution",
type=str,
default="poisson",
choices=["poisson", "uniform"],
help="Distribution type for request intervals (poisson or uniform)",
)
parser.add_argument(
"--request-rate",
type=float,
default=1.0,
help="Average number of requests per second",
)
parser.add_argument(
"--host",
type=str,
default="localhost",
help="Server hostname or IP (default: localhost)",
)
parser.add_argument(
"--port",
type=int,
default=30000,
help="Server port (default: 30000)",
)
parser.add_argument(
"--model-path",
type=str,
default="meta-llama/Llama-3.1-8B-Instruct",
help="model path compatible with Hugging Face Transformers",
)
parser.add_argument(
"--dataset-path",
type=str,
default="",
help="local dataset to sample tokens from",
)
parser.add_argument(
"--log-file",
type=str,
default="performance_metrics.jsonl",
help="File to log performance metrics",
)
parser.add_argument(
"--disable-auto-run",
action="store_true",
help="If set, disable automatically testing with a range of request rates.",
)
parser.add_argument(
"--disable-random-sample",
action="store_true",
help="If set, disable random sampling of requests from the ShareGPT dataset.",
)
parser.add_argument(
"--sub-question-input-length",
type=int,
default=0,
help="Length of the sub question input for each request, if set 0 use request_length",
)
parser.add_argument(
"--ready-queue-policy",
type=str,
default="random",
help="Policy for popping requests from the ready queue (random or fifo)",
)
parser.add_argument(
"--tag",
type=str,
default="",
help="Tag of a certain run in the log file",
)
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument(
"--lora-path",
type=str,
default="",
help="String of LoRA path. Currently we only support benchmarking on a single LoRA adaptor.",
)
return parser.parse_args()
async def async_request_sglang_generate(
payload,
url,
pbar: Optional[tqdm] = None,
):
"""
Sends a streaming request to the server. Gathers text token-by-token.
"""
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {}
generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
output = RequestFuncOutput()
try:
async with session.post(url=url, json=payload, headers=headers) as response:
if response.status == 200:
prompt_tokens = 0
cached_tokens = 0
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
latency = time.perf_counter() - st
if chunk == "[DONE]":
pass
else:
data = json.loads(chunk)
if data["text"]:
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
prompt_tokens = (data.get("meta_info") or {}).get(
"prompt_tokens", 0
)
cached_tokens = (data.get("meta_info") or {}).get(
"cached_tokens", 0
)
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text = data["text"]
output.generated_text = generated_text
output.success = True
output.latency = latency
output.prompt_len = prompt_tokens
output.cached_tokens = cached_tokens
output.generated_len = len(output.itl) + 1
else:
output.error = response.reason or ""
output.success = False
except Exception as e:
output.success = False
output.error = str(e)
print(f"Request failed: {e}")
if pbar:
pbar.update(1)
return output
def gen_payload(prompt, output_len, lora_path=""):
payload = {
"text": prompt,
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": output_len,
"ignore_eos": True,
},
"stream": True,
"stream_options": {"include_usage": True},
"lora_path": lora_path,
"return_logprob": False,
"logprob_start_len": -1,
}
return payload
def log_to_jsonl_file(data, file_path="performance_metrics.jsonl", tag=""):
"""Append the data with a timestamp and tag to the specified JSONL file."""
timestamped_data = {"timestamp": datetime.now().isoformat(), "tag": tag, **data}
try:
with open(file_path, "a") as file:
file.write(
json.dumps(timestamped_data) + "\n"
) # Write as a single line in JSONL format
except IOError as e:
print(f"Error writing to JSONL file: {e}")
class ReadyQueue:
"""
Thread-safe queue that can pop requests in different orders based on given policy.
"""
def __init__(self, init_requests=None, policy="random"):
self.lock = threading.Lock()
self.requests = init_requests or []
self.policy = policy
def append(self, item):
with self.lock:
self.requests.append(item)
def pop(self):
with self.lock:
if not self.requests:
return None
if self.policy == "random":
index = random.randrange(len(self.requests))
return self.requests.pop(index)
elif self.policy == "fifo":
return self.requests.pop(0)
else:
# todo, varying thinking time of clients
raise ValueError(f"{self.policy} not implemented")
class WorkloadGenerator:
def __init__(self, args):
# Construct the base URL for requests
self.url = f"http://{args.host}:{args.port}/generate"
self.tokenizer = get_tokenizer(args.model_path)
self.distribution = args.distribution
self.request_rate = args.request_rate
self.start_time = None
self.finished_time = None
self.sent_requests = 0
self.completed_requests = 0
self.candidate_inputs = sample_random_requests(
input_len=args.request_length,
output_len=args.output_length,
num_prompts=args.num_clients,
range_ratio=1.0,
tokenizer=self.tokenizer,
dataset_path=args.dataset_path,
random_sample=not args.disable_random_sample,
)
self.candidate_inputs = [i.prompt for i in self.candidate_inputs]
if args.sub_question_input_length != 0:
sub_question_input_length = args.sub_question_input_length
else:
sub_question_input_length = args.request_length
self.sub_question_inputs = sample_random_requests(
input_len=sub_question_input_length,
output_len=args.output_length,
num_prompts=args.num_clients * max(args.num_rounds - 1, 1),
range_ratio=1.0,
tokenizer=self.tokenizer,
dataset_path=args.dataset_path,
random_sample=not args.disable_random_sample,
)
init_requests = [
(
i,
gen_payload(
self.candidate_inputs[i], args.output_length, args.lora_path
),
)
for i in range(args.num_clients)
]
self.client_records = {
i: {"round": 0, "history": init_requests[i][1]["text"]}
for i in range(args.num_clients)
}
self.ready_queue = ReadyQueue(
init_requests=init_requests, policy=args.ready_queue_policy
)
self.candidate_inputs = self.candidate_inputs[args.num_clients :]
self.response_queue = queue.Queue()
self.pbar = tqdm(total=args.num_clients * args.num_rounds)
self.performance_metrics = {
"ttft": [],
"latency": [],
"prompt_len": [],
"cached_tokens": [],
"generated_len": [],
}
self.num_rounds = args.num_rounds
self.max_parallel = args.max_parallel
self.output_length = args.output_length
async def handle_request(self, item):
try:
client_id, payload = item
response = await async_request_sglang_generate(payload, self.url, self.pbar)
if self.pbar.n == self.pbar.total:
self.finished_time = time.perf_counter()
self.response_queue.put((client_id, response))
except Exception as e:
print(f"Request failed: {e}")
def request_sender(self):
async def request_loop():
while True:
if self.sent_requests - self.completed_requests < self.max_parallel:
new_request = self.ready_queue.pop()
if new_request:
asyncio.create_task(self.handle_request(new_request))
self.sent_requests += 1
else:
await asyncio.sleep(0.05)
continue
if self.pbar.n == self.pbar.total:
break
# Calculate Poisson-distributed wait time
if self.distribution == "poisson":
sleep_time = random.expovariate(self.request_rate)
elif self.distribution == "uniform":
avg_interval = (
1.0 / self.request_rate if self.request_rate > 0 else 1.0
)
sleep_time = random.uniform(0, 2 * avg_interval)
else:
raise ValueError("Invalid distribution type")
await asyncio.sleep(sleep_time) # Wait before sending the next request
# Create and run the event loop for asynchronous requests
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(request_loop())
loop.close()
def response_handler(self):
while True:
try:
client_id, response = self.response_queue.get(
timeout=10
) # Block until response is available
if not response.success:
raise ValueError(f"Request failed with error: {response.error}")
self.client_records[client_id]["history"] += response.generated_text
self.client_records[client_id]["round"] += 1
self.performance_metrics["ttft"].append(response.ttft)
self.performance_metrics["latency"].append(response.latency)
self.performance_metrics["prompt_len"].append(response.prompt_len)
self.performance_metrics["cached_tokens"].append(response.cached_tokens)
self.performance_metrics["generated_len"].append(response.generated_len)
self.completed_requests += 1
if self.client_records[client_id]["round"] < self.num_rounds:
# append new request to client's history
self.client_records[client_id][
"history"
] += self.sub_question_inputs.pop().prompt
self.ready_queue.append(
(
client_id,
gen_payload(
self.client_records[client_id]["history"],
self.output_length,
args.lora_path,
),
)
)
except queue.Empty:
if self.pbar.n == self.pbar.total:
break
except ValueError as e:
print(f"Error processing response for client {client_id}: {e}")
continue
def run(self):
request_thread = threading.Thread(target=self.request_sender, daemon=True)
response_thread = threading.Thread(target=self.response_handler, daemon=True)
self.start_time = time.perf_counter()
request_thread.start()
response_thread.start()
request_thread.join()
response_thread.join()
self.pbar.close()
duration = self.finished_time - self.start_time
performance_data = {
"summary": {
"total_requests": len(self.performance_metrics["ttft"]),
"request_rate": self.request_rate,
"average_ttft": sum(self.performance_metrics["ttft"])
/ len(self.performance_metrics["ttft"]),
"p90_ttft": sorted(self.performance_metrics["ttft"])[
int(0.9 * len(self.performance_metrics["ttft"]))
],
"median_ttft": sorted(self.performance_metrics["ttft"])[
len(self.performance_metrics["ttft"]) // 2
],
"average_latency": sum(self.performance_metrics["latency"])
/ len(self.performance_metrics["latency"]),
"p90_latency": sorted(self.performance_metrics["latency"])[
int(0.9 * len(self.performance_metrics["latency"]))
],
"median_latency": sorted(self.performance_metrics["latency"])[
len(self.performance_metrics["latency"]) // 2
],
"input_token_throughput": sum(self.performance_metrics["prompt_len"])
/ duration,
"output_token_throughput": sum(
self.performance_metrics["generated_len"]
)
/ duration,
"throughput": self.pbar.total / duration,
"cache_hit_rate": (
0
if sum(self.performance_metrics["prompt_len"]) == 0
else sum(self.performance_metrics["cached_tokens"])
/ sum(self.performance_metrics["prompt_len"])
),
},
}
print("All requests completed")
print("Performance metrics summary:")
print(
f" Total requests: {performance_data['summary']['total_requests']} at {performance_data['summary']['request_rate']} requests per second"
)
print(f" Average TTFT: {performance_data['summary']['average_ttft']:.2f}")
print(f" P90 TTFT: {performance_data['summary']['p90_ttft']:.2f}")
print(f" Median TTFT: {performance_data['summary']['median_ttft']:.2f}")
print(
f" Average latency: {performance_data['summary']['average_latency']:.2f}"
)
print(f" P90 latency: {performance_data['summary']['p90_latency']:.2f}")
print(f" Median latency: {performance_data['summary']['median_latency']:.2f}")
print(
f" Input token throughput: {performance_data['summary']['input_token_throughput']:.2f} tokens per second"
)
print(
f" Output token throughput: {performance_data['summary']['output_token_throughput']:.2f} tokens per second"
)
print(
f" Request Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
)
print(f" Cache Hit Rate: {performance_data['summary']['cache_hit_rate']:.6f}")
return performance_data
if __name__ == "__main__":
args = parse_args()
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
random.seed(args.seed)
np.random.seed(args.seed)
if args.disable_auto_run:
print("Running with specified request rate...")
request_rates = [args.request_rate]
else:
print("Auto-running with different request rates...")
request_rates = [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
for rate in request_rates:
args.request_rate = rate
requests.post(flush_cache_url)
time.sleep(1)
performance_data = WorkloadGenerator(args).run()
log_to_jsonl_file(performance_data, args.log_file, tag=args.tag)