import argparse import asyncio import json import queue import random import threading import time from datetime import datetime from typing import Optional import aiohttp import numpy as np import requests from tqdm.asyncio import tqdm from sglang.bench_serving import ( RequestFuncOutput, get_tokenizer, remove_prefix, sample_random_requests, ) AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60) def parse_args(): parser = argparse.ArgumentParser( description="Script to benchmark concurrent requests to a server." ) parser.add_argument( "--num-clients", type=int, default=256, help="Number of concurrent clients", ) parser.add_argument( "--max-parallel", type=int, default=128, help="Maximum number of parallel requests", ) parser.add_argument( "--request-length", type=int, default=512, help="Length of each new request", ) parser.add_argument( "--output-length", type=int, default=64, help="Length of each output", ) parser.add_argument( "--num-rounds", type=int, default=5, help="Number of rounds per client", ) parser.add_argument( "--distribution", type=str, default="poisson", choices=["poisson", "uniform"], help="Distribution type for request intervals (poisson or uniform)", ) parser.add_argument( "--request-rate", type=float, default=1.0, help="Average number of requests per second", ) parser.add_argument( "--host", type=str, default="localhost", help="Server hostname or IP (default: localhost)", ) parser.add_argument( "--port", type=int, default=30000, help="Server port (default: 30000)", ) parser.add_argument( "--model-path", type=str, default="meta-llama/Llama-3.1-8B-Instruct", help="model path compatible with Hugging Face Transformers", ) parser.add_argument( "--dataset-path", type=str, default="", help="local dataset to sample tokens from", ) parser.add_argument( "--log-file", type=str, default="performance_metrics.jsonl", help="File to log performance metrics", ) parser.add_argument( "--disable-auto-run", action="store_true", help="If set, disable automatically testing with a range of request rates.", ) parser.add_argument( "--disable-random-sample", action="store_true", help="If set, disable random sampling of requests from the ShareGPT dataset.", ) parser.add_argument( "--sub-question-input-length", type=int, default=0, help="Length of the sub question input for each request, if set 0 use request_length", ) parser.add_argument( "--ready-queue-policy", type=str, default="random", help="Policy for popping requests from the ready queue (random or fifo)", ) parser.add_argument( "--tag", type=str, default="", help="Tag of a certain run in the log file", ) parser.add_argument("--seed", type=int, default=1, help="The random seed.") parser.add_argument( "--lora-path", type=str, default="", help="String of LoRA path. Currently we only support benchmarking on a single LoRA adaptor.", ) return parser.parse_args() async def async_request_sglang_generate( payload, url, pbar: Optional[tqdm] = None, ): """ Sends a streaming request to the server. Gathers text token-by-token. """ async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: headers = {} generated_text = "" ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st output = RequestFuncOutput() try: async with session.post(url=url, json=payload, headers=headers) as response: if response.status == 200: prompt_tokens = 0 cached_tokens = 0 async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") latency = time.perf_counter() - st if chunk == "[DONE]": pass else: data = json.loads(chunk) if data["text"]: timestamp = time.perf_counter() # First token if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft prompt_tokens = (data.get("meta_info") or {}).get( "prompt_tokens", 0 ) cached_tokens = (data.get("meta_info") or {}).get( "cached_tokens", 0 ) # Decoding phase else: output.itl.append(timestamp - most_recent_timestamp) most_recent_timestamp = timestamp generated_text = data["text"] output.generated_text = generated_text output.success = True output.latency = latency output.prompt_len = prompt_tokens output.cached_tokens = cached_tokens output.generated_len = len(output.itl) + 1 else: output.error = response.reason or "" output.success = False except Exception as e: output.success = False output.error = str(e) print(f"Request failed: {e}") if pbar: pbar.update(1) return output def gen_payload(prompt, output_len, lora_path=""): payload = { "text": prompt, "sampling_params": { "temperature": 0.0, "max_new_tokens": output_len, "ignore_eos": True, }, "stream": True, "stream_options": {"include_usage": True}, "lora_path": lora_path, "return_logprob": False, "logprob_start_len": -1, } return payload def log_to_jsonl_file(data, file_path="performance_metrics.jsonl", tag=""): """Append the data with a timestamp and tag to the specified JSONL file.""" timestamped_data = {"timestamp": datetime.now().isoformat(), "tag": tag, **data} try: with open(file_path, "a") as file: file.write( json.dumps(timestamped_data) + "\n" ) # Write as a single line in JSONL format except IOError as e: print(f"Error writing to JSONL file: {e}") class ReadyQueue: """ Thread-safe queue that can pop requests in different orders based on given policy. """ def __init__(self, init_requests=None, policy="random"): self.lock = threading.Lock() self.requests = init_requests or [] self.policy = policy def append(self, item): with self.lock: self.requests.append(item) def pop(self): with self.lock: if not self.requests: return None if self.policy == "random": index = random.randrange(len(self.requests)) return self.requests.pop(index) elif self.policy == "fifo": return self.requests.pop(0) else: # todo, varying thinking time of clients raise ValueError(f"{self.policy} not implemented") class WorkloadGenerator: def __init__(self, args): # Construct the base URL for requests self.url = f"http://{args.host}:{args.port}/generate" self.tokenizer = get_tokenizer(args.model_path) self.distribution = args.distribution self.request_rate = args.request_rate self.start_time = None self.finished_time = None self.sent_requests = 0 self.completed_requests = 0 self.candidate_inputs = sample_random_requests( input_len=args.request_length, output_len=args.output_length, num_prompts=args.num_clients, range_ratio=1.0, tokenizer=self.tokenizer, dataset_path=args.dataset_path, random_sample=not args.disable_random_sample, ) self.candidate_inputs = [i.prompt for i in self.candidate_inputs] if args.sub_question_input_length != 0: sub_question_input_length = args.sub_question_input_length else: sub_question_input_length = args.request_length self.sub_question_inputs = sample_random_requests( input_len=sub_question_input_length, output_len=args.output_length, num_prompts=args.num_clients * max(args.num_rounds - 1, 1), range_ratio=1.0, tokenizer=self.tokenizer, dataset_path=args.dataset_path, random_sample=not args.disable_random_sample, ) init_requests = [ ( i, gen_payload( self.candidate_inputs[i], args.output_length, args.lora_path ), ) for i in range(args.num_clients) ] self.client_records = { i: {"round": 0, "history": init_requests[i][1]["text"]} for i in range(args.num_clients) } self.ready_queue = ReadyQueue( init_requests=init_requests, policy=args.ready_queue_policy ) self.candidate_inputs = self.candidate_inputs[args.num_clients :] self.response_queue = queue.Queue() self.pbar = tqdm(total=args.num_clients * args.num_rounds) self.performance_metrics = { "ttft": [], "latency": [], "prompt_len": [], "cached_tokens": [], "generated_len": [], } self.num_rounds = args.num_rounds self.max_parallel = args.max_parallel self.output_length = args.output_length async def handle_request(self, item): try: client_id, payload = item response = await async_request_sglang_generate(payload, self.url, self.pbar) if self.pbar.n == self.pbar.total: self.finished_time = time.perf_counter() self.response_queue.put((client_id, response)) except Exception as e: print(f"Request failed: {e}") def request_sender(self): async def request_loop(): while True: if self.sent_requests - self.completed_requests < self.max_parallel: new_request = self.ready_queue.pop() if new_request: asyncio.create_task(self.handle_request(new_request)) self.sent_requests += 1 else: await asyncio.sleep(0.05) continue if self.pbar.n == self.pbar.total: break # Calculate Poisson-distributed wait time if self.distribution == "poisson": sleep_time = random.expovariate(self.request_rate) elif self.distribution == "uniform": avg_interval = ( 1.0 / self.request_rate if self.request_rate > 0 else 1.0 ) sleep_time = random.uniform(0, 2 * avg_interval) else: raise ValueError("Invalid distribution type") await asyncio.sleep(sleep_time) # Wait before sending the next request # Create and run the event loop for asynchronous requests loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(request_loop()) loop.close() def response_handler(self): while True: try: client_id, response = self.response_queue.get( timeout=10 ) # Block until response is available if not response.success: raise ValueError(f"Request failed with error: {response.error}") self.client_records[client_id]["history"] += response.generated_text self.client_records[client_id]["round"] += 1 self.performance_metrics["ttft"].append(response.ttft) self.performance_metrics["latency"].append(response.latency) self.performance_metrics["prompt_len"].append(response.prompt_len) self.performance_metrics["cached_tokens"].append(response.cached_tokens) self.performance_metrics["generated_len"].append(response.generated_len) self.completed_requests += 1 if self.client_records[client_id]["round"] < self.num_rounds: # append new request to client's history self.client_records[client_id][ "history" ] += self.sub_question_inputs.pop().prompt self.ready_queue.append( ( client_id, gen_payload( self.client_records[client_id]["history"], self.output_length, args.lora_path, ), ) ) except queue.Empty: if self.pbar.n == self.pbar.total: break except ValueError as e: print(f"Error processing response for client {client_id}: {e}") continue def run(self): request_thread = threading.Thread(target=self.request_sender, daemon=True) response_thread = threading.Thread(target=self.response_handler, daemon=True) self.start_time = time.perf_counter() request_thread.start() response_thread.start() request_thread.join() response_thread.join() self.pbar.close() duration = self.finished_time - self.start_time performance_data = { "summary": { "total_requests": len(self.performance_metrics["ttft"]), "request_rate": self.request_rate, "average_ttft": sum(self.performance_metrics["ttft"]) / len(self.performance_metrics["ttft"]), "p90_ttft": sorted(self.performance_metrics["ttft"])[ int(0.9 * len(self.performance_metrics["ttft"])) ], "median_ttft": sorted(self.performance_metrics["ttft"])[ len(self.performance_metrics["ttft"]) // 2 ], "average_latency": sum(self.performance_metrics["latency"]) / len(self.performance_metrics["latency"]), "p90_latency": sorted(self.performance_metrics["latency"])[ int(0.9 * len(self.performance_metrics["latency"])) ], "median_latency": sorted(self.performance_metrics["latency"])[ len(self.performance_metrics["latency"]) // 2 ], "input_token_throughput": sum(self.performance_metrics["prompt_len"]) / duration, "output_token_throughput": sum( self.performance_metrics["generated_len"] ) / duration, "throughput": self.pbar.total / duration, "cache_hit_rate": ( 0 if sum(self.performance_metrics["prompt_len"]) == 0 else sum(self.performance_metrics["cached_tokens"]) / sum(self.performance_metrics["prompt_len"]) ), }, } print("All requests completed") print("Performance metrics summary:") print( f" Total requests: {performance_data['summary']['total_requests']} at {performance_data['summary']['request_rate']} requests per second" ) print(f" Average TTFT: {performance_data['summary']['average_ttft']:.2f}") print(f" P90 TTFT: {performance_data['summary']['p90_ttft']:.2f}") print(f" Median TTFT: {performance_data['summary']['median_ttft']:.2f}") print( f" Average latency: {performance_data['summary']['average_latency']:.2f}" ) print(f" P90 latency: {performance_data['summary']['p90_latency']:.2f}") print(f" Median latency: {performance_data['summary']['median_latency']:.2f}") print( f" Input token throughput: {performance_data['summary']['input_token_throughput']:.2f} tokens per second" ) print( f" Output token throughput: {performance_data['summary']['output_token_throughput']:.2f} tokens per second" ) print( f" Request Throughput: {performance_data['summary']['throughput']:.2f} requests per second" ) print(f" Cache Hit Rate: {performance_data['summary']['cache_hit_rate']:.6f}") return performance_data if __name__ == "__main__": args = parse_args() flush_cache_url = f"http://{args.host}:{args.port}/flush_cache" random.seed(args.seed) np.random.seed(args.seed) if args.disable_auto_run: print("Running with specified request rate...") request_rates = [args.request_rate] else: print("Auto-running with different request rates...") request_rates = [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1] for rate in request_rates: args.request_rate = rate requests.post(flush_cache_url) time.sleep(1) performance_data = WorkloadGenerator(args).run() log_to_jsonl_file(performance_data, args.log_file, tag=args.tag)