sglang0.4.5.post1/python/sglang/srt/disaggregation/decode.py

496 lines
18 KiB
Python

"""
Life cycle of a request in the decode server
1. PreallocQueue:
a. Initialize a receiver for each request
b. The request handshakes first, and pre-allocate kv once there is available kv.
c. Move the request to TransferQueue.
2. TransferQueue:
a. Poll the receiver to check the transfer state
b. If the transfer has finished, move the request to waiting queue
3. WaitingQueue:
a. Use the requests in the queue to construct a PrebuiltExtendBatch
b. Skip the prefill forward but only populate metadata
4. RunningBatch:
a. Merge the resolved PrebuiltExtendBatch into running batch to run decoding
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple
import torch
from torch.distributed import ProcessGroup
from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVReceiver
from sglang.srt.disaggregation.utils import (
ReqToMetadataIdxAllocator,
poll_and_all_reduce,
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.scheduler import Scheduler
from sglang.srt.server_args import ServerArgs
@dataclass
class DecodeRequest:
req: Req
kv_receiver: KVReceiver
waiting_for_input: bool = False
metadata_buffer_index: int = -1
class DecodePreallocQueue:
"""
Store the requests that are preallocating.
"""
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
metadata_buffers: List[torch.Tensor],
aux_dtype: torch.dtype,
scheduler: Scheduler,
transfer_queue: DecodeTransferQueue,
tree_cache: BasePrefixCache,
gloo_group: ProcessGroup,
tp_rank: int,
tp_size: int,
bootstrap_port: int,
):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache()
self.aux_dtype = aux_dtype
self.metadata_buffers = metadata_buffers
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
self.scheduler = scheduler
self.transfer_queue = transfer_queue
self.tree_cache = tree_cache # this is always a chunk cache
self.gloo_group = gloo_group
self.tp_rank = tp_rank
self.tp_size = tp_size
self.bootstrap_port = bootstrap_port
self.num_reserved_decode_tokens = 512
# Queue for requests pending pre-allocation
self.queue: List[DecodeRequest] = []
self.kv_manager = self._init_kv_manager()
def _init_kv_manager(self) -> KVManager:
kv_args = KVArgs()
kv_args.engine_rank = self.tp_rank
kv_data_ptrs, kv_data_lens, kv_item_lens = (
self.token_to_kv_pool.get_contiguous_buf_infos()
)
kv_args.kv_data_ptrs = kv_data_ptrs
kv_args.kv_data_lens = kv_data_lens
kv_args.kv_item_lens = kv_item_lens
kv_args.aux_data_ptrs = [
output_id_tensor.data_ptr() for output_id_tensor in self.metadata_buffers
]
kv_args.aux_data_lens = [
metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
]
kv_args.aux_item_lens = [
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
]
kv_args.ib_device = "mock-ib-device"
kv_manager = KVManager(kv_args)
return kv_manager
def add(self, req: Req) -> None:
"""Add a request to the pending queue."""
kv_receiver = KVReceiver(
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
)
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
def extend(self, reqs: List[Req]) -> None:
"""Add a request to the pending queue."""
for req in reqs:
self.add(req)
def _update_handshake_waiters(self) -> None:
if not self.queue:
return
if all(decode_req.waiting_for_input for decode_req in self.queue):
return
polls = poll_and_all_reduce(
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
)
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
if poll == KVPoll.Bootstrapping:
pass
elif poll == KVPoll.WaitingForInput:
decode_req.waiting_for_input = True
elif poll == KVPoll.Failed:
raise Exception("Handshake failed")
def pop_preallocated(self) -> List[DecodeRequest]:
"""Pop the preallocated requests from the pending queue (FIFO)."""
self._update_handshake_waiters()
preallocated_reqs = []
indices_to_remove = set()
allocatable_tokens = self._allocatable_tokens()
for i, decode_req in enumerate(self.queue):
if not decode_req.waiting_for_input:
continue
if self.req_to_token_pool.available_size() <= 0:
break
if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0:
break
required_tokens_for_request = (
len(decode_req.req.origin_input_ids) + self.num_reserved_decode_tokens
)
if required_tokens_for_request > allocatable_tokens:
break
allocatable_tokens -= required_tokens_for_request
self._pre_alloc(decode_req.req)
kv_indices = (
self.req_to_token_pool.req_to_token[decode_req.req.req_pool_idx][
: len(decode_req.req.origin_input_ids)
]
.cpu()
.numpy()
)
decode_req.metadata_buffer_index = (
self.req_to_metadata_buffer_idx_allocator.alloc()
)
assert decode_req.metadata_buffer_index is not None
decode_req.kv_receiver.init(kv_indices, decode_req.metadata_buffer_index)
preallocated_reqs.append(decode_req)
indices_to_remove.add(i)
self.queue = [
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
]
return preallocated_reqs
def _allocatable_tokens(self) -> int:
allocatable_tokens = (
self.token_to_kv_pool_allocator.available_size()
- self.num_reserved_decode_tokens
* (
len(self.scheduler.running_batch.reqs)
+ len(self.transfer_queue.queue)
+ len(self.scheduler.waiting_queue)
)
)
# Note: if the last fake extend just finishes, and we enter `pop_preallocated` immediately in the next iteration
# the extend batch is not in any queue, so we need to explicitly add the tokens slots here
if (
self.scheduler.last_batch
and self.scheduler.last_batch.forward_mode.is_extend()
):
allocatable_tokens -= self.num_reserved_decode_tokens * len(
self.scheduler.last_batch.reqs
)
return allocatable_tokens
def _pre_alloc(self, req: Req) -> torch.Tensor:
"""Pre-allocate the memory for req_to_token and token_kv_pool"""
req_pool_indices = self.req_to_token_pool.alloc(1)
assert req_pool_indices is not None
req.req_pool_idx = req_pool_indices[0]
kv_loc = self.token_to_kv_pool_allocator.alloc(
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
)
assert kv_loc is not None
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
# populate metadata
req.fill_ids = req.origin_input_ids + req.output_ids
req.extend_input_len = len(req.origin_input_ids)
return kv_loc
class DecodeTransferQueue:
"""
Store the requests that is polling kv
"""
def __init__(
self,
gloo_group: ProcessGroup,
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
metadata_buffers: torch.Tensor,
):
self.queue: List[DecodeRequest] = []
self.gloo_group = gloo_group
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
self.metadata_buffers = metadata_buffers
def add(self, req_conn: DecodeRequest) -> None:
self.queue.append(req_conn)
def extend(self, req_conns) -> None:
self.queue.extend(req_conns)
def pop_transferred(self) -> List[Req]:
if not self.queue:
return []
polls = poll_and_all_reduce(
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
)
transferred_reqs = []
indices_to_remove = set()
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
if poll == KVPoll.Failed:
raise Exception("Transfer failed")
elif poll == KVPoll.Success:
# pop and push it to waiting queue
idx = decode_req.metadata_buffer_index
assert len(decode_req.req.output_ids) == 0
output_id_buffer = self.metadata_buffers[0]
# the last dimension is padded by the same values.
output_id = output_id_buffer[idx][0].item()
assert len(decode_req.req.output_ids) == 0
assert decode_req.req.transferred_output_id is None
decode_req.req.transferred_output_id = output_id
transferred_reqs.append(decode_req.req)
indices_to_remove.add(i)
elif poll in [
KVPoll.Bootstrapping,
KVPoll.WaitingForInput,
KVPoll.Transferring,
]:
pass
else:
raise ValueError(f"Unexpected poll case: {poll}")
for i in indices_to_remove:
idx = self.queue[i].metadata_buffer_index
assert idx != -1
self.req_to_metadata_buffer_idx_allocator.free(idx)
self.queue = [
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
]
return transferred_reqs
class ScheduleBatchDisaggregationDecodeMixin:
def prepare_for_prebuilt_extend(self: ScheduleBatch):
"""
Prepare a prebuilt extend by populate metadata
Adapted from .prepare_for_extend().
"""
self.forward_mode = ForwardMode.EXTEND
reqs = self.reqs
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
extend_num_tokens = sum(len(ids) for ids in input_ids)
seq_lens = []
pre_lens = []
req_pool_indices = []
# Pre-calculate total size
total_size = sum(req.extend_input_len for req in reqs)
out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device)
# Fill the tensor in one pass
offset = 0
for i, req in enumerate(reqs):
req_pool_indices.append(req.req_pool_idx)
chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][
: req.extend_input_len
]
assert (
offset + req.extend_input_len <= total_size
), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}"
out_cache_loc[offset : offset + req.extend_input_len] = chunk
offset += req.extend_input_len
pre_len = len(req.prefix_indices)
seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1)
seq_lens.append(seq_len)
if len(req.output_ids) == 0:
assert (
seq_len - pre_len == req.extend_input_len
), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}"
req.cached_tokens += pre_len - req.already_computed
req.already_computed = seq_len
req.is_retracted = False
pre_lens.append(pre_len)
req.extend_logprob_start_len = 0
extend_input_logprob_token_ids = None
# Set fields
self.input_ids = torch.tensor(
sum(input_ids, []), dtype=torch.int32, device=self.device
)
self.req_pool_indices = torch.tensor(
req_pool_indices, dtype=torch.int64, device=self.device
)
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
self.out_cache_loc = out_cache_loc
self.seq_lens_sum = sum(seq_lens)
self.extend_num_tokens = extend_num_tokens
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
self.extend_lens = [r.extend_input_len for r in reqs]
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
# Build sampling info
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self,
self.model_config.vocab_size,
)
def process_prebuilt_extend(
self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
):
"""Assign the buffered last input id to schedule batch"""
self.output_ids = []
for req in self.reqs:
if req.output_ids and len(req.output_ids) > 0:
# resumed retracted req
self.output_ids.append(req.output_ids[-1])
else:
assert req.transferred_output_id is not None
req.output_ids.append(req.transferred_output_id)
self.output_ids.append(req.transferred_output_id)
self.tree_cache.cache_unfinished_req(req)
self.output_ids = torch.tensor(self.output_ids, device=self.device)
class SchedulerDisaggregationDecodeMixin:
def get_next_disagg_decode_batch_to_run(
self: Scheduler,
) -> Optional[Tuple[ScheduleBatch, bool]]:
"""Create fake completed prefill if possible and merge with running batch"""
# Merge the prefill batch into the running batch
last_batch = self.last_batch
if last_batch and last_batch.forward_mode.is_extend():
# chunked prefill doesn't happen in decode instance.
assert self.chunked_req is None
# Filter finished batches.
last_batch.filter_batch()
if not last_batch.is_empty():
if self.running_batch.is_empty():
self.running_batch = last_batch
else:
# merge running_batch with prefill batch
self.running_batch.merge_batch(last_batch)
new_prebuilt_batch = self.get_new_prebuilt_batch()
ret: Optional[ScheduleBatch] = None
if new_prebuilt_batch:
ret = new_prebuilt_batch
else:
if self.running_batch.is_empty():
ret = None
else:
self.running_batch = self.update_running_batch(self.running_batch)
ret = self.running_batch if not self.running_batch.is_empty() else None
return ret
def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
"""Create a schedulebatch for fake completed prefill"""
if len(self.waiting_queue) == 0:
return None
curr_batch_size = self.running_batch.batch_size()
batch_size = min(self.req_to_token_pool.size, self.max_running_requests)
num_not_used_batch = batch_size - curr_batch_size
# pop req from waiting queue
can_run_list: List[Req] = []
waiting_queue: List[Req] = []
for i in range(len(self.waiting_queue)):
req = self.waiting_queue[i]
# we can only add at least `num_not_used_batch` new batch to the running queue
if i < num_not_used_batch:
can_run_list.append(req)
req.init_next_round_input(self.tree_cache)
else:
waiting_queue.append(req)
self.waiting_queue = waiting_queue
if len(can_run_list) == 0:
return None
# local import to avoid circular import
from sglang.srt.managers.schedule_batch import ScheduleBatch
# construct a schedule batch with those requests and mark as decode
new_batch = ScheduleBatch.init_new(
can_run_list,
self.req_to_token_pool,
self.token_to_kv_pool_allocator,
self.tree_cache,
self.model_config,
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
)
# construct fake completed prefill
new_batch.prepare_for_prebuilt_extend()
new_batch.process_prebuilt_extend(self.server_args, self.model_config)
return new_batch
def process_decode_queue(self: Scheduler):
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
self.disagg_decode_transfer_queue.extend(req_conns)
alloc_reqs = (
self.disagg_decode_transfer_queue.pop_transferred()
) # the requests which kv has arrived
self.waiting_queue.extend(alloc_reqs)