57 lines
1.7 KiB
Python
57 lines
1.7 KiB
Python
import logging
|
|
from typing import List
|
|
|
|
import numpy as np
|
|
import tqdm
|
|
|
|
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST
|
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
|
|
|
logger = logging.getLogger(__file__)
|
|
|
|
_warmup_registry = {}
|
|
|
|
|
|
def warmup(name: str) -> callable:
|
|
def decorator(fn: callable):
|
|
_warmup_registry[name] = fn
|
|
return fn
|
|
|
|
return decorator
|
|
|
|
|
|
async def execute_warmups(
|
|
disaggregation_mode: str,
|
|
warmup_names: List[str],
|
|
tokenizer_manager: TokenizerManager,
|
|
):
|
|
for warmup_name in warmup_names:
|
|
if warmup_name not in _warmup_registry:
|
|
logger.warning(f"Could not find custom warmup {warmup_name}")
|
|
continue
|
|
logger.info(f"Running warmup {warmup_name}")
|
|
await _warmup_registry[warmup_name](disaggregation_mode, tokenizer_manager)
|
|
|
|
|
|
@warmup("voice_chat")
|
|
async def voice_chat(disaggregation_mode: str, tokenizer_manager: TokenizerManager):
|
|
# this warms up the fused_moe triton kernels and caches them
|
|
# if we don't do this we break real time inference for voice chat
|
|
for i in tqdm.trange(1, 512):
|
|
size = i * 4
|
|
generate_req_input = GenerateReqInput(
|
|
input_ids=(np.random.randint(2**16, size=[size])).tolist(),
|
|
sampling_params={
|
|
"max_new_tokens": 30,
|
|
"temperature": 0.8,
|
|
"stop_token_ids": [1],
|
|
"min_p": 0.0,
|
|
},
|
|
)
|
|
if disaggregation_mode != "null":
|
|
generate_req_input.bootstrap_room = 0
|
|
generate_req_input.bootstrap_host = FAKE_BOOTSTRAP_HOST
|
|
|
|
await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
|