sglang0.4.5.post1/python/sglang/srt/managers/scheduler.py

2023 lines
78 KiB
Python

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A scheduler that manages a tensor parallel GPU worker."""
import faulthandler
import logging
import os
import signal
import sys
import threading
import time
import warnings
from collections import defaultdict, deque
from concurrent import futures
from dataclasses import dataclass
from http import HTTPStatus
from types import SimpleNamespace
from typing import Dict, List, Optional, Tuple, Union
import psutil
import setproctitle
import torch
import zmq
from torch.distributed import barrier
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
from sglang.srt.disaggregation.decode import (
DecodePreallocQueue,
DecodeTransferQueue,
SchedulerDisaggregationDecodeMixin,
)
from sglang.srt.disaggregation.prefill import (
PrefillBootstrapQueue,
SchedulerDisaggregationPrefillMixin,
)
from sglang.srt.disaggregation.utils import (
DisaggregationMode,
ReqToMetadataIdxAllocator,
)
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from sglang.srt.managers.io_struct import (
AbortReq,
CloseSessionReqInput,
ExpertDistributionReq,
ExpertDistributionReqOutput,
FlushCacheReq,
GetInternalStateReq,
GetInternalStateReqOutput,
GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
ProfileReqOutput,
ProfileReqType,
ReleaseMemoryOccupationReqInput,
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput,
RpcReqInput,
RpcReqOutput,
SetInternalStateReq,
SetInternalStateReqOutput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput,
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
)
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
MultimodalInputs,
Req,
ScheduleBatch,
global_server_args_dict,
)
from sglang.srt.managers.schedule_policy import (
AddReqResult,
PrefillAdder,
SchedulePolicy,
)
from sglang.srt.managers.scheduler_output_processor_mixin import (
SchedulerOutputProcessorMixin,
)
from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import (
DynamicGradMode,
broadcast_pyobj,
configure_logger,
crash_on_warnings,
get_bool_env_var,
get_zmq_socket,
kill_itself_when_parent_died,
pyspy_dump_schedulers,
set_gpu_proc_affinity,
set_random_seed,
suppress_other_loggers,
)
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
expert_distribution_recorder = ExpertDistributionRecorder()
logger = logging.getLogger(__name__)
# Test retract decode for debugging purposes
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
@dataclass
class GenerationBatchResult:
logits_output: LogitsProcessorOutput
next_token_ids: List[int]
extend_input_len_per_req: List[int]
extend_logprob_start_len_per_req: List[int]
bid: int
@dataclass
class EmbeddingBatchResult:
embeddings: torch.Tensor
bid: int
class Scheduler(
SchedulerOutputProcessorMixin,
SchedulerDisaggregationDecodeMixin,
SchedulerDisaggregationPrefillMixin,
):
"""A scheduler that manages a tensor parallel GPU worker."""
def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
gpu_id: int,
tp_rank: int,
dp_rank: Optional[int],
):
# Parse args
self.server_args = server_args
self.tp_rank = tp_rank
self.tp_size = server_args.tp_size
self.schedule_policy = server_args.schedule_policy
self.lora_paths = server_args.lora_paths
self.max_loras_per_batch = server_args.max_loras_per_batch
self.enable_overlap = not server_args.disable_overlap_schedule
self.skip_tokenizer_init = server_args.skip_tokenizer_init
self.enable_metrics = server_args.enable_metrics
self.stream_interval = server_args.stream_interval
self.spec_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
self.gpu_id = gpu_id
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
self.page_size = server_args.page_size
# Distributed rank info
self.dp_size = server_args.dp_size
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
compute_dp_attention_world_info(
server_args.enable_dp_attention,
self.tp_rank,
self.tp_size,
self.dp_size,
)
)
# Init inter-process communication
context = zmq.Context(2)
if self.attn_tp_rank == 0:
self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
)
self.send_to_tokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
)
if server_args.skip_tokenizer_init:
# Directly send to the TokenizerManager
self.send_to_detokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
)
else:
# Send to the DetokenizerManager
self.send_to_detokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
)
self.recv_from_rpc = get_zmq_socket(
context, zmq.DEALER, port_args.rpc_ipc_name, False
)
else:
self.recv_from_tokenizer = None
self.recv_from_rpc = None
self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
# Init tokenizer
self.init_tokenizer()
# Check whether overlap can be enabled
if not self.is_generation:
self.enable_overlap = False
logger.info("Overlap scheduler is disabled for embedding models.")
if self.model_config.is_multimodal:
self.enable_overlap = False
logger.info("Overlap scheduler is disabled for multimodal models.")
# Launch a tensor parallel worker
if self.enable_overlap:
TpWorkerClass = TpModelWorkerClient
else:
TpWorkerClass = TpModelWorker
self.tp_worker = TpWorkerClass(
server_args=server_args,
gpu_id=gpu_id,
tp_rank=tp_rank,
dp_rank=dp_rank,
nccl_port=port_args.nccl_port,
)
# Launch a draft worker for speculative decoding
if self.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_worker import EAGLEWorker
self.draft_worker = EAGLEWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
else:
self.draft_worker = None
# Get token and memory info from the model worker
(
self.max_total_num_tokens,
self.max_prefill_tokens,
self.max_running_requests,
self.max_req_len,
self.max_req_input_len,
self.random_seed,
self.device,
worker_global_server_args_dict,
_,
_,
_,
) = self.tp_worker.get_worker_info()
self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
global_server_args_dict.update(worker_global_server_args_dict)
set_random_seed(self.random_seed)
# Print debug info
logger.info(
f"max_total_num_tokens={self.max_total_num_tokens}, "
f"chunked_prefill_size={server_args.chunked_prefill_size}, "
f"max_prefill_tokens={self.max_prefill_tokens}, "
f"max_running_requests={self.max_running_requests}, "
f"context_len={self.model_config.context_len}"
)
# Init memory pool and cache
self.init_memory_pool_and_cache()
# Init running status
self.waiting_queue: List[Req] = []
# The running decoding batch for continuous batching
self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
# The current forward batch
self.cur_batch: Optional[ScheduleBatch] = None
# The last forward batch
self.last_batch: Optional[ScheduleBatch] = None
self.forward_ct = 0
self.forward_ct_decode = 0
self.num_generated_tokens = 0
self.num_prefill_tokens = 0
self.last_decode_stats_tic = time.time()
self.last_prefill_stats_tic = time.time()
self.return_health_check_ct = 0
self.current_stream = torch.get_device_module(self.device).current_stream()
if self.device == "cpu":
self.current_stream.synchronize = lambda: None # No-op for CPU
# Init session info
self.sessions: Dict[str, Session] = {}
# Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
if self.chunked_prefill_size <= 0: # -1 means disable
self.chunked_prefill_size = None
self.chunked_req = None
self.is_mixed_chunk = (
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
)
# Init the grammar backend for constrained generation
self.grammar_queue: List[Req] = []
if not server_args.skip_tokenizer_init:
self.grammar_backend = create_grammar_backend(
server_args, self.tokenizer, self.model_config.vocab_size
)
else:
self.grammar_backend = None
# Init schedule policy and new token estimation
self.policy = SchedulePolicy(
self.schedule_policy,
self.tree_cache,
self.enable_hierarchical_cache,
)
assert (
server_args.schedule_conservativeness >= 0
), "Invalid schedule_conservativeness"
self.init_new_token_ratio = min(
global_config.default_init_new_token_ratio
* server_args.schedule_conservativeness,
1.0,
)
self.min_new_token_ratio = min(
self.init_new_token_ratio
* global_config.default_min_new_token_ratio_factor,
1.0,
)
self.new_token_ratio_decay = (
self.init_new_token_ratio - self.min_new_token_ratio
) / global_config.default_new_token_ratio_decay_steps
self.new_token_ratio = self.init_new_token_ratio
# Init watchdog thread
self.watchdog_timeout = server_args.watchdog_timeout
t = threading.Thread(target=self.watchdog_thread, daemon=True)
t.start()
self.parent_process = psutil.Process().parent()
# Init memory saver
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver
)
# Init profiler
self.torch_profiler = None
self.torch_profiler_output_dir: Optional[str] = None
self.profiler_activities: Optional[List[str]] = None
self.profiler_target_forward_ct: Optional[int] = None
# Init metrics stats
self.init_metrics()
# Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
[
(TokenizedGenerateReqInput, self.handle_generate_request),
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
(FlushCacheReq, self.flush_cache_wrapped),
(AbortReq, self.abort_request),
(OpenSessionReqInput, self.open_session),
(CloseSessionReqInput, self.close_session),
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
(
UpdateWeightsFromDistributedReqInput,
self.update_weights_from_distributed,
),
(UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
(GetWeightsByNameReqInput, self.get_weights_by_name),
(ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
(ProfileReq, self.profile),
(GetInternalStateReq, self.get_internal_state),
(SetInternalStateReq, self.set_internal_state),
(RpcReqInput, self.handle_rpc_request),
(ExpertDistributionReq, self.expert_distribution_handle),
]
)
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.init_disaggregation()
def init_tokenizer(self):
server_args = self.server_args
self.model_config = ModelConfig(
server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
dtype=server_args.dtype,
quantization=server_args.quantization,
)
self.is_generation = self.model_config.is_generation
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
else:
if self.model_config.is_multimodal:
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
)
self.tokenizer = self.processor.tokenizer
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
)
def init_memory_pool_and_cache(self):
server_args = self.server_args
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
self.tp_worker.get_memory_pool()
)
if (
server_args.chunked_prefill_size is not None
and server_args.disable_radix_cache
):
self.tree_cache = ChunkCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
else:
if self.enable_hierarchical_cache:
self.tree_cache = HiRadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
tp_cache_group=self.tp_worker.get_tp_cpu_group(),
page_size=self.page_size,
hicache_ratio=server_args.hicache_ratio,
)
else:
self.tree_cache = RadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
page_size=self.page_size,
disable=server_args.disable_radix_cache,
)
self.decode_mem_cache_buf_multiplier = (
1
if self.spec_algorithm.is_none()
else (
server_args.speculative_num_draft_tokens
+ (
server_args.speculative_eagle_topk
* server_args.speculative_num_steps
)
)
)
def init_metrics(self):
# The largest prefill length of a single request
self._largest_prefill_len: int = 0
# The largest context length (prefill + generation) of a single request
self._largest_prefill_decode_len: int = 0
self.last_gen_throughput: float = 0.0
self.last_input_throughput: float = 0.0
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
self.spec_num_total_accepted_tokens = 0
self.spec_num_total_forward_ct = 0
self.cum_spec_accept_length = 0
self.cum_spec_accept_count = 0
self.stats = SchedulerStats()
if self.enable_metrics:
engine_type = "unified"
self.metrics_collector = SchedulerMetricsCollector(
labels={
"model_name": self.server_args.served_model_name,
"engine_type": engine_type,
},
)
def init_disaggregation(self):
if (
self.disaggregation_mode == DisaggregationMode.DECODE
): # *2 for the headroom.
buffer_size = (self.req_to_token_pool.size) * 2
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
buffer_size
)
aux_dtype = torch.int32
# A list of metadata buffers. The shape is (b, metadata_size) where
# b corresponds to a max running requests. The last shape * dtype.itemsize
# should be larger than 64 bytes to work with RDMA, so we pad it.
output_id_buffer = torch.zeros(
(buffer_size, 16), dtype=aux_dtype, device="cpu"
)
metadata_buffers = [output_id_buffer]
# The decode requests polling kv cache
self.disagg_decode_transfer_queue = DecodeTransferQueue(
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
metadata_buffers=metadata_buffers,
)
# The decode requests pending for pre-allocation
self.disagg_decode_prealloc_queue = DecodePreallocQueue(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
metadata_buffers=metadata_buffers,
aux_dtype=aux_dtype,
scheduler=self,
transfer_queue=self.disagg_decode_transfer_queue,
tree_cache=self.tree_cache,
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
tp_rank=self.tp_rank,
tp_size=self.tp_size,
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
)
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
# *2 for the headroom.
buffer_size = self.max_running_requests * 2
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
buffer_size
)
aux_dtype = torch.int32
# A list of metadata buffers. The shape is (b, metadata_size) where
# b corresponds to a max running requests. The last shape * dtype.itemsize
# should be larger than 64 bytes to work with RDMA, so we pad it.
output_id_buffer = torch.zeros(
(buffer_size, 16), dtype=aux_dtype, device="cpu"
)
metadata_buffers = [output_id_buffer]
self.disagg_prefill_pending_queue = PrefillBootstrapQueue(
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
metadata_buffers=metadata_buffers,
aux_dtype=aux_dtype,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
)
# The prefill requests that are in the middle of kv sending
self.disagg_prefill_infight_queue: List[Req] = []
@DynamicGradMode()
def event_loop_normal(self):
"""A normal scheduler loop."""
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run()
self.cur_batch = batch
if batch:
result = self.run_batch(batch)
self.process_batch_result(batch, result)
else:
# When the server is idle, do self-check and re-init some states
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
self.last_batch = batch
@DynamicGradMode()
def event_loop_overlap(self):
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
self.result_queue = deque()
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run()
self.cur_batch = batch
if batch:
result = self.run_batch(batch)
self.result_queue.append((batch.copy(), result))
if self.last_batch is None:
# Create a dummy first batch to start the pipeline for overlap schedule.
# It is now used for triggering the sampling_info_done event.
tmp_batch = ScheduleBatch(
reqs=None,
forward_mode=ForwardMode.DUMMY_FIRST,
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
)
self.process_batch_result(tmp_batch, None)
if self.last_batch:
# Process the results of the last batch
tmp_batch, tmp_result = self.result_queue.popleft()
tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None
)
self.process_batch_result(tmp_batch, tmp_result)
elif batch is None:
# When the server is idle, do self-check and re-init some states
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
self.last_batch = batch
@torch.no_grad()
def event_loop_normal_disagg_prefill(self):
"""A normal scheduler loop for prefill worker in disaggregation mode."""
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
self.waiting_queue.extend(
self.disagg_prefill_pending_queue.pop_bootstrapped()
)
self.process_prefill_chunk()
batch = self.get_new_batch_prefill()
self.cur_batch = batch
if batch:
result = self.run_batch(batch)
self.process_batch_result_disagg_prefill(batch, result)
if len(self.disagg_prefill_infight_queue) > 0:
self.process_disagg_prefill_infight_queue()
if batch is None and len(self.disagg_prefill_infight_queue) == 0:
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
self.last_batch = batch
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
# Otherwise, it hangs under high concurrency
self.running_batch.batch_is_full = False
@torch.no_grad()
def event_loop_normal_disagg_decode(self):
"""A normal scheduler loop for decode worker in disaggregation mode."""
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
# polling and allocating kv cache
self.process_decode_queue()
batch = self.get_next_disagg_decode_batch_to_run()
self.cur_batch = batch
if batch:
# Generate fake extend output.
if batch.forward_mode.is_extend():
# Note: Logprobs should be handled on the prefill engine.
self.stream_output(
batch.reqs, [False for _ in range(len(batch.reqs))]
)
else:
result = self.run_batch(batch)
self.process_batch_result(batch, result)
if batch is None and (
len(self.disagg_decode_transfer_queue.queue)
+ len(self.disagg_decode_prealloc_queue.queue)
== 0
):
# When the server is idle, do self-check and re-init some states
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
self.last_batch = batch
def recv_requests(self) -> List[Req]:
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
if self.attn_tp_rank == 0:
recv_reqs = []
while True:
try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
recv_reqs.append(recv_req)
while True:
try:
recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
recv_reqs.append(recv_rpc)
else:
recv_reqs = None
if self.server_args.enable_dp_attention:
if self.attn_tp_rank == 0:
work_reqs = [
req
for req in recv_reqs
if isinstance(
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
)
]
control_reqs = [
req
for req in recv_reqs
if not isinstance(
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
)
]
else:
work_reqs = None
control_reqs = None
if self.attn_tp_size != 1:
attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
work_reqs = broadcast_pyobj(
work_reqs,
self.attn_tp_rank,
self.attn_tp_cpu_group,
src=attn_tp_rank_0,
)
if self.tp_size != 1:
control_reqs = broadcast_pyobj(
control_reqs, self.tp_rank, self.tp_cpu_group
)
recv_reqs = work_reqs + control_reqs
elif self.tp_size != 1:
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
return recv_reqs
def process_input_requests(self, recv_reqs: List):
for recv_req in recv_reqs:
# If it is a health check generation request and there are running requests, ignore it.
if is_health_check_generate_req(recv_req) and (
self.chunked_req is not None or not self.running_batch.is_empty()
):
self.return_health_check_ct += 1
continue
output = self._request_dispatcher(recv_req)
if output is not None:
if isinstance(output, RpcReqOutput):
if self.recv_from_rpc is not None:
self.recv_from_rpc.send_pyobj(output)
else:
self.send_to_tokenizer.send_pyobj(output)
def handle_generate_request(
self,
recv_req: TokenizedGenerateReqInput,
):
# Create a new request
if (
recv_req.session_params is None
or recv_req.session_params.id is None
or recv_req.session_params.id not in self.sessions
):
if recv_req.input_embeds is not None:
# Generate fake input_ids based on the length of input_embeds
seq_length = len(recv_req.input_embeds)
fake_input_ids = [1] * seq_length
recv_req.input_ids = fake_input_ids
# Handle custom logit processor passed to the request
custom_logit_processor = recv_req.custom_logit_processor
if (
not self.server_args.enable_custom_logit_processor
and custom_logit_processor is not None
):
logger.warning(
"The SGLang server is not configured to enable custom logit processor."
"The custom logit processor passed in will be ignored."
"Please set --enable-custom-logits-processor to enable this feature."
)
custom_logit_processor = None
req = Req(
recv_req.rid,
recv_req.input_text,
recv_req.input_ids,
recv_req.sampling_params,
return_logprob=recv_req.return_logprob,
top_logprobs_num=recv_req.top_logprobs_num,
token_ids_logprob=recv_req.token_ids_logprob,
stream=recv_req.stream,
lora_path=recv_req.lora_path,
input_embeds=recv_req.input_embeds,
custom_logit_processor=custom_logit_processor,
return_hidden_states=recv_req.return_hidden_states,
eos_token_ids=self.model_config.hf_eos_token_id,
)
req.tokenizer = self.tokenizer
if (
recv_req.session_params is not None
and recv_req.session_params.id is not None
):
req.finished_reason = FINISH_ABORT(
f"Invalid request: session id {recv_req.session_params.id} does not exist"
)
self._add_request_to_queue(req)
return
else:
# Create a new request from a previous session
session = self.sessions[recv_req.session_params.id]
req = session.create_req(recv_req, self.tokenizer)
if isinstance(req.finished_reason, FINISH_ABORT):
self._add_request_to_queue(req)
return
# Handle multimodal inputs
if recv_req.mm_inputs is not None:
image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
# Expand a single image token into multiple dummy tokens for receiving image embeddings
req.origin_input_ids = self.pad_input_ids_func(
req.origin_input_ids, image_inputs
)
req.extend_image_inputs(image_inputs)
if len(req.origin_input_ids) >= self.max_req_input_len:
error_msg = (
"Multimodal prompt is too long after expanding multimodal tokens. "
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
)
logger.error(error_msg)
req.origin_input_ids = [0]
req.multimodal_inputs = None
req.sampling_params.max_new_tokens = 0
req.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
)
self._add_request_to_queue(req)
return
# Validate prompts length
error_msg = validate_input_length(
req,
self.max_req_input_len,
self.server_args.allow_auto_truncate,
)
if error_msg:
req.origin_input_ids = [0]
req.sampling_params.max_new_tokens = 0
self._add_request_to_queue(req)
return
# Copy more attributes
if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
# By default, only return the logprobs for output tokens
req.logprob_start_len = len(req.origin_input_ids) - 1
else:
req.logprob_start_len = recv_req.logprob_start_len
if req.logprob_start_len >= len(req.origin_input_ids):
req.finished_reason = FINISH_ABORT(
f"logprob_start_len, ({req.logprob_start_len}) is higher than the number of input tokens ({len(req.origin_input_ids)}). Request with a lower logprob_start_len.",
HTTPStatus.BAD_REQUEST,
"BadRequestError",
)
req.logprob_start_len = len(req.origin_input_ids) - 1
self._add_request_to_queue(req)
return
req.sampling_params.max_new_tokens = min(
(
req.sampling_params.max_new_tokens
if req.sampling_params.max_new_tokens is not None
else 1 << 30
),
self.max_req_len - len(req.origin_input_ids) - 1,
)
# Init grammar cache for this request
add_to_grammar_queue = False
if (
req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None
or req.sampling_params.ebnf is not None
or req.sampling_params.structural_tag is not None
):
assert self.grammar_backend is not None
if req.sampling_params.json_schema is not None:
key = ("json", req.sampling_params.json_schema)
elif req.sampling_params.regex is not None:
key = ("regex", req.sampling_params.regex)
elif req.sampling_params.ebnf is not None:
key = ("ebnf", req.sampling_params.ebnf)
elif req.sampling_params.structural_tag:
key = ("structural_tag", req.sampling_params.structural_tag)
req.grammar = self.grammar_backend.get_cached_value(key)
if not req.grammar:
req.grammar = self.grammar_backend.get_future_value(key)
add_to_grammar_queue = True
if add_to_grammar_queue:
self.grammar_queue.append(req)
else:
self._add_request_to_queue(req)
def _add_request_to_queue(self, req: Req):
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.disagg_prefill_pending_queue.add(req)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.disagg_decode_prealloc_queue.add(req)
else:
self.waiting_queue.append(req)
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
if self.disaggregation_mode == DisaggregationMode.DECODE:
self.disagg_decode_prealloc_queue.extend(reqs)
else:
self.waiting_queue.extend(reqs)
def handle_embedding_request(
self,
recv_req: TokenizedEmbeddingReqInput,
):
req = Req(
recv_req.rid,
recv_req.input_text,
recv_req.input_ids,
recv_req.sampling_params,
)
req.tokenizer = self.tokenizer
# Handle multimodal inputs
if recv_req.image_inputs is not None:
image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
# Expand a single image token into multiple dummy tokens for receiving image embeddings
req.origin_input_ids = self.pad_input_ids_func(
req.origin_input_ids, image_inputs
)
req.extend_image_inputs(image_inputs)
if len(req.origin_input_ids) >= self.max_req_input_len:
error_msg = (
"Multimodal prompt is too long after expanding multimodal tokens. "
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
)
logger.error(error_msg)
req.origin_input_ids = [0]
req.multimodal_inputs = None
req.sampling_params.max_new_tokens = 0
req.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
)
self.waiting_queue.append(req)
return
# Validate prompts length
error_msg = validate_input_length(
req,
self.max_req_input_len,
self.server_args.allow_auto_truncate,
)
if error_msg:
self._add_request_to_queue(req)
return
# Copy more attributes
req.logprob_start_len = len(req.origin_input_ids) - 1
self._add_request_to_queue(req)
def log_prefill_stats(
self,
adder: PrefillAdder,
can_run_list: List[Req],
running_bs: int,
):
gap_latency = time.time() - self.last_prefill_stats_tic
self.last_prefill_stats_tic = time.time()
self.last_input_throughput = self.num_prefill_tokens / gap_latency
self.num_prefill_tokens = 0
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
self._largest_prefill_len = max(
self._largest_prefill_len, adder.log_input_tokens
)
f = (
f"Prefill batch. "
f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue)}, "
)
logger.info(f)
if self.enable_metrics:
cache_hit_rate = adder.log_hit_tokens / (
adder.log_input_tokens + adder.log_hit_tokens
)
self.stats.num_running_reqs = running_bs
self.stats.num_used_tokens = num_used
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.cache_hit_rate = cache_hit_rate
self.metrics_collector.log_stats(self.stats)
def log_decode_stats(self):
gap_latency = time.time() - self.last_decode_stats_tic
self.last_decode_stats_tic = time.time()
self.last_gen_throughput = self.num_generated_tokens / gap_latency
self.num_generated_tokens = 0
num_running_reqs = len(self.running_batch.reqs)
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
if RECORD_STEP_TIME:
self.step_time_dict[num_running_reqs].append(
gap_latency / self.server_args.decode_log_interval
)
if self.spec_algorithm.is_none():
msg = (
f"Decode batch. "
f"#running-req: {num_running_reqs}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}, "
)
spec_accept_length = 0
else:
spec_accept_length = (
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
)
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
self.cum_spec_accept_count += self.spec_num_total_forward_ct
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
msg = (
f"Decode batch. "
f"#running-req: {num_running_reqs}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"accept len: {spec_accept_length:.2f}, "
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}, "
)
logger.info(msg)
if self.enable_metrics:
self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used
self.stats.token_usage = num_used / self.max_total_num_tokens
self.stats.cache_hit_rate = 0.0
self.stats.gen_throughput = self.last_gen_throughput
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.spec_accept_length = spec_accept_length
self.metrics_collector.log_stats(self.stats)
def check_memory(self):
available_size = (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
protected_size = self.tree_cache.protected_size()
memory_leak = available_size != (
self.max_total_num_tokens
if not self.enable_hierarchical_cache
else self.max_total_num_tokens - protected_size
)
if memory_leak:
msg = (
"KV cache pool leak detected! "
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
f"{self.tree_cache.evictable_size()=}\n"
)
warnings.warn(msg)
if crash_on_warnings():
raise ValueError(msg)
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
msg = (
"Memory pool leak detected!"
f"available_size={len(self.req_to_token_pool.free_slots)}, "
f"total_size={self.req_to_token_pool.size}\n"
)
warnings.warn(msg)
if crash_on_warnings():
raise ValueError(msg)
if (
self.enable_metrics
and self.attn_tp_rank == 0
and time.time() > self.metrics_collector.last_log_time + 30
):
# During idle time, also collect metrics every 30 seconds.
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
num_running_reqs = len(self.running_batch.reqs)
self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used
self.stats.token_usage = num_used / self.max_total_num_tokens
self.stats.gen_throughput = 0
self.stats.num_queue_reqs = len(self.waiting_queue)
self.metrics_collector.log_stats(self.stats)
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
# Merge the prefill batch into the running batch
if self.last_batch and self.last_batch.forward_mode.is_extend():
if self.chunked_req:
# Move the chunked request out of the batch so that we can merge
# only finished requests to running_batch.
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
self.tree_cache.cache_unfinished_req(self.chunked_req)
# chunked request keeps its rid but will get a new req_pool_idx
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
self.running_batch.batch_is_full = False
# Filter batch
last_bs = self.last_batch.batch_size()
self.last_batch.filter_batch()
if self.last_batch.batch_size() < last_bs:
self.running_batch.batch_is_full = False
# Merge the new batch into the running batch
if not self.last_batch.is_empty():
if self.running_batch.is_empty():
self.running_batch = self.last_batch
else:
# Merge running_batch with prefill batch
self.running_batch.merge_batch(self.last_batch)
new_batch = self.get_new_batch_prefill()
if new_batch is not None:
# Run prefill first if possible
ret = new_batch
else:
# Run decode
if not self.running_batch.is_empty():
self.running_batch = self.update_running_batch(self.running_batch)
ret = self.running_batch if not self.running_batch.is_empty() else None
else:
ret = None
# Handle DP attention
if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
ret, _ = self.prepare_dp_attn_batch(ret)
return ret
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
# Check if the grammar is ready in the grammar queue
if self.grammar_queue:
self.move_ready_grammar_requests()
# Handle the cases where prefill is not allowed
if (
self.running_batch.batch_is_full or len(self.waiting_queue) == 0
) and self.chunked_req is None:
return None
running_bs = len(self.running_batch.reqs)
if running_bs >= self.max_running_requests:
self.running_batch.batch_is_full = True
return None
if self.enable_hierarchical_cache:
# check for completion of hierarchical cache activities to release memory
self.tree_cache.writing_check()
self.tree_cache.loading_check()
# Get priority queue
prefix_computed = self.policy.calc_priority(self.waiting_queue)
# Prefill policy
adder = PrefillAdder(
self.tree_cache,
self.token_to_kv_pool_allocator,
self.running_batch,
self.new_token_ratio,
self.max_prefill_tokens,
self.chunked_prefill_size,
running_bs if self.is_mixed_chunk else 0,
)
if self.chunked_req is not None:
self.chunked_req.init_next_round_input()
self.chunked_req = adder.add_chunked_req(self.chunked_req)
if self.lora_paths:
lora_set = set([req.lora_path for req in self.running_batch.reqs])
# Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue:
if (
self.lora_paths
and len(
lora_set
| set([req.lora_path for req in adder.can_run_list])
| set([req.lora_path])
)
> self.max_loras_per_batch
):
self.running_batch.batch_is_full = True
break
if running_bs + len(adder.can_run_list) >= self.max_running_requests:
self.running_batch.batch_is_full = True
break
req.init_next_round_input(
None if prefix_computed else self.tree_cache,
self.enable_hierarchical_cache,
)
res = adder.add_one_req(
req, self.chunked_req, self.enable_hierarchical_cache
)
if res != AddReqResult.CONTINUE:
if res == AddReqResult.NO_TOKEN:
if self.enable_hierarchical_cache:
# Set batch_is_full after making sure there are requests that can be served
self.running_batch.batch_is_full = len(
adder.can_run_list
) > 0 or (
self.running_batch is not None
and not self.running_batch.is_empty()
)
else:
self.running_batch.batch_is_full = True
break
# Update waiting queue
can_run_list: List[Req] = adder.can_run_list
if len(can_run_list) == 0:
return None
self.waiting_queue = [
x for x in self.waiting_queue if x not in set(can_run_list)
]
if self.enable_hierarchical_cache:
self.tree_cache.read_to_load_cache()
if adder.new_chunked_req is not None:
assert self.chunked_req is None
self.chunked_req = adder.new_chunked_req
if self.chunked_req:
self.chunked_req.is_chunked += 1
# Print stats
if self.attn_tp_rank == 0:
self.log_prefill_stats(adder, can_run_list, running_bs)
# Create a new batch
new_batch = ScheduleBatch.init_new(
can_run_list,
self.req_to_token_pool,
self.token_to_kv_pool_allocator,
self.tree_cache,
self.model_config,
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
)
new_batch.prepare_for_extend()
# Mixed-style chunked prefill
if (
self.is_mixed_chunk
and not self.running_batch.is_empty()
and not (new_batch.return_logprob or self.running_batch.return_logprob)
):
# TODO (lianmin): support return_logprob + mixed chunked prefill
self.running_batch.filter_batch()
if not self.running_batch.is_empty():
self.running_batch.prepare_for_decode()
new_batch.mix_with_running(self.running_batch)
new_batch.decoding_reqs = self.running_batch.reqs
self.running_batch = ScheduleBatch(
reqs=[], batch_is_full=self.running_batch.batch_is_full
)
else:
new_batch.decoding_reqs = None
return new_batch
def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
"""Update the current running decoding batch."""
initial_bs = batch.batch_size()
batch.filter_batch()
if batch.is_empty():
batch.batch_is_full = False
return batch
# Check if decode out of memory
if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
TEST_RETRACT and batch.batch_size() > 10
):
old_ratio = self.new_token_ratio
retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
self.new_token_ratio = new_token_ratio
logger.info(
"Decode out of memory happened. "
f"#retracted_reqs: {len(retracted_reqs)}, "
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
)
self._extend_requests_to_queue(retracted_reqs)
else:
self.new_token_ratio = max(
self.new_token_ratio - self.new_token_ratio_decay,
self.min_new_token_ratio,
)
if batch.batch_size() < initial_bs:
batch.batch_is_full = False
# Update batch tensors
batch.prepare_for_decode()
return batch
def run_batch(
self, batch: ScheduleBatch
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
"""Run a batch."""
self.forward_ct += 1
# Check profiler
if (
self.profiler_target_forward_ct
and self.profiler_target_forward_ct <= self.forward_ct
):
self.stop_profile()
# Run forward
if self.is_generation:
if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch
)
bid = model_worker_batch.bid
else:
(
logits_output,
next_token_ids,
bid,
num_accepted_tokens,
) = self.draft_worker.forward_batch_speculative_generation(batch)
self.spec_num_total_accepted_tokens += (
num_accepted_tokens + batch.batch_size()
)
self.spec_num_total_forward_ct += batch.batch_size()
self.num_generated_tokens += num_accepted_tokens
batch.output_ids = next_token_ids
# These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that
# we can use the correct values in output processing.
if batch.return_logprob:
extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
extend_logprob_start_len_per_req = [
req.extend_logprob_start_len for req in batch.reqs
]
else:
extend_input_len_per_req = None
extend_logprob_start_len_per_req = None
ret = GenerationBatchResult(
logits_output=logits_output,
next_token_ids=next_token_ids,
extend_input_len_per_req=extend_input_len_per_req,
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
bid=bid,
)
else: # embedding or reward model
model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
ret = EmbeddingBatchResult(
embeddings=embeddings, bid=model_worker_batch.bid
)
return ret
def process_batch_result(
self,
batch: ScheduleBatch,
result: Union[GenerationBatchResult, EmbeddingBatchResult],
):
if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result)
elif batch.forward_mode.is_extend():
self.process_batch_result_prefill(batch, result)
elif batch.forward_mode.is_idle():
if self.enable_overlap:
self.tp_worker.resolve_batch_result(result.bid)
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
elif batch.forward_mode.is_dummy_first():
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
if self.return_health_check_ct:
# Return some signal for the health check.
# This is used to prevent the health check signal being blocked by long context prefill.
# However, one minor issue is that this code path does not check the status of detokenizer manager.
self.return_health_check_ct -= 1
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
# Check if other DP workers have running batches
if local_batch is None:
num_tokens = 0
global_num_tokens_for_logprob = 0
elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size()
if not self.spec_algorithm.is_none() and self.spec_algorithm.is_eagle():
num_tokens = num_tokens * self.server_args.speculative_num_draft_tokens
global_num_tokens_for_logprob = num_tokens
else:
num_tokens = local_batch.extend_num_tokens
global_num_tokens_for_logprob = sum(
[
# We should have at least 1 token for sample in every case.
max(extend_len - logprob_start_len, 1)
for logprob_start_len, extend_len in zip(
local_batch.extend_logprob_start_lens, local_batch.extend_lens
)
]
)
if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
can_cuda_graph = 1
else:
can_cuda_graph = 0
if not self.spec_algorithm.is_none():
# TODO(sang): Support cuda graph when idle batch is there.
if local_batch is None or local_batch.forward_mode.is_idle():
can_cuda_graph = 0
is_extend_in_batch = (
local_batch.forward_mode.is_extend() if local_batch else False
)
local_info = torch.tensor(
[
num_tokens,
can_cuda_graph,
global_num_tokens_for_logprob,
is_extend_in_batch,
],
dtype=torch.int64,
)
global_info = torch.empty(
(self.server_args.dp_size, self.attn_tp_size, 4),
dtype=torch.int64,
)
torch.distributed.all_gather_into_tensor(
global_info.flatten(),
local_info,
group=self.tp_cpu_group,
)
global_num_tokens = global_info[:, 0, 0].tolist()
can_cuda_graph = min(global_info[:, 0, 1].tolist())
global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
is_extend_in_batch = global_info[:, 0, 3].tolist()
if local_batch is None and max(global_num_tokens) > 0:
local_batch = self.get_idle_batch()
if local_batch is not None:
local_batch.global_num_tokens = global_num_tokens
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
# Check forward mode for cuda graph
if not self.server_args.disable_cuda_graph:
local_batch.can_run_dp_cuda_graph = can_cuda_graph
return local_batch, any(is_extend_in_batch)
def get_idle_batch(self):
idle_batch = ScheduleBatch.init_new(
[],
self.req_to_token_pool,
self.token_to_kv_pool_allocator,
self.tree_cache,
self.model_config,
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
)
idle_batch.prepare_for_idle()
return idle_batch
def move_ready_grammar_requests(self):
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
num_ready_reqs = 0
for req in self.grammar_queue:
try:
req.grammar = req.grammar.result(timeout=0.05)
num_ready_reqs += 1
except futures._base.TimeoutError:
break
if self.server_args.enable_dp_attention:
tp_size = self.attn_tp_size
tp_group = self.attn_tp_cpu_group
else:
tp_size = self.tp_size
tp_group = self.tp_cpu_group
if tp_size > 1:
# Sync across TP ranks to make sure they have the same number of ready requests
tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
torch.distributed.all_reduce(
tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
)
num_ready_reqs_max = tensor.item()
for i in range(num_ready_reqs, num_ready_reqs_max):
self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
num_ready_reqs = num_ready_reqs_max
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
def watchdog_thread(self):
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
self.watchdog_last_forward_ct = 0
self.watchdog_last_time = time.time()
while True:
current = time.time()
if self.cur_batch is not None:
if self.watchdog_last_forward_ct == self.forward_ct:
if current > self.watchdog_last_time + self.watchdog_timeout:
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
break
else:
self.watchdog_last_forward_ct = self.forward_ct
self.watchdog_last_time = current
time.sleep(self.watchdog_timeout // 2)
# Print batch size and memory pool info to check whether there are de-sync issues.
logger.error(
f"{self.cur_batch.batch_size()=}, "
f"{self.cur_batch.reqs=}, "
f"{self.token_to_kv_pool_allocator.available_size()=}, "
f"{self.tree_cache.evictable_size()=}, "
)
# Wait for some time so that the parent process can print the error.
pyspy_dump_schedulers()
print(file=sys.stderr, flush=True)
print(file=sys.stdout, flush=True)
time.sleep(5)
self.parent_process.send_signal(signal.SIGQUIT)
def flush_cache_wrapped(self, recv_req: FlushCacheReq):
self.flush_cache()
def flush_cache(self):
"""Flush the memory pool and cache."""
if len(self.waiting_queue) == 0 and self.running_batch.is_empty():
self.cur_batch = None
self.last_batch = None
self.tree_cache.reset()
if self.grammar_backend:
self.grammar_backend.reset()
self.req_to_token_pool.clear()
self.token_to_kv_pool_allocator.clear()
if not self.spec_algorithm.is_none():
self.draft_worker.model_runner.req_to_token_pool.clear()
self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
self.num_generated_tokens = 0
self.forward_ct_decode = 0
self.spec_num_total_accepted_tokens = 0
self.spec_num_total_forward_ct = 0
self.cum_spec_accept_length = 0
self.cum_spec_accept_count = 0
torch.cuda.empty_cache()
logger.info("Cache flushed successfully!")
if_success = True
else:
logging.warning(
f"Cache not flushed because there are pending requests. "
f"#queue-req: {len(self.waiting_queue)}, "
f"#running-req: {len(self.running_batch.reqs)}"
)
if_success = False
return if_success
def get_internal_state(self, recv_req: GetInternalStateReq):
ret = dict(global_server_args_dict)
ret["last_gen_throughput"] = self.last_gen_throughput
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
ret["avg_spec_accept_length"] = (
self.cum_spec_accept_length / self.cum_spec_accept_count
)
if RECORD_STEP_TIME:
ret["step_time_dict"] = self.step_time_dict
return GetInternalStateReqOutput(
internal_state=ret,
)
def set_internal_state(self, recv_req: SetInternalStateReq):
server_args_dict = recv_req.server_args
args_allow_update = set(
[
"speculative_accept_threshold_single",
"speculative_accept_threshold_acc",
]
)
if_success = True
for k, v in server_args_dict.items():
if k not in args_allow_update:
logging.warning(f"Updating {k} is not supported.")
if_success = False
break
if if_success:
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
avg_spec_accept_length = (
self.cum_spec_accept_length / self.cum_spec_accept_count
)
logger.info(f"{avg_spec_accept_length=}")
self.cum_spec_accept_length = self.cum_spec_accept_count = 0
for k, v in server_args_dict.items():
global_server_args_dict[k] = v
logger.info(f"Global server args updated! " f"{global_server_args_dict=}")
return SetInternalStateReqOutput(
updated=True,
server_args=global_server_args_dict,
)
def handle_rpc_request(self, recv_req: RpcReqInput):
# Handle RPC requests
logger.info(
f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
)
success = True
exec = None
try:
func = getattr(self, recv_req.method)
func(recv_req.parameters)
except Exception as e:
success = False
exec = e
logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")
barrier()
return RpcReqOutput(success, "" if not exec else str(exec))
def save_remote_model(self, params):
url = params["url"]
worker = self.tp_worker.worker
worker.model_runner.save_remote_model(url)
def save_sharded_model(self, params):
worker = self.tp_worker.worker
worker.model_runner.save_sharded_model(
path=params["path"],
pattern=params["pattern"],
max_size=params["max_size"],
)
def abort_request(self, recv_req: AbortReq):
# Delete requests in the waiting queue
to_del = []
for i, req in enumerate(self.waiting_queue):
if req.rid.startswith(recv_req.rid):
to_del.append(i)
break
# Sort in reverse order to avoid index issues when deleting
for i in sorted(to_del, reverse=True):
req = self.waiting_queue.pop(i)
logger.debug(f"Abort queued request. {req.rid=}")
return
# Delete requests in the running batch
for req in self.running_batch.reqs:
if req.rid.startswith(recv_req.rid) and not req.finished():
logger.debug(f"Abort running request. {req.rid=}")
req.to_abort = True
return
def _pause_engine(self) -> Tuple[List[Req], int]:
raise NotImplementedError()
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
"""In-place update of the weights from disk."""
success, message = self.tp_worker.update_weights_from_disk(recv_req)
if success:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return UpdateWeightFromDiskReqOutput(success, message, 0)
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
"""Initialize the online model parameter update group."""
success, message = self.tp_worker.init_weights_update_group(recv_req)
return InitWeightsUpdateGroupReqOutput(success, message)
def update_weights_from_distributed(
self,
recv_req: UpdateWeightsFromDistributedReqInput,
) -> Tuple[bool, str]:
"""Update the online model parameter."""
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
if success:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return UpdateWeightsFromDistributedReqOutput(success, message)
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
"""Update the online model parameter from tensors."""
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
if success:
if recv_req.flush_cache:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return UpdateWeightsFromTensorReqOutput(success, message)
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
return GetWeightsByNameReqOutput(parameter)
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
self.memory_saver_adapter.check_validity(
caller_name="release_memory_occupation"
)
self.stashed_model_static_state = _export_static_state(
self.tp_worker.worker.model_runner.model
)
self.memory_saver_adapter.pause()
self.flush_cache()
return ReleaseMemoryOccupationReqOutput()
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
self.memory_saver_adapter.resume()
_import_static_state(
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
)
del self.stashed_model_static_state
return ResumeMemoryOccupationReqOutput()
def profile(self, recv_req: ProfileReq):
if recv_req.type == ProfileReqType.START_PROFILE:
return self.start_profile(
recv_req.output_dir,
recv_req.num_steps,
recv_req.activities,
recv_req.with_stack,
recv_req.record_shapes,
)
else:
return self.stop_profile()
def start_profile(
self,
output_dir: Optional[str],
num_steps: Optional[int],
activities: Optional[List[str]],
with_stack: Optional[bool],
record_shapes: Optional[bool],
) -> None:
if self.profiler_activities:
return ProfileReqOutput(
success=False,
message="Profiling is already in progress. Call /stop_profile first.",
)
if output_dir is None:
output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
if activities is None:
activities = ["CPU", "GPU"]
self.torch_profiler_output_dir = output_dir
self.profiler_activities = activities
logger.info(
"Profiling starts. Traces will be saved to: %s",
self.torch_profiler_output_dir,
)
activity_map = {
"CPU": torch.profiler.ProfilerActivity.CPU,
"GPU": torch.profiler.ProfilerActivity.CUDA,
}
torchprof_activities = [
activity_map[a] for a in activities if a in activity_map
]
if torchprof_activities:
self.torch_profiler = torch.profiler.profile(
activities=torchprof_activities,
with_stack=with_stack if with_stack is not None else True,
record_shapes=record_shapes if record_shapes is not None else False,
)
self.torch_profiler.start()
if "MEM" in activities:
torch.cuda.memory._record_memory_history(max_entries=100000)
if "CUDA_PROFILER" in activities:
torch.cuda.cudart().cudaProfilerStart()
if num_steps:
self.profiler_target_forward_ct = self.forward_ct + num_steps
# The caller will be notified when reaching profiler_target_forward_ct
else:
self.profiler_target_forward_ct = None
return ProfileReqOutput(success=True, message="Succeeded")
def stop_profile(self) -> None:
if self.profiler_activities is None:
return
logger.info("Stop profiling...")
if self.torch_profiler is not None:
self.torch_profiler.stop()
self.torch_profiler.export_chrome_trace(
os.path.join(
self.torch_profiler_output_dir,
str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
)
)
if "MEM" in self.profiler_activities:
memory_profile_path = os.path.join(
self.torch_profiler_output_dir,
str(time.time()) + f"-TP-{self.tp_rank}-memory" + ".pickle",
)
torch.cuda.memory._dump_snapshot(memory_profile_path)
torch.cuda.memory._record_memory_history(enabled=None)
if "CUDA_PROFILER" in self.profiler_activities:
torch.cuda.cudart().cudaProfilerStop()
logger.info(
"Profiling done. Traces are saved to: %s",
self.torch_profiler_output_dir,
)
self.torch_profiler = None
self.torch_profiler_output_dir = None
self.profiler_activities = None
if self.profiler_target_forward_ct:
self.send_to_tokenizer.send_pyobj(
ProfileReqOutput(success=True, message="Succeeded.")
)
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
if recv_req == ExpertDistributionReq.START_RECORD:
expert_distribution_recorder.start_record()
elif recv_req == ExpertDistributionReq.STOP_RECORD:
expert_distribution_recorder.stop_record()
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
expert_distribution_recorder.dump_record()
else:
raise ValueError("Unrecognized ExpertDistributionReq value")
return ExpertDistributionReqOutput()
def open_session(self, recv_req: OpenSessionReqInput):
# handle error
session_id = recv_req.session_id
if session_id in self.sessions:
logger.warning(f"session id {session_id} already exist, cannot open.")
return OpenSessionReqOutput(session_id, False)
elif session_id is None:
logger.warning("session id is None, cannot open.")
return OpenSessionReqOutput(session_id, False)
else:
self.sessions[session_id] = Session(
recv_req.capacity_of_str_len, session_id
)
return OpenSessionReqOutput(session_id, True)
def close_session(self, recv_req: CloseSessionReqInput):
# handle error
session_id = recv_req.session_id
if session_id not in self.sessions:
logger.warning(f"session id {session_id} does not exist, cannot delete.")
else:
del self.sessions[session_id]
def is_health_check_generate_req(recv_req):
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
def _export_static_state(model):
return dict(
buffers=[
(name, buffer.detach().clone()) for name, buffer in model.named_buffers()
]
)
def _import_static_state(model, static_params):
self_named_buffers = dict(model.named_buffers())
for name, tensor in static_params["buffers"]:
self_named_buffers[name][...] = tensor
def run_scheduler_process(
server_args: ServerArgs,
port_args: PortArgs,
gpu_id: int,
tp_rank: int,
dp_rank: Optional[int],
pipe_writer,
):
# Generate the prefix
if dp_rank is None:
prefix = f" TP{tp_rank}"
else:
prefix = f" DP{dp_rank} TP{tp_rank}"
# Config the process
kill_itself_when_parent_died()
setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
faulthandler.enable()
parent_process = psutil.Process().parent()
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
dp_rank = int(os.environ["SGLANG_DP_RANK"])
# Configure the logger
configure_logger(server_args, prefix=prefix)
suppress_other_loggers()
# Set cpu affinity to this gpu process
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
# Create a scheduler and run the event loop
try:
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
pipe_writer.send(
{
"status": "ready",
"max_total_num_tokens": scheduler.max_total_num_tokens,
"max_req_input_len": scheduler.max_req_input_len,
}
)
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
if disaggregation_mode == DisaggregationMode.NULL:
if scheduler.enable_overlap:
scheduler.event_loop_overlap()
else:
scheduler.event_loop_normal()
elif disaggregation_mode == DisaggregationMode.PREFILL:
scheduler.event_loop_normal_disagg_prefill()
elif disaggregation_mode == DisaggregationMode.DECODE:
scheduler.event_loop_normal_disagg_decode()
except Exception:
traceback = get_exception_traceback()
logger.error(f"Scheduler hit an exception: {traceback}")
parent_process.send_signal(signal.SIGQUIT)