""" Life cycle of a request in the decode server 1. PreallocQueue: a. Initialize a receiver for each request b. The request handshakes first, and pre-allocate kv once there is available kv. c. Move the request to TransferQueue. 2. TransferQueue: a. Poll the receiver to check the transfer state b. If the transfer has finished, move the request to waiting queue 3. WaitingQueue: a. Use the requests in the queue to construct a PrebuiltExtendBatch b. Skip the prefill forward but only populate metadata 4. RunningBatch: a. Merge the resolved PrebuiltExtendBatch into running batch to run decoding """ from __future__ import annotations import logging from dataclasses import dataclass from typing import TYPE_CHECKING, List, Optional, Tuple import torch from torch.distributed import ProcessGroup from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVReceiver from sglang.srt.disaggregation.utils import ( ReqToMetadataIdxAllocator, poll_and_all_reduce, ) from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo logger = logging.getLogger(__name__) if TYPE_CHECKING: from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.scheduler import Scheduler from sglang.srt.server_args import ServerArgs @dataclass class DecodeRequest: req: Req kv_receiver: KVReceiver waiting_for_input: bool = False metadata_buffer_index: int = -1 class DecodePreallocQueue: """ Store the requests that are preallocating. """ def __init__( self, req_to_token_pool: ReqToTokenPool, token_to_kv_pool_allocator: TokenToKVPoolAllocator, req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator, metadata_buffers: List[torch.Tensor], aux_dtype: torch.dtype, scheduler: Scheduler, transfer_queue: DecodeTransferQueue, tree_cache: BasePrefixCache, gloo_group: ProcessGroup, tp_rank: int, tp_size: int, bootstrap_port: int, ): self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache() self.aux_dtype = aux_dtype self.metadata_buffers = metadata_buffers self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator self.scheduler = scheduler self.transfer_queue = transfer_queue self.tree_cache = tree_cache # this is always a chunk cache self.gloo_group = gloo_group self.tp_rank = tp_rank self.tp_size = tp_size self.bootstrap_port = bootstrap_port self.num_reserved_decode_tokens = 512 # Queue for requests pending pre-allocation self.queue: List[DecodeRequest] = [] self.kv_manager = self._init_kv_manager() def _init_kv_manager(self) -> KVManager: kv_args = KVArgs() kv_args.engine_rank = self.tp_rank kv_data_ptrs, kv_data_lens, kv_item_lens = ( self.token_to_kv_pool.get_contiguous_buf_infos() ) kv_args.kv_data_ptrs = kv_data_ptrs kv_args.kv_data_lens = kv_data_lens kv_args.kv_item_lens = kv_item_lens kv_args.aux_data_ptrs = [ output_id_tensor.data_ptr() for output_id_tensor in self.metadata_buffers ] kv_args.aux_data_lens = [ metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers ] kv_args.aux_item_lens = [ metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers ] kv_args.ib_device = "mock-ib-device" kv_manager = KVManager(kv_args) return kv_manager def add(self, req: Req) -> None: """Add a request to the pending queue.""" kv_receiver = KVReceiver( mgr=self.kv_manager, bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}", bootstrap_room=req.bootstrap_room, ) self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver)) def extend(self, reqs: List[Req]) -> None: """Add a request to the pending queue.""" for req in reqs: self.add(req) def _update_handshake_waiters(self) -> None: if not self.queue: return if all(decode_req.waiting_for_input for decode_req in self.queue): return polls = poll_and_all_reduce( [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group ) for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): if poll == KVPoll.Bootstrapping: pass elif poll == KVPoll.WaitingForInput: decode_req.waiting_for_input = True elif poll == KVPoll.Failed: raise Exception("Handshake failed") def pop_preallocated(self) -> List[DecodeRequest]: """Pop the preallocated requests from the pending queue (FIFO).""" self._update_handshake_waiters() preallocated_reqs = [] indices_to_remove = set() allocatable_tokens = self._allocatable_tokens() for i, decode_req in enumerate(self.queue): if not decode_req.waiting_for_input: continue if self.req_to_token_pool.available_size() <= 0: break if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0: break required_tokens_for_request = ( len(decode_req.req.origin_input_ids) + self.num_reserved_decode_tokens ) if required_tokens_for_request > allocatable_tokens: break allocatable_tokens -= required_tokens_for_request self._pre_alloc(decode_req.req) kv_indices = ( self.req_to_token_pool.req_to_token[decode_req.req.req_pool_idx][ : len(decode_req.req.origin_input_ids) ] .cpu() .numpy() ) decode_req.metadata_buffer_index = ( self.req_to_metadata_buffer_idx_allocator.alloc() ) assert decode_req.metadata_buffer_index is not None decode_req.kv_receiver.init(kv_indices, decode_req.metadata_buffer_index) preallocated_reqs.append(decode_req) indices_to_remove.add(i) self.queue = [ entry for i, entry in enumerate(self.queue) if i not in indices_to_remove ] return preallocated_reqs def _allocatable_tokens(self) -> int: allocatable_tokens = ( self.token_to_kv_pool_allocator.available_size() - self.num_reserved_decode_tokens * ( len(self.scheduler.running_batch.reqs) + len(self.transfer_queue.queue) + len(self.scheduler.waiting_queue) ) ) # Note: if the last fake extend just finishes, and we enter `pop_preallocated` immediately in the next iteration # the extend batch is not in any queue, so we need to explicitly add the tokens slots here if ( self.scheduler.last_batch and self.scheduler.last_batch.forward_mode.is_extend() ): allocatable_tokens -= self.num_reserved_decode_tokens * len( self.scheduler.last_batch.reqs ) return allocatable_tokens def _pre_alloc(self, req: Req) -> torch.Tensor: """Pre-allocate the memory for req_to_token and token_kv_pool""" req_pool_indices = self.req_to_token_pool.alloc(1) assert req_pool_indices is not None req.req_pool_idx = req_pool_indices[0] kv_loc = self.token_to_kv_pool_allocator.alloc( len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0) ) assert kv_loc is not None self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc) # populate metadata req.fill_ids = req.origin_input_ids + req.output_ids req.extend_input_len = len(req.origin_input_ids) return kv_loc class DecodeTransferQueue: """ Store the requests that is polling kv """ def __init__( self, gloo_group: ProcessGroup, req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator, metadata_buffers: torch.Tensor, ): self.queue: List[DecodeRequest] = [] self.gloo_group = gloo_group self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator self.metadata_buffers = metadata_buffers def add(self, req_conn: DecodeRequest) -> None: self.queue.append(req_conn) def extend(self, req_conns) -> None: self.queue.extend(req_conns) def pop_transferred(self) -> List[Req]: if not self.queue: return [] polls = poll_and_all_reduce( [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group ) transferred_reqs = [] indices_to_remove = set() for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): if poll == KVPoll.Failed: raise Exception("Transfer failed") elif poll == KVPoll.Success: # pop and push it to waiting queue idx = decode_req.metadata_buffer_index assert len(decode_req.req.output_ids) == 0 output_id_buffer = self.metadata_buffers[0] # the last dimension is padded by the same values. output_id = output_id_buffer[idx][0].item() assert len(decode_req.req.output_ids) == 0 assert decode_req.req.transferred_output_id is None decode_req.req.transferred_output_id = output_id transferred_reqs.append(decode_req.req) indices_to_remove.add(i) elif poll in [ KVPoll.Bootstrapping, KVPoll.WaitingForInput, KVPoll.Transferring, ]: pass else: raise ValueError(f"Unexpected poll case: {poll}") for i in indices_to_remove: idx = self.queue[i].metadata_buffer_index assert idx != -1 self.req_to_metadata_buffer_idx_allocator.free(idx) self.queue = [ entry for i, entry in enumerate(self.queue) if i not in indices_to_remove ] return transferred_reqs class ScheduleBatchDisaggregationDecodeMixin: def prepare_for_prebuilt_extend(self: ScheduleBatch): """ Prepare a prebuilt extend by populate metadata Adapted from .prepare_for_extend(). """ self.forward_mode = ForwardMode.EXTEND reqs = self.reqs input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] extend_num_tokens = sum(len(ids) for ids in input_ids) seq_lens = [] pre_lens = [] req_pool_indices = [] # Pre-calculate total size total_size = sum(req.extend_input_len for req in reqs) out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device) # Fill the tensor in one pass offset = 0 for i, req in enumerate(reqs): req_pool_indices.append(req.req_pool_idx) chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][ : req.extend_input_len ] assert ( offset + req.extend_input_len <= total_size ), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}" out_cache_loc[offset : offset + req.extend_input_len] = chunk offset += req.extend_input_len pre_len = len(req.prefix_indices) seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1) seq_lens.append(seq_len) if len(req.output_ids) == 0: assert ( seq_len - pre_len == req.extend_input_len ), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}" req.cached_tokens += pre_len - req.already_computed req.already_computed = seq_len req.is_retracted = False pre_lens.append(pre_len) req.extend_logprob_start_len = 0 extend_input_logprob_token_ids = None # Set fields self.input_ids = torch.tensor( sum(input_ids, []), dtype=torch.int32, device=self.device ) self.req_pool_indices = torch.tensor( req_pool_indices, dtype=torch.int64, device=self.device ) self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device) self.out_cache_loc = out_cache_loc self.seq_lens_sum = sum(seq_lens) self.extend_num_tokens = extend_num_tokens self.prefix_lens = [len(r.prefix_indices) for r in reqs] self.extend_lens = [r.extend_input_len for r in reqs] self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] self.extend_input_logprob_token_ids = extend_input_logprob_token_ids # Build sampling info self.sampling_info = SamplingBatchInfo.from_schedule_batch( self, self.model_config.vocab_size, ) def process_prebuilt_extend( self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig ): """Assign the buffered last input id to schedule batch""" self.output_ids = [] for req in self.reqs: if req.output_ids and len(req.output_ids) > 0: # resumed retracted req self.output_ids.append(req.output_ids[-1]) else: assert req.transferred_output_id is not None req.output_ids.append(req.transferred_output_id) self.output_ids.append(req.transferred_output_id) self.tree_cache.cache_unfinished_req(req) self.output_ids = torch.tensor(self.output_ids, device=self.device) class SchedulerDisaggregationDecodeMixin: def get_next_disagg_decode_batch_to_run( self: Scheduler, ) -> Optional[Tuple[ScheduleBatch, bool]]: """Create fake completed prefill if possible and merge with running batch""" # Merge the prefill batch into the running batch last_batch = self.last_batch if last_batch and last_batch.forward_mode.is_extend(): # chunked prefill doesn't happen in decode instance. assert self.chunked_req is None # Filter finished batches. last_batch.filter_batch() if not last_batch.is_empty(): if self.running_batch.is_empty(): self.running_batch = last_batch else: # merge running_batch with prefill batch self.running_batch.merge_batch(last_batch) new_prebuilt_batch = self.get_new_prebuilt_batch() ret: Optional[ScheduleBatch] = None if new_prebuilt_batch: ret = new_prebuilt_batch else: if self.running_batch.is_empty(): ret = None else: self.running_batch = self.update_running_batch(self.running_batch) ret = self.running_batch if not self.running_batch.is_empty() else None return ret def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]: """Create a schedulebatch for fake completed prefill""" if len(self.waiting_queue) == 0: return None curr_batch_size = self.running_batch.batch_size() batch_size = min(self.req_to_token_pool.size, self.max_running_requests) num_not_used_batch = batch_size - curr_batch_size # pop req from waiting queue can_run_list: List[Req] = [] waiting_queue: List[Req] = [] for i in range(len(self.waiting_queue)): req = self.waiting_queue[i] # we can only add at least `num_not_used_batch` new batch to the running queue if i < num_not_used_batch: can_run_list.append(req) req.init_next_round_input(self.tree_cache) else: waiting_queue.append(req) self.waiting_queue = waiting_queue if len(can_run_list) == 0: return None # local import to avoid circular import from sglang.srt.managers.schedule_batch import ScheduleBatch # construct a schedule batch with those requests and mark as decode new_batch = ScheduleBatch.init_new( can_run_list, self.req_to_token_pool, self.token_to_kv_pool_allocator, self.tree_cache, self.model_config, self.enable_overlap, self.spec_algorithm, self.server_args.enable_custom_logit_processor, ) # construct fake completed prefill new_batch.prepare_for_prebuilt_extend() new_batch.process_prebuilt_extend(self.server_args, self.model_config) return new_batch def process_decode_queue(self: Scheduler): req_conns = self.disagg_decode_prealloc_queue.pop_preallocated() self.disagg_decode_transfer_queue.extend(req_conns) alloc_reqs = ( self.disagg_decode_transfer_queue.pop_transferred() ) # the requests which kv has arrived self.waiting_queue.extend(alloc_reqs)