sglang_v0.5.2/sglang/scripts/playground/replay_request_dump.py

152 lines
4.7 KiB
Python

"""
Usage:
# replay from a folder
python3 replay_request_dump.py --file-number 100 --parallel 512 --input-folder /data/lianmin/sglang_request_dump/grok-mini-0220-engine-5756f8f94-28bm6/
# replay from a single file
python3 replay_request_dump.py --parallel 512 --input-file /data/sglang_crash_dump/memx-cti-34-sr1.xpop.twttr.net/crash_dump_2025-06-04_20-13-18.pkl
"""
import argparse
import glob
import json
import pickle
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
from datetime import datetime
import requests
from sglang.bench_serving import set_ulimit
from sglang.utils import get_exception_traceback
def read_records(files):
records = []
for f in files:
tmp = pickle.load(open(f, "rb"))
if isinstance(tmp, dict) and "requests" in tmp:
records.extend(tmp["requests"])
else:
records.extend(tmp)
return records
def run_one_request_internal(record):
(req, output, replay_init_time, start_time, end_time, idx) = record
time.sleep(max(0, (start_time - (time.time() - replay_init_time)) / args.speed))
if "completion_tokens" in output.get("meta_info", {}):
recorded_completion_tokens = output["meta_info"]["completion_tokens"]
else:
recorded_completion_tokens = ""
json_data = asdict(req)
stream = json_data["stream"]
if args.ignore_eos:
json_data["sampling_params"]["ignore_eos"] = True
if recorded_completion_tokens:
json_data["sampling_params"]["max_new_tokens"] = recorded_completion_tokens
response = requests.post(
f"http://{args.host}:{args.port}/generate",
json=json_data,
stream=stream,
)
if stream:
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
ret = json.loads(chunk[5:].strip("\n"))
else:
ret = response.json()
prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
print(
f"{idx=}, {start_time=:.2f}, {prompt_tokens=}, "
f"{completion_tokens=}, {recorded_completion_tokens=}"
)
def run_one_request(record):
# global success_ct, error_ct
try:
run_one_request_internal(record)
# success_ct += 1
except Exception:
# error_ct += 1
traceback = get_exception_traceback()
print(f"Hit an exception: {traceback}")
def main(records):
if len(records) == 0:
return
base_time = records[0][-2]
base_time_str = datetime.fromtimestamp(base_time).strftime("%y-%m-%d %H:%M:%S")
print(f"{base_time_str=}")
replay_init_time = time.time()
for i in range(len(records)):
req, output, start_time, end_time = records[i]
start_time -= base_time
records[i] = (req, output, replay_init_time, start_time, end_time, i)
with ThreadPoolExecutor(args.parallel) as executor:
executor.map(run_one_request, records)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=30000)
parser.add_argument(
"--input-folder", type=str, default=None, help="Folder containing pickle files"
)
parser.add_argument(
"--input-file", type=str, default=None, help="Single pickle file to process"
)
parser.add_argument("--file-number", type=int, default=1)
parser.add_argument("--req-number", type=int, default=1000000)
parser.add_argument("--req-start", type=int, default=0)
parser.add_argument("--parallel", type=int, default=512)
parser.add_argument("--idx", type=int, default=None)
parser.add_argument("--ignore-eos", action="store_true")
parser.add_argument("--speed", type=float, default=1)
args = parser.parse_args()
set_ulimit()
files = []
if args.input_file:
files = [args.input_file]
if args.file_number > 1:
print("Warning: --file-number is ignored when --input-file is provided.")
elif args.input_folder:
files = glob.glob(f"{args.input_folder}/*.pkl")
files = files[: args.file_number]
else:
print("Error: Either --input-folder or --input-file must be provided.")
exit(1)
print(f"{files=}")
records = read_records(files)
# Sort by the receive time, before filtering
records.sort(key=lambda x: x[-2])
records = records[args.req_start :]
if args.idx:
records = [records[args.idx]]
print(f"testing {args.idx=}")
print(f"{records[0]}")
print(f"{len(records)=}")
main(records)