import logging from typing import List import numpy as np import tqdm from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.tokenizer_manager import TokenizerManager logger = logging.getLogger(__file__) _warmup_registry = {} def warmup(name: str) -> callable: def decorator(fn: callable): _warmup_registry[name] = fn return fn return decorator async def execute_warmups(warmup_names: List[str], tokenizer_manager: TokenizerManager): for warmup_name in warmup_names: if warmup_name not in _warmup_registry: logger.warning(f"Could not find custom warmup {warmup_name}") continue logger.info(f"Running warmup {warmup_name}") await _warmup_registry[warmup_name](tokenizer_manager) @warmup("voice_chat") async def voice_chat(tokenizer_manager: TokenizerManager): # this warms up the fused_moe triton kernels and caches them # if we don't do this we break real time inference for voice chat for i in tqdm.trange(1, 512): size = i * 4 generate_req_input = GenerateReqInput( input_ids=(np.random.randint(2**16, size=[size])).tolist(), sampling_params={ "max_new_tokens": 30, "temperature": 0.8, "stop_token_ids": [1], "min_p": 0.0, }, ) await tokenizer_manager.generate_request(generate_req_input, None).__anext__()