396 lines
14 KiB
Python
396 lines
14 KiB
Python
import argparse
|
|
import asyncio
|
|
import json
|
|
import queue
|
|
import random
|
|
import threading
|
|
import time
|
|
from datetime import datetime
|
|
from typing import Optional
|
|
|
|
import aiohttp
|
|
import requests
|
|
from tqdm.asyncio import tqdm
|
|
|
|
from sglang.bench_serving import (
|
|
RequestFuncOutput,
|
|
get_tokenizer,
|
|
remove_prefix,
|
|
sample_random_requests,
|
|
)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description="Script to benchmark concurrent requests to a server."
|
|
)
|
|
parser.add_argument(
|
|
"--num-clients",
|
|
type=int,
|
|
default=256,
|
|
help="Number of concurrent clients",
|
|
)
|
|
parser.add_argument(
|
|
"--max-parallel",
|
|
type=int,
|
|
default=128,
|
|
help="Maximum number of parallel requests",
|
|
)
|
|
parser.add_argument(
|
|
"--request-length",
|
|
type=int,
|
|
default=512,
|
|
help="Length of each new request",
|
|
)
|
|
parser.add_argument(
|
|
"--output-length",
|
|
type=int,
|
|
default=64,
|
|
help="Length of each output",
|
|
)
|
|
parser.add_argument(
|
|
"--num-rounds",
|
|
type=int,
|
|
default=5,
|
|
help="Number of rounds per client",
|
|
)
|
|
parser.add_argument(
|
|
"--distribution",
|
|
type=str,
|
|
default="poisson",
|
|
choices=["poisson", "uniform"],
|
|
help="Distribution type for request intervals (poisson or uniform)",
|
|
)
|
|
parser.add_argument(
|
|
"--request-rate",
|
|
type=float,
|
|
default=1.0,
|
|
help="Average number of requests per second",
|
|
)
|
|
parser.add_argument(
|
|
"--host",
|
|
type=str,
|
|
default="localhost",
|
|
help="Server hostname or IP (default: localhost)",
|
|
)
|
|
parser.add_argument(
|
|
"--port",
|
|
type=int,
|
|
default=30000,
|
|
help="Server port (default: 30000)",
|
|
)
|
|
parser.add_argument(
|
|
"--model-path",
|
|
type=str,
|
|
default="meta-llama/Llama-3.1-8B-Instruct",
|
|
help="model path compatible with Hugging Face Transformers",
|
|
)
|
|
parser.add_argument(
|
|
"--dataset-path",
|
|
type=str,
|
|
default="",
|
|
help="local dataset to sample tokens from",
|
|
)
|
|
parser.add_argument(
|
|
"--log-file",
|
|
type=str,
|
|
default="performance_metrics.jsonl",
|
|
help="File to log performance metrics",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
async def async_request_sglang_generate(
|
|
payload,
|
|
url,
|
|
pbar: Optional[tqdm] = None,
|
|
):
|
|
"""
|
|
Sends a streaming request to the server. Gathers text token-by-token.
|
|
"""
|
|
async with aiohttp.ClientSession() as session:
|
|
headers = {}
|
|
generated_text = ""
|
|
ttft = 0.0
|
|
st = time.perf_counter()
|
|
most_recent_timestamp = st
|
|
output = RequestFuncOutput()
|
|
|
|
try:
|
|
async with session.post(url=url, json=payload, headers=headers) as response:
|
|
if response.status == 200:
|
|
async for chunk_bytes in response.content:
|
|
chunk_bytes = chunk_bytes.strip()
|
|
if not chunk_bytes:
|
|
continue
|
|
|
|
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
|
latency = time.perf_counter() - st
|
|
if chunk == "[DONE]":
|
|
pass
|
|
else:
|
|
data = json.loads(chunk)
|
|
|
|
if data["text"]:
|
|
timestamp = time.perf_counter()
|
|
# First token
|
|
if ttft == 0.0:
|
|
ttft = time.perf_counter() - st
|
|
output.ttft = ttft
|
|
|
|
# Decoding phase
|
|
else:
|
|
output.itl.append(timestamp - most_recent_timestamp)
|
|
|
|
most_recent_timestamp = timestamp
|
|
generated_text = data["text"]
|
|
|
|
output.generated_text = generated_text
|
|
output.success = True
|
|
output.latency = latency
|
|
else:
|
|
output.error = response.reason or ""
|
|
output.success = False
|
|
except Exception as e:
|
|
output.success = False
|
|
output.error = str(e)
|
|
print(f"Request failed: {e}")
|
|
|
|
if pbar:
|
|
pbar.update(1)
|
|
return output
|
|
|
|
|
|
def gen_payload(prompt, output_len):
|
|
payload = {
|
|
"text": prompt,
|
|
"sampling_params": {
|
|
"temperature": 0.0,
|
|
"max_new_tokens": output_len,
|
|
"ignore_eos": True,
|
|
},
|
|
"stream": True,
|
|
"lora_path": "",
|
|
"return_logprob": False,
|
|
"logprob_start_len": -1,
|
|
}
|
|
return payload
|
|
|
|
|
|
def log_to_jsonl_file(data, file_path="performance_metrics.jsonl"):
|
|
"""Append the data with a timestamp to the specified JSONL file."""
|
|
timestamped_data = {"timestamp": datetime.now().isoformat(), **data}
|
|
try:
|
|
with open(file_path, "a") as file:
|
|
file.write(
|
|
json.dumps(timestamped_data) + "\n"
|
|
) # Write as a single line in JSONL format
|
|
except IOError as e:
|
|
print(f"Error writing to JSONL file: {e}")
|
|
|
|
|
|
class ReadyQueue:
|
|
"""
|
|
Thread-safe queue that can pop requests in different orders based on given policy.
|
|
"""
|
|
|
|
def __init__(self, init_requests=None, policy="random"):
|
|
self.lock = threading.Lock()
|
|
self.requests = init_requests or []
|
|
self.policy = policy
|
|
|
|
def append(self, item):
|
|
with self.lock:
|
|
self.requests.append(item)
|
|
|
|
def pop(self):
|
|
with self.lock:
|
|
if not self.requests:
|
|
return None
|
|
if self.policy == "random":
|
|
index = random.randrange(len(self.requests))
|
|
return self.requests.pop(index)
|
|
elif self.policy == "fifo":
|
|
return self.requests.pop(0)
|
|
else:
|
|
# todo, varying thinking time of clients
|
|
raise ValueError(f"{self.policy} not implemented")
|
|
|
|
|
|
class WorkloadGenerator:
|
|
def __init__(self, args):
|
|
# Construct the base URL for requests
|
|
self.url = f"http://{args.host}:{args.port}/generate"
|
|
|
|
self.tokenizer = get_tokenizer(args.model_path)
|
|
self.distribution = args.distribution
|
|
self.request_rate = args.request_rate
|
|
self.start_time = None
|
|
self.finished_time = None
|
|
|
|
self.sent_requests = 0
|
|
self.completed_requests = 0
|
|
|
|
self.candidate_inputs = sample_random_requests(
|
|
input_len=args.request_length,
|
|
output_len=args.output_length,
|
|
num_prompts=args.num_clients * args.num_rounds,
|
|
range_ratio=1.0,
|
|
tokenizer=self.tokenizer,
|
|
dataset_path=args.dataset_path,
|
|
)
|
|
self.candidate_inputs = [i[0] for i in self.candidate_inputs]
|
|
|
|
init_requests = [
|
|
(i, gen_payload(self.candidate_inputs[i], args.output_length))
|
|
for i in range(args.num_clients)
|
|
]
|
|
self.client_records = {
|
|
i: {"round": 0, "history": init_requests[i][1]["text"]}
|
|
for i in range(args.num_clients)
|
|
}
|
|
self.ready_queue = ReadyQueue(init_requests=init_requests)
|
|
self.candidate_inputs = self.candidate_inputs[args.num_clients :]
|
|
|
|
self.response_queue = queue.Queue()
|
|
self.pbar = tqdm(total=args.num_clients * args.num_rounds)
|
|
self.performance_metrics = {"ttft": [], "latency": []}
|
|
|
|
async def handle_request(self, item):
|
|
try:
|
|
client_id, payload = item
|
|
response = await async_request_sglang_generate(payload, self.url, self.pbar)
|
|
if self.pbar.n == self.pbar.total:
|
|
self.finished_time = time.time()
|
|
self.response_queue.put((client_id, response))
|
|
except Exception as e:
|
|
print(f"Request failed: {e}")
|
|
|
|
def request_sender(self):
|
|
async def request_loop():
|
|
while True:
|
|
if self.sent_requests - self.completed_requests < args.max_parallel:
|
|
new_request = self.ready_queue.pop()
|
|
if new_request:
|
|
asyncio.create_task(self.handle_request(new_request))
|
|
self.sent_requests += 1
|
|
else:
|
|
await asyncio.sleep(0.05)
|
|
continue
|
|
|
|
if self.pbar.n == self.pbar.total:
|
|
break
|
|
|
|
# Calculate Poisson-distributed wait time
|
|
if self.distribution == "poisson":
|
|
sleep_time = random.expovariate(self.request_rate)
|
|
elif self.distribution == "uniform":
|
|
avg_interval = (
|
|
1.0 / self.request_rate if self.request_rate > 0 else 1.0
|
|
)
|
|
sleep_time = random.uniform(0, 2 * avg_interval)
|
|
else:
|
|
raise ValueError("Invalid distribution type")
|
|
await asyncio.sleep(sleep_time) # Wait before sending the next request
|
|
|
|
# Create and run the event loop for asynchronous requests
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
loop.run_until_complete(request_loop())
|
|
loop.close()
|
|
|
|
def response_handler(self):
|
|
while True:
|
|
try:
|
|
client_id, response = self.response_queue.get(
|
|
timeout=10
|
|
) # Block until response is available
|
|
if not response.success:
|
|
raise ValueError(f"Request failed with error: {response.error}")
|
|
self.client_records[client_id]["history"] += response.generated_text
|
|
self.client_records[client_id]["round"] += 1
|
|
self.performance_metrics["ttft"].append(response.ttft)
|
|
self.performance_metrics["latency"].append(response.latency)
|
|
self.completed_requests += 1
|
|
|
|
if self.client_records[client_id]["round"] < args.num_rounds:
|
|
self.client_records[client_id][
|
|
"history"
|
|
] += self.candidate_inputs.pop()
|
|
self.ready_queue.append(
|
|
(
|
|
client_id,
|
|
gen_payload(
|
|
self.client_records[client_id]["history"],
|
|
args.output_length,
|
|
),
|
|
)
|
|
)
|
|
except queue.Empty:
|
|
if self.pbar.n == self.pbar.total:
|
|
break
|
|
|
|
def run(self):
|
|
request_thread = threading.Thread(target=self.request_sender, daemon=True)
|
|
response_thread = threading.Thread(target=self.response_handler, daemon=True)
|
|
|
|
self.start_time = time.time()
|
|
request_thread.start()
|
|
response_thread.start()
|
|
|
|
request_thread.join()
|
|
response_thread.join()
|
|
self.pbar.close()
|
|
|
|
performance_data = {
|
|
"summary": {
|
|
"total_requests": len(self.performance_metrics["ttft"]),
|
|
"request_rate": self.request_rate,
|
|
"average_ttft": sum(self.performance_metrics["ttft"])
|
|
/ len(self.performance_metrics["ttft"]),
|
|
"p90_ttft": sorted(self.performance_metrics["ttft"])[
|
|
int(0.9 * len(self.performance_metrics["ttft"]))
|
|
],
|
|
"median_ttft": sorted(self.performance_metrics["ttft"])[
|
|
len(self.performance_metrics["ttft"]) // 2
|
|
],
|
|
"average_latency": sum(self.performance_metrics["latency"])
|
|
/ len(self.performance_metrics["latency"]),
|
|
"p90_latency": sorted(self.performance_metrics["latency"])[
|
|
int(0.9 * len(self.performance_metrics["latency"]))
|
|
],
|
|
"median_latency": sorted(self.performance_metrics["latency"])[
|
|
len(self.performance_metrics["latency"]) // 2
|
|
],
|
|
"throughput": self.pbar.total / (self.finished_time - self.start_time),
|
|
},
|
|
}
|
|
print("All requests completed")
|
|
print("Performance metrics summary:")
|
|
print(
|
|
f" Total requests: {performance_data['summary']['total_requests']} at {performance_data['summary']['request_rate']} requests per second"
|
|
)
|
|
print(f" Average TTFT: {performance_data['summary']['average_ttft']:.2f}")
|
|
print(f" P90 TTFT: {performance_data['summary']['p90_ttft']:.2f}")
|
|
print(f" Median TTFT: {performance_data['summary']['median_ttft']:.2f}")
|
|
print(
|
|
f" Average latency: {performance_data['summary']['average_latency']:.2f}"
|
|
)
|
|
print(f" P90 latency: {performance_data['summary']['p90_latency']:.2f}")
|
|
print(f" Median latency: {performance_data['summary']['median_latency']:.2f}")
|
|
print(
|
|
f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
|
|
)
|
|
log_to_jsonl_file(performance_data, args.log_file)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
|
|
|
|
for request_rate in [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]:
|
|
args.request_rate = request_rate
|
|
requests.post(flush_cache_url)
|
|
time.sleep(1)
|
|
WorkloadGenerator(args).run()
|