152 lines
4.7 KiB
Python
152 lines
4.7 KiB
Python
"""
|
|
Usage:
|
|
# replay from a folder
|
|
python3 replay_request_dump.py --file-number 100 --parallel 512 --input-folder /data/lianmin/sglang_request_dump/grok-mini-0220-engine-5756f8f94-28bm6/
|
|
|
|
# replay from a single file
|
|
python3 replay_request_dump.py --parallel 512 --input-file /data/sglang_crash_dump/memx-cti-34-sr1.xpop.twttr.net/crash_dump_2025-06-04_20-13-18.pkl
|
|
"""
|
|
|
|
import argparse
|
|
import glob
|
|
import json
|
|
import pickle
|
|
import time
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from dataclasses import asdict
|
|
from datetime import datetime
|
|
|
|
import requests
|
|
|
|
from sglang.bench_serving import set_ulimit
|
|
from sglang.utils import get_exception_traceback
|
|
|
|
|
|
def read_records(files):
|
|
records = []
|
|
for f in files:
|
|
tmp = pickle.load(open(f, "rb"))
|
|
if isinstance(tmp, dict) and "requests" in tmp:
|
|
records.extend(tmp["requests"])
|
|
else:
|
|
records.extend(tmp)
|
|
|
|
return records
|
|
|
|
|
|
def run_one_request_internal(record):
|
|
(req, output, replay_init_time, start_time, end_time, idx) = record
|
|
time.sleep(max(0, (start_time - (time.time() - replay_init_time)) / args.speed))
|
|
|
|
if "completion_tokens" in output.get("meta_info", {}):
|
|
recorded_completion_tokens = output["meta_info"]["completion_tokens"]
|
|
else:
|
|
recorded_completion_tokens = ""
|
|
|
|
json_data = asdict(req)
|
|
stream = json_data["stream"]
|
|
|
|
if args.ignore_eos:
|
|
json_data["sampling_params"]["ignore_eos"] = True
|
|
if recorded_completion_tokens:
|
|
json_data["sampling_params"]["max_new_tokens"] = recorded_completion_tokens
|
|
|
|
response = requests.post(
|
|
f"http://{args.host}:{args.port}/generate",
|
|
json=json_data,
|
|
stream=stream,
|
|
)
|
|
|
|
if stream:
|
|
for chunk in response.iter_lines(decode_unicode=False):
|
|
chunk = chunk.decode("utf-8")
|
|
if chunk and chunk.startswith("data:"):
|
|
if chunk == "data: [DONE]":
|
|
break
|
|
ret = json.loads(chunk[5:].strip("\n"))
|
|
else:
|
|
ret = response.json()
|
|
|
|
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
|
completion_tokens = ret["meta_info"]["completion_tokens"]
|
|
print(
|
|
f"{idx=}, {start_time=:.2f}, {prompt_tokens=}, "
|
|
f"{completion_tokens=}, {recorded_completion_tokens=}"
|
|
)
|
|
|
|
|
|
def run_one_request(record):
|
|
# global success_ct, error_ct
|
|
|
|
try:
|
|
run_one_request_internal(record)
|
|
# success_ct += 1
|
|
except Exception:
|
|
# error_ct += 1
|
|
traceback = get_exception_traceback()
|
|
print(f"Hit an exception: {traceback}")
|
|
|
|
|
|
def main(records):
|
|
if len(records) == 0:
|
|
return
|
|
|
|
base_time = records[0][-2]
|
|
base_time_str = datetime.fromtimestamp(base_time).strftime("%y-%m-%d %H:%M:%S")
|
|
print(f"{base_time_str=}")
|
|
replay_init_time = time.time()
|
|
|
|
for i in range(len(records)):
|
|
req, output, start_time, end_time = records[i]
|
|
start_time -= base_time
|
|
records[i] = (req, output, replay_init_time, start_time, end_time, i)
|
|
|
|
with ThreadPoolExecutor(args.parallel) as executor:
|
|
executor.map(run_one_request, records)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--host", type=str, default="localhost")
|
|
parser.add_argument("--port", type=int, default=30000)
|
|
parser.add_argument(
|
|
"--input-folder", type=str, default=None, help="Folder containing pickle files"
|
|
)
|
|
parser.add_argument(
|
|
"--input-file", type=str, default=None, help="Single pickle file to process"
|
|
)
|
|
parser.add_argument("--file-number", type=int, default=1)
|
|
parser.add_argument("--req-number", type=int, default=1000000)
|
|
parser.add_argument("--req-start", type=int, default=0)
|
|
parser.add_argument("--parallel", type=int, default=512)
|
|
parser.add_argument("--idx", type=int, default=None)
|
|
parser.add_argument("--ignore-eos", action="store_true")
|
|
parser.add_argument("--speed", type=float, default=1)
|
|
args = parser.parse_args()
|
|
|
|
set_ulimit()
|
|
|
|
files = []
|
|
if args.input_file:
|
|
files = [args.input_file]
|
|
if args.file_number > 1:
|
|
print("Warning: --file-number is ignored when --input-file is provided.")
|
|
elif args.input_folder:
|
|
files = glob.glob(f"{args.input_folder}/*.pkl")
|
|
files = files[: args.file_number]
|
|
else:
|
|
print("Error: Either --input-folder or --input-file must be provided.")
|
|
exit(1)
|
|
print(f"{files=}")
|
|
|
|
records = read_records(files)
|
|
# Sort by the receive time, before filtering
|
|
records.sort(key=lambda x: x[-2])
|
|
records = records[args.req_start :]
|
|
if args.idx:
|
|
records = [records[args.idx]]
|
|
print(f"testing {args.idx=}")
|
|
print(f"{records[0]}")
|
|
print(f"{len(records)=}")
|
|
main(records)
|