import argparse import asyncio import json import logging import os import queue import random import threading import time from dataclasses import dataclass from functools import wraps import aiohttp from sglang.bench_serving import ( RequestFuncOutput, get_tokenizer, remove_prefix, sample_random_requests, ) # Set up logger logger = logging.getLogger(__name__) # Set up JSONL file for debug logging debug_log_file = None # Create a lock for thread-safe debug log writing debug_log_lock = threading.Lock() def write_debug_log(data): global debug_log_file """Write debug information to a JSONL file""" if debug_log_file is None: return # Acquire lock for thread-safe writing with debug_log_lock: # Write as JSONL (JSON Line format) debug_log_file.write(json.dumps(data) + "\n") debug_log_file.flush() def parse_args(): parser = argparse.ArgumentParser( description="Script to benchmark concurrent requests to a server." ) parser.add_argument( "--model-path", type=str, default="/data/models/Qwen3-0.6B", help="model path compatible with Hugging Face Transformers", ) parser.add_argument( "--dataset-path", type=str, default="/data/models/ShareGPT_V3_unfiltered_cleaned_split/ShareGPT_V3_unfiltered_cleaned_split.json", help="local dataset to sample tokens from", ) parser.add_argument( "--host", type=str, default="localhost", help="Server hostname or IP (default: localhost)", ) parser.add_argument( "--port", type=int, default=30000, help="Server port (default: 30000)", ) parser.add_argument( "--duration", type=int, default=600, help="Duration to run the benchmark in seconds (default: 300 seconds)", ) parser.add_argument( "--log-level", type=str, default="info", choices=["debug", "info"], help="Set the logging level (default: info)", ) parser.add_argument( "--debug-log-file", type=str, default="debug.log.jsonl", help="File to write debug logs in JSONL format", ) return parser.parse_args() def load_config(): config_path = os.getenv("CONFIG_PATH") if not config_path: raise ValueError("Environment variable 'CONFIG_PATH' is not set.") with open(config_path, "r") as f: config = json.load(f) required_keys = [ "num_rounds", "num_clients", "round_ratios", "mean_new_tokens_per_round", "mean_return_tokens_per_round", "mean_inter_round_interval", ] for key in required_keys: if key not in config: raise KeyError(f"Missing required configuration key: {key}") num_rounds = config["num_rounds"] assert len(config["round_ratios"]) == num_rounds assert len(config["mean_new_tokens_per_round"]) == num_rounds assert len(config["mean_return_tokens_per_round"]) == num_rounds assert len(config["mean_inter_round_interval"]) == num_rounds print(config) return config @dataclass class UserData: user_id: int current_round: int total_rounds: int prompt: str return_tokens: int start: int def synchronized(): def _decorator(func): @wraps(func) def wrapper(self, *args, **kwargs): with self.lock: return func(self, *args, **kwargs) return wrapper return _decorator class UserGenerator: def __init__(self, config, model_path, dataset_path): self.tokenizer_path = model_path self.tokenizer = get_tokenizer(self.tokenizer_path) self.dataset_path = dataset_path self.user_id = 0 self.lock = threading.Lock() self.num_rounds = config["num_rounds"] self.cumulative_ratios = [ sum(config["round_ratios"][: i + 1]) for i in range(len(config["round_ratios"])) ] self.mean_new_tokens_per_round = config["mean_new_tokens_per_round"] self.mean_return_tokens_per_round = config["mean_return_tokens_per_round"] self.mean_inter_round_interval = config["mean_inter_round_interval"] self.sigma = 100 self.range_ratio = 0.8 assert self.range_ratio <= 1 self.candidate_inputs = [ [ r for r in sample_random_requests( input_len=( self.mean_new_tokens_per_round[i] * (2 - self.range_ratio) ), output_len=( self.mean_return_tokens_per_round[i] * (2 - self.range_ratio) ), num_prompts=config["num_clients"], range_ratio=self.range_ratio / (2 - self.range_ratio), tokenizer=self.tokenizer, dataset_path=self.dataset_path, random_sample=False, ) ] for i in range(self.num_rounds) ] self.multiturn_queue = [] self.user_stats = [0 for _ in range(self.num_rounds)] self.input_stats = [[0, 0] for _ in range(self.num_rounds)] self.output_stats = [[0, 0] for _ in range(self.num_rounds)] def gen(self): user_id = self.user_id self.user_id += 1 rand_ratio = random.randint(0, self.cumulative_ratios[-1]) i = len(self.cumulative_ratios) for idx, cumulative_ratio in enumerate(self.cumulative_ratios): if rand_ratio >= cumulative_ratio: continue else: i = idx + 1 break total_rounds = i current_round = 0 candidate_input = random.sample(self.candidate_inputs[current_round], 1)[0] self.input_stats[0][0] += candidate_input.prompt_len self.input_stats[0][1] += 1 prompt = f"{user_id} " + candidate_input.prompt return_tokens = int( random.gauss(self.mean_return_tokens_per_round[current_round], self.sigma) ) if return_tokens <= 0: return_tokens = self.mean_return_tokens_per_round[current_round] start = 0 user_data = UserData( user_id, current_round, total_rounds, prompt, return_tokens, start ) self.user_stats[total_rounds - 1] += 1 return user_data @synchronized() def push(self, user_data, generated_text, len_itl): self.output_stats[user_data.current_round][0] += len_itl + 1 self.output_stats[user_data.current_round][1] += 1 user_data.current_round += 1 if user_data.current_round >= user_data.total_rounds: return candidate_input = random.sample( self.candidate_inputs[user_data.current_round], 1 )[0] self.input_stats[user_data.current_round][0] += candidate_input.prompt_len self.input_stats[user_data.current_round][1] += 1 user_data.prompt += generated_text + candidate_input.prompt user_data.return_tokens = int( random.gauss( self.mean_return_tokens_per_round[user_data.current_round], self.sigma ) ) if user_data.return_tokens <= 0: user_data.return_tokens = self.mean_return_tokens_per_round[ user_data.current_round ] interval = random.gauss( self.mean_inter_round_interval[user_data.current_round], self.sigma ) if interval <= 0: interval = self.mean_inter_round_interval[user_data.current_round] user_data.start = time.perf_counter() + interval if len(self.multiturn_queue) == 0: self.multiturn_queue.append(user_data) else: i = len(self.multiturn_queue) for idx, d in enumerate(self.multiturn_queue): if user_data.start < d.start: i = idx break self.multiturn_queue.insert(idx, user_data) @synchronized() def pop(self): if ( len(self.multiturn_queue) and time.perf_counter() > self.multiturn_queue[0].start ): return self.multiturn_queue.pop(0) return self.gen() def gen_payload(prompt, output_len): payload = { "text": prompt, "sampling_params": { "temperature": 0.0, "max_new_tokens": output_len, "ignore_eos": True, }, "stream": True, "stream_options": {"include_usage": True}, "lora_path": "", "return_logprob": False, "logprob_start_len": -1, } return payload AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60) async def async_request_sglang_generate( user_data, url, atomic_counter, ): """ Sends a streaming request to the server. Gathers text token-by-token. """ async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: headers = {} generated_text = "" ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st output = RequestFuncOutput() payload = gen_payload(user_data.prompt, user_data.return_tokens) write_debug_log({"timestamp": st, "user_data": user_data.__dict__}) try: async with session.post(url=url, json=payload, headers=headers) as response: if response.status == 200: prompt_tokens = 0 cached_tokens = 0 async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") latency = time.perf_counter() - st if chunk == "[DONE]": pass else: data = json.loads(chunk) if data.get("text"): timestamp = time.perf_counter() # First token if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft prompt_tokens = (data.get("meta_info") or {}).get( "prompt_tokens", 0 ) cached_tokens = (data.get("meta_info") or {}).get( "cached_tokens", 0 ) # Decoding phase else: output.itl.append(timestamp - most_recent_timestamp) most_recent_timestamp = timestamp generated_text = data["text"] output.generated_text = generated_text output.success = True output.latency = latency output.prompt_len = prompt_tokens output.cached_tokens = cached_tokens else: output.error = response.reason or "" output.success = False except Exception as e: output.success = False output.error = str(e) print(f"Request failed: {e}") atomic_counter.increment(1) return output class AtomicCounter: def __init__(self, initial_value=0): self._value = initial_value self.lock = threading.Lock() @synchronized() def increment(self, amount=1): self._value += amount @synchronized() def get(self): return self._value class WorkloadGenerator: def __init__(self, args): config = load_config() user_generator = UserGenerator( config, args.model_path, args.dataset_path, ) self.url = f"http://{args.host}:{args.port}/generate" self.tokenizer = user_generator.tokenizer self.start_time = None self.finished_time = None self.duration = args.duration self.done = False self.sent_requests = 0 self.completed_requests = 0 self.user_generator = user_generator self.response_queue = queue.Queue() self.performance_metrics = { "ttft": [], "latency": [], "prompt_len": [], "cached_tokens": [], } self.max_parallel = config["num_clients"] self.atomic_counter = AtomicCounter() async def handle_request(self, user_data): try: response = await async_request_sglang_generate( user_data, self.url, self.atomic_counter ) self.response_queue.put((user_data, response)) except Exception as e: print(f"Request failed: {e}") self.completed_requests += 1 def request_sender(self): async def request_loop(): while True: if self.sent_requests - self.completed_requests < self.max_parallel: new_request = self.user_generator.pop() if new_request: asyncio.create_task(self.handle_request(new_request)) self.sent_requests += 1 else: await asyncio.sleep(0.05) continue if time.perf_counter() - self.start_time > self.duration: self.done = True break loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(request_loop()) loop.close() def response_handler(self): while True: try: user_data, response = self.response_queue.get(timeout=10) logger.info( f"{((time.perf_counter()-self.start_time)/self.duration*100):.2f}%" ) if not response.success: raise ValueError(f"Request failed with error: {response.error}") self.user_generator.push( user_data, response.generated_text, len(response.itl) ) self.performance_metrics["ttft"].append(response.ttft) self.performance_metrics["latency"].append(response.latency) self.performance_metrics["prompt_len"].append(response.prompt_len) self.performance_metrics["cached_tokens"].append(response.cached_tokens) self.completed_requests += 1 self.finished_time = time.perf_counter() except queue.Empty: if self.done: break except ValueError as e: print(f"Error processing response for client {user_data}: {e}") continue def run(self): request_thread = threading.Thread(target=self.request_sender, daemon=True) response_thread = threading.Thread(target=self.response_handler, daemon=True) self.start_time = time.perf_counter() request_thread.start() response_thread.start() request_thread.join() response_thread.join() performance_data = { "summary": { "total_requests": len(self.performance_metrics["ttft"]), "average_ttft": sum(self.performance_metrics["ttft"]) / len(self.performance_metrics["ttft"]), "p90_ttft": sorted(self.performance_metrics["ttft"])[ int(0.9 * len(self.performance_metrics["ttft"])) ], "median_ttft": sorted(self.performance_metrics["ttft"])[ len(self.performance_metrics["ttft"]) // 2 ], "average_latency": sum(self.performance_metrics["latency"]) / len(self.performance_metrics["latency"]), "p90_latency": sorted(self.performance_metrics["latency"])[ int(0.9 * len(self.performance_metrics["latency"])) ], "median_latency": sorted(self.performance_metrics["latency"])[ len(self.performance_metrics["latency"]) // 2 ], "throughput": self.atomic_counter.get() / (self.finished_time - self.start_time), "cache_hit_rate": ( 0 if sum(self.performance_metrics["prompt_len"]) == 0 else sum(self.performance_metrics["cached_tokens"]) / sum(self.performance_metrics["prompt_len"]) ), }, } print("All requests completed") print("Performance metrics summary:") print(f" Total requests: {performance_data['summary']['total_requests']}") print(f" Average TTFT: {performance_data['summary']['average_ttft']:.2f}") print(f" P90 TTFT: {performance_data['summary']['p90_ttft']:.2f}") print(f" Median TTFT: {performance_data['summary']['median_ttft']:.2f}") print( f" Average latency: {performance_data['summary']['average_latency']:.2f}" ) print(f" P90 latency: {performance_data['summary']['p90_latency']:.2f}") print(f" Median latency: {performance_data['summary']['median_latency']:.2f}") print( f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second" ) print(f" Cache Hit Rate: {performance_data['summary']['cache_hit_rate']:.6f}") user_stats = self.user_generator.user_stats input_stats = self.user_generator.input_stats output_stats = self.user_generator.output_stats print(f"round_ratios: {user_stats}") print( f"mean_new_tokens_per_round: {[int(a/b) if b > 0 else 0 for a, b in input_stats]}" ) print( f"mean_return_tokens_per_round: {[int(a/b) if b > 0 else 0 for a, b in output_stats]}" ) return performance_data def main(): global debug_log_file args = parse_args() if args.log_level == "debug": logging.basicConfig(level=logging.DEBUG) logger.info("use log_level debug") # Initialize debug log file debug_log_file = open(args.debug_log_file, "w") else: logging.basicConfig(level=logging.INFO) logger.info("use log_level info") performance_data = WorkloadGenerator(args).run() # Close debug log file if it was opened if debug_log_file: debug_log_file.close() if __name__ == "__main__": main()