import json import random import time from argparse import ArgumentParser from pathlib import Path from tqdm import tqdm import sglang as sgl from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.test.test_utils import ( add_common_sglang_args_and_parse, select_sglang_backend, ) from sglang.utils import dump_state_text def gen_prompt(tokenizer, token_num): all_available_tokens = list(tokenizer.get_vocab().values()) selected_tokens = random.choices(all_available_tokens, k=token_num) ret = tokenizer.decode(selected_tokens) return ret def get_cache_path(args): # Create cache directory under ~/.cache/sglang cache_dir = Path.home() / ".cache" / "sglang" # Create a unique cache filename based on the arguments that affect generation cache_key = f"qa_{args.num_qa}_{args.turns}_{args.system_prompt_len}_{args.len_q}_{args.len_a}_{args.tokenizer.replace('/', '_')}.json" return cache_dir / cache_key def gen_arguments(args, tokenizer): cache_path = get_cache_path(args) # Try to load from cache first if cache_path.exists(): print(f"Loading cached arguments from {cache_path}") with open(cache_path, "r") as f: return json.load(f) print("Generating new arguments...") # First progress bar for system prompts multi_qas = [] for _ in tqdm(range(args.num_qa), desc="Generating system prompts"): multi_qas.append( {"system_prompt": gen_prompt(tokenizer, args.system_prompt_len), "qas": []} ) # Nested progress bars for QA pairs for i in tqdm(range(args.num_qa), desc="Generating QA pairs"): qas = multi_qas[i]["qas"] for j in range(args.turns): qas.append( { "prompt": gen_prompt(tokenizer, args.len_q), "new_tokens": args.len_a, } ) # Save to cache cache_path.parent.mkdir(parents=True, exist_ok=True) with open(cache_path, "w") as f: json.dump(multi_qas, f) print(f"Cached arguments saved to {cache_path}") return multi_qas @sgl.function def multi_turns(s, system_prompt, qas): s += system_prompt for i, qa in enumerate(qas): s += qa["prompt"] s += sgl.gen(max_tokens=qa["new_tokens"], ignore_eos=True) def main(args): tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code) multi_qas = gen_arguments(args, tokenizer) backend = select_sglang_backend(args) tic = time.time() states = multi_turns.run_batch( multi_qas, temperature=0, backend=backend, num_threads="auto", progress_bar=True, ) latency = time.time() - tic print(f"Latency: {latency:.3f}") dump_state_text(f"tmp_output_{args.backend}.txt", states) with open(args.result_file, "a") as fout: value = { "task": "multi_turn_system_prompt_chat", "backend": args.backend, "latency": round(latency, 3), "num_requests": args.num_qa, "num_turns": args.turns, "other": { "parallel": args.parallel, }, } fout.write(json.dumps(value) + "\n") if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("--turns", type=int, default=8) parser.add_argument("--num-qa", type=int, default=128) parser.add_argument("--system-prompt-len", type=int, default=2048) parser.add_argument("--len-q", type=int, default=32) parser.add_argument("--len-a", type=int, default=128) parser.add_argument( "--tokenizer", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct" ) parser.add_argument("--trust-remote-code", action="store_true") args = add_common_sglang_args_and_parse(parser) print(args) main(args)