sglang0.4.5.post1/benchmark/hicache/bench_multiturn.py

396 lines
14 KiB
Python

import argparse
import asyncio
import json
import queue
import random
import threading
import time
from datetime import datetime
from typing import Optional
import aiohttp
import requests
from tqdm.asyncio import tqdm
from sglang.bench_serving import (
RequestFuncOutput,
get_tokenizer,
remove_prefix,
sample_random_requests,
)
def parse_args():
parser = argparse.ArgumentParser(
description="Script to benchmark concurrent requests to a server."
)
parser.add_argument(
"--num-clients",
type=int,
default=256,
help="Number of concurrent clients",
)
parser.add_argument(
"--max-parallel",
type=int,
default=128,
help="Maximum number of parallel requests",
)
parser.add_argument(
"--request-length",
type=int,
default=512,
help="Length of each new request",
)
parser.add_argument(
"--output-length",
type=int,
default=64,
help="Length of each output",
)
parser.add_argument(
"--num-rounds",
type=int,
default=5,
help="Number of rounds per client",
)
parser.add_argument(
"--distribution",
type=str,
default="poisson",
choices=["poisson", "uniform"],
help="Distribution type for request intervals (poisson or uniform)",
)
parser.add_argument(
"--request-rate",
type=float,
default=1.0,
help="Average number of requests per second",
)
parser.add_argument(
"--host",
type=str,
default="localhost",
help="Server hostname or IP (default: localhost)",
)
parser.add_argument(
"--port",
type=int,
default=30000,
help="Server port (default: 30000)",
)
parser.add_argument(
"--model-path",
type=str,
default="meta-llama/Llama-3.1-8B-Instruct",
help="model path compatible with Hugging Face Transformers",
)
parser.add_argument(
"--dataset-path",
type=str,
default="",
help="local dataset to sample tokens from",
)
parser.add_argument(
"--log-file",
type=str,
default="performance_metrics.jsonl",
help="File to log performance metrics",
)
return parser.parse_args()
async def async_request_sglang_generate(
payload,
url,
pbar: Optional[tqdm] = None,
):
"""
Sends a streaming request to the server. Gathers text token-by-token.
"""
async with aiohttp.ClientSession() as session:
headers = {}
generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
output = RequestFuncOutput()
try:
async with session.post(url=url, json=payload, headers=headers) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
latency = time.perf_counter() - st
if chunk == "[DONE]":
pass
else:
data = json.loads(chunk)
if data["text"]:
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text = data["text"]
output.generated_text = generated_text
output.success = True
output.latency = latency
else:
output.error = response.reason or ""
output.success = False
except Exception as e:
output.success = False
output.error = str(e)
print(f"Request failed: {e}")
if pbar:
pbar.update(1)
return output
def gen_payload(prompt, output_len):
payload = {
"text": prompt,
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": output_len,
"ignore_eos": True,
},
"stream": True,
"lora_path": "",
"return_logprob": False,
"logprob_start_len": -1,
}
return payload
def log_to_jsonl_file(data, file_path="performance_metrics.jsonl"):
"""Append the data with a timestamp to the specified JSONL file."""
timestamped_data = {"timestamp": datetime.now().isoformat(), **data}
try:
with open(file_path, "a") as file:
file.write(
json.dumps(timestamped_data) + "\n"
) # Write as a single line in JSONL format
except IOError as e:
print(f"Error writing to JSONL file: {e}")
class ReadyQueue:
"""
Thread-safe queue that can pop requests in different orders based on given policy.
"""
def __init__(self, init_requests=None, policy="random"):
self.lock = threading.Lock()
self.requests = init_requests or []
self.policy = policy
def append(self, item):
with self.lock:
self.requests.append(item)
def pop(self):
with self.lock:
if not self.requests:
return None
if self.policy == "random":
index = random.randrange(len(self.requests))
return self.requests.pop(index)
elif self.policy == "fifo":
return self.requests.pop(0)
else:
# todo, varying thinking time of clients
raise ValueError(f"{self.policy} not implemented")
class WorkloadGenerator:
def __init__(self, args):
# Construct the base URL for requests
self.url = f"http://{args.host}:{args.port}/generate"
self.tokenizer = get_tokenizer(args.model_path)
self.distribution = args.distribution
self.request_rate = args.request_rate
self.start_time = None
self.finished_time = None
self.sent_requests = 0
self.completed_requests = 0
self.candidate_inputs = sample_random_requests(
input_len=args.request_length,
output_len=args.output_length,
num_prompts=args.num_clients * args.num_rounds,
range_ratio=1.0,
tokenizer=self.tokenizer,
dataset_path=args.dataset_path,
)
self.candidate_inputs = [i[0] for i in self.candidate_inputs]
init_requests = [
(i, gen_payload(self.candidate_inputs[i], args.output_length))
for i in range(args.num_clients)
]
self.client_records = {
i: {"round": 0, "history": init_requests[i][1]["text"]}
for i in range(args.num_clients)
}
self.ready_queue = ReadyQueue(init_requests=init_requests)
self.candidate_inputs = self.candidate_inputs[args.num_clients :]
self.response_queue = queue.Queue()
self.pbar = tqdm(total=args.num_clients * args.num_rounds)
self.performance_metrics = {"ttft": [], "latency": []}
async def handle_request(self, item):
try:
client_id, payload = item
response = await async_request_sglang_generate(payload, self.url, self.pbar)
if self.pbar.n == self.pbar.total:
self.finished_time = time.time()
self.response_queue.put((client_id, response))
except Exception as e:
print(f"Request failed: {e}")
def request_sender(self):
async def request_loop():
while True:
if self.sent_requests - self.completed_requests < args.max_parallel:
new_request = self.ready_queue.pop()
if new_request:
asyncio.create_task(self.handle_request(new_request))
self.sent_requests += 1
else:
await asyncio.sleep(0.05)
continue
if self.pbar.n == self.pbar.total:
break
# Calculate Poisson-distributed wait time
if self.distribution == "poisson":
sleep_time = random.expovariate(self.request_rate)
elif self.distribution == "uniform":
avg_interval = (
1.0 / self.request_rate if self.request_rate > 0 else 1.0
)
sleep_time = random.uniform(0, 2 * avg_interval)
else:
raise ValueError("Invalid distribution type")
await asyncio.sleep(sleep_time) # Wait before sending the next request
# Create and run the event loop for asynchronous requests
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(request_loop())
loop.close()
def response_handler(self):
while True:
try:
client_id, response = self.response_queue.get(
timeout=10
) # Block until response is available
if not response.success:
raise ValueError(f"Request failed with error: {response.error}")
self.client_records[client_id]["history"] += response.generated_text
self.client_records[client_id]["round"] += 1
self.performance_metrics["ttft"].append(response.ttft)
self.performance_metrics["latency"].append(response.latency)
self.completed_requests += 1
if self.client_records[client_id]["round"] < args.num_rounds:
self.client_records[client_id][
"history"
] += self.candidate_inputs.pop()
self.ready_queue.append(
(
client_id,
gen_payload(
self.client_records[client_id]["history"],
args.output_length,
),
)
)
except queue.Empty:
if self.pbar.n == self.pbar.total:
break
def run(self):
request_thread = threading.Thread(target=self.request_sender, daemon=True)
response_thread = threading.Thread(target=self.response_handler, daemon=True)
self.start_time = time.time()
request_thread.start()
response_thread.start()
request_thread.join()
response_thread.join()
self.pbar.close()
performance_data = {
"summary": {
"total_requests": len(self.performance_metrics["ttft"]),
"request_rate": self.request_rate,
"average_ttft": sum(self.performance_metrics["ttft"])
/ len(self.performance_metrics["ttft"]),
"p90_ttft": sorted(self.performance_metrics["ttft"])[
int(0.9 * len(self.performance_metrics["ttft"]))
],
"median_ttft": sorted(self.performance_metrics["ttft"])[
len(self.performance_metrics["ttft"]) // 2
],
"average_latency": sum(self.performance_metrics["latency"])
/ len(self.performance_metrics["latency"]),
"p90_latency": sorted(self.performance_metrics["latency"])[
int(0.9 * len(self.performance_metrics["latency"]))
],
"median_latency": sorted(self.performance_metrics["latency"])[
len(self.performance_metrics["latency"]) // 2
],
"throughput": self.pbar.total / (self.finished_time - self.start_time),
},
}
print("All requests completed")
print("Performance metrics summary:")
print(
f" Total requests: {performance_data['summary']['total_requests']} at {performance_data['summary']['request_rate']} requests per second"
)
print(f" Average TTFT: {performance_data['summary']['average_ttft']:.2f}")
print(f" P90 TTFT: {performance_data['summary']['p90_ttft']:.2f}")
print(f" Median TTFT: {performance_data['summary']['median_ttft']:.2f}")
print(
f" Average latency: {performance_data['summary']['average_latency']:.2f}"
)
print(f" P90 latency: {performance_data['summary']['p90_latency']:.2f}")
print(f" Median latency: {performance_data['summary']['median_latency']:.2f}")
print(
f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
)
log_to_jsonl_file(performance_data, args.log_file)
if __name__ == "__main__":
args = parse_args()
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
for request_rate in [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]:
args.request_rate = request_rate
requests.post(flush_cache_url)
time.sleep(1)
WorkloadGenerator(args).run()