496 lines
18 KiB
Python
496 lines
18 KiB
Python
"""
|
|
Life cycle of a request in the decode server
|
|
|
|
1. PreallocQueue:
|
|
a. Initialize a receiver for each request
|
|
b. The request handshakes first, and pre-allocate kv once there is available kv.
|
|
c. Move the request to TransferQueue.
|
|
|
|
2. TransferQueue:
|
|
a. Poll the receiver to check the transfer state
|
|
b. If the transfer has finished, move the request to waiting queue
|
|
|
|
3. WaitingQueue:
|
|
a. Use the requests in the queue to construct a PrebuiltExtendBatch
|
|
b. Skip the prefill forward but only populate metadata
|
|
|
|
4. RunningBatch:
|
|
a. Merge the resolved PrebuiltExtendBatch into running batch to run decoding
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
|
|
|
import torch
|
|
from torch.distributed import ProcessGroup
|
|
|
|
from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVReceiver
|
|
from sglang.srt.disaggregation.utils import (
|
|
ReqToMetadataIdxAllocator,
|
|
poll_and_all_reduce,
|
|
)
|
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.configs.model_config import ModelConfig
|
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
|
from sglang.srt.managers.scheduler import Scheduler
|
|
from sglang.srt.server_args import ServerArgs
|
|
|
|
|
|
@dataclass
|
|
class DecodeRequest:
|
|
req: Req
|
|
kv_receiver: KVReceiver
|
|
waiting_for_input: bool = False
|
|
metadata_buffer_index: int = -1
|
|
|
|
|
|
class DecodePreallocQueue:
|
|
"""
|
|
Store the requests that are preallocating.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
req_to_token_pool: ReqToTokenPool,
|
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
|
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
|
metadata_buffers: List[torch.Tensor],
|
|
aux_dtype: torch.dtype,
|
|
scheduler: Scheduler,
|
|
transfer_queue: DecodeTransferQueue,
|
|
tree_cache: BasePrefixCache,
|
|
gloo_group: ProcessGroup,
|
|
tp_rank: int,
|
|
tp_size: int,
|
|
bootstrap_port: int,
|
|
):
|
|
self.req_to_token_pool = req_to_token_pool
|
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
|
self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache()
|
|
self.aux_dtype = aux_dtype
|
|
self.metadata_buffers = metadata_buffers
|
|
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
|
self.scheduler = scheduler
|
|
self.transfer_queue = transfer_queue
|
|
self.tree_cache = tree_cache # this is always a chunk cache
|
|
self.gloo_group = gloo_group
|
|
self.tp_rank = tp_rank
|
|
self.tp_size = tp_size
|
|
self.bootstrap_port = bootstrap_port
|
|
|
|
self.num_reserved_decode_tokens = 512
|
|
|
|
# Queue for requests pending pre-allocation
|
|
self.queue: List[DecodeRequest] = []
|
|
self.kv_manager = self._init_kv_manager()
|
|
|
|
def _init_kv_manager(self) -> KVManager:
|
|
kv_args = KVArgs()
|
|
kv_args.engine_rank = self.tp_rank
|
|
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
|
self.token_to_kv_pool.get_contiguous_buf_infos()
|
|
)
|
|
|
|
kv_args.kv_data_ptrs = kv_data_ptrs
|
|
kv_args.kv_data_lens = kv_data_lens
|
|
kv_args.kv_item_lens = kv_item_lens
|
|
|
|
kv_args.aux_data_ptrs = [
|
|
output_id_tensor.data_ptr() for output_id_tensor in self.metadata_buffers
|
|
]
|
|
kv_args.aux_data_lens = [
|
|
metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
|
|
]
|
|
kv_args.aux_item_lens = [
|
|
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
|
]
|
|
kv_args.ib_device = "mock-ib-device"
|
|
kv_manager = KVManager(kv_args)
|
|
return kv_manager
|
|
|
|
def add(self, req: Req) -> None:
|
|
"""Add a request to the pending queue."""
|
|
|
|
kv_receiver = KVReceiver(
|
|
mgr=self.kv_manager,
|
|
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
|
|
bootstrap_room=req.bootstrap_room,
|
|
)
|
|
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
|
|
|
|
def extend(self, reqs: List[Req]) -> None:
|
|
"""Add a request to the pending queue."""
|
|
for req in reqs:
|
|
self.add(req)
|
|
|
|
def _update_handshake_waiters(self) -> None:
|
|
if not self.queue:
|
|
return
|
|
|
|
if all(decode_req.waiting_for_input for decode_req in self.queue):
|
|
return
|
|
|
|
polls = poll_and_all_reduce(
|
|
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
|
|
)
|
|
|
|
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
|
|
if poll == KVPoll.Bootstrapping:
|
|
pass
|
|
elif poll == KVPoll.WaitingForInput:
|
|
decode_req.waiting_for_input = True
|
|
elif poll == KVPoll.Failed:
|
|
raise Exception("Handshake failed")
|
|
|
|
def pop_preallocated(self) -> List[DecodeRequest]:
|
|
"""Pop the preallocated requests from the pending queue (FIFO)."""
|
|
self._update_handshake_waiters()
|
|
|
|
preallocated_reqs = []
|
|
indices_to_remove = set()
|
|
allocatable_tokens = self._allocatable_tokens()
|
|
|
|
for i, decode_req in enumerate(self.queue):
|
|
if not decode_req.waiting_for_input:
|
|
continue
|
|
|
|
if self.req_to_token_pool.available_size() <= 0:
|
|
break
|
|
|
|
if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0:
|
|
break
|
|
|
|
required_tokens_for_request = (
|
|
len(decode_req.req.origin_input_ids) + self.num_reserved_decode_tokens
|
|
)
|
|
|
|
if required_tokens_for_request > allocatable_tokens:
|
|
break
|
|
|
|
allocatable_tokens -= required_tokens_for_request
|
|
self._pre_alloc(decode_req.req)
|
|
|
|
kv_indices = (
|
|
self.req_to_token_pool.req_to_token[decode_req.req.req_pool_idx][
|
|
: len(decode_req.req.origin_input_ids)
|
|
]
|
|
.cpu()
|
|
.numpy()
|
|
)
|
|
|
|
decode_req.metadata_buffer_index = (
|
|
self.req_to_metadata_buffer_idx_allocator.alloc()
|
|
)
|
|
assert decode_req.metadata_buffer_index is not None
|
|
decode_req.kv_receiver.init(kv_indices, decode_req.metadata_buffer_index)
|
|
preallocated_reqs.append(decode_req)
|
|
indices_to_remove.add(i)
|
|
|
|
self.queue = [
|
|
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
|
]
|
|
|
|
return preallocated_reqs
|
|
|
|
def _allocatable_tokens(self) -> int:
|
|
allocatable_tokens = (
|
|
self.token_to_kv_pool_allocator.available_size()
|
|
- self.num_reserved_decode_tokens
|
|
* (
|
|
len(self.scheduler.running_batch.reqs)
|
|
+ len(self.transfer_queue.queue)
|
|
+ len(self.scheduler.waiting_queue)
|
|
)
|
|
)
|
|
|
|
# Note: if the last fake extend just finishes, and we enter `pop_preallocated` immediately in the next iteration
|
|
# the extend batch is not in any queue, so we need to explicitly add the tokens slots here
|
|
if (
|
|
self.scheduler.last_batch
|
|
and self.scheduler.last_batch.forward_mode.is_extend()
|
|
):
|
|
allocatable_tokens -= self.num_reserved_decode_tokens * len(
|
|
self.scheduler.last_batch.reqs
|
|
)
|
|
|
|
return allocatable_tokens
|
|
|
|
def _pre_alloc(self, req: Req) -> torch.Tensor:
|
|
"""Pre-allocate the memory for req_to_token and token_kv_pool"""
|
|
req_pool_indices = self.req_to_token_pool.alloc(1)
|
|
|
|
assert req_pool_indices is not None
|
|
|
|
req.req_pool_idx = req_pool_indices[0]
|
|
kv_loc = self.token_to_kv_pool_allocator.alloc(
|
|
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
|
)
|
|
|
|
assert kv_loc is not None
|
|
|
|
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
|
|
|
|
# populate metadata
|
|
req.fill_ids = req.origin_input_ids + req.output_ids
|
|
req.extend_input_len = len(req.origin_input_ids)
|
|
|
|
return kv_loc
|
|
|
|
|
|
class DecodeTransferQueue:
|
|
"""
|
|
Store the requests that is polling kv
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
gloo_group: ProcessGroup,
|
|
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
|
metadata_buffers: torch.Tensor,
|
|
):
|
|
self.queue: List[DecodeRequest] = []
|
|
self.gloo_group = gloo_group
|
|
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
|
self.metadata_buffers = metadata_buffers
|
|
|
|
def add(self, req_conn: DecodeRequest) -> None:
|
|
self.queue.append(req_conn)
|
|
|
|
def extend(self, req_conns) -> None:
|
|
self.queue.extend(req_conns)
|
|
|
|
def pop_transferred(self) -> List[Req]:
|
|
if not self.queue:
|
|
return []
|
|
|
|
polls = poll_and_all_reduce(
|
|
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
|
|
)
|
|
|
|
transferred_reqs = []
|
|
indices_to_remove = set()
|
|
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
|
|
if poll == KVPoll.Failed:
|
|
raise Exception("Transfer failed")
|
|
elif poll == KVPoll.Success:
|
|
# pop and push it to waiting queue
|
|
idx = decode_req.metadata_buffer_index
|
|
assert len(decode_req.req.output_ids) == 0
|
|
output_id_buffer = self.metadata_buffers[0]
|
|
# the last dimension is padded by the same values.
|
|
output_id = output_id_buffer[idx][0].item()
|
|
assert len(decode_req.req.output_ids) == 0
|
|
assert decode_req.req.transferred_output_id is None
|
|
decode_req.req.transferred_output_id = output_id
|
|
transferred_reqs.append(decode_req.req)
|
|
indices_to_remove.add(i)
|
|
elif poll in [
|
|
KVPoll.Bootstrapping,
|
|
KVPoll.WaitingForInput,
|
|
KVPoll.Transferring,
|
|
]:
|
|
pass
|
|
else:
|
|
raise ValueError(f"Unexpected poll case: {poll}")
|
|
|
|
for i in indices_to_remove:
|
|
idx = self.queue[i].metadata_buffer_index
|
|
assert idx != -1
|
|
self.req_to_metadata_buffer_idx_allocator.free(idx)
|
|
|
|
self.queue = [
|
|
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
|
]
|
|
|
|
return transferred_reqs
|
|
|
|
|
|
class ScheduleBatchDisaggregationDecodeMixin:
|
|
|
|
def prepare_for_prebuilt_extend(self: ScheduleBatch):
|
|
"""
|
|
Prepare a prebuilt extend by populate metadata
|
|
Adapted from .prepare_for_extend().
|
|
"""
|
|
|
|
self.forward_mode = ForwardMode.EXTEND
|
|
reqs = self.reqs
|
|
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
|
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
|
seq_lens = []
|
|
pre_lens = []
|
|
req_pool_indices = []
|
|
|
|
# Pre-calculate total size
|
|
total_size = sum(req.extend_input_len for req in reqs)
|
|
out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device)
|
|
|
|
# Fill the tensor in one pass
|
|
offset = 0
|
|
for i, req in enumerate(reqs):
|
|
req_pool_indices.append(req.req_pool_idx)
|
|
|
|
chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
|
: req.extend_input_len
|
|
]
|
|
assert (
|
|
offset + req.extend_input_len <= total_size
|
|
), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}"
|
|
out_cache_loc[offset : offset + req.extend_input_len] = chunk
|
|
offset += req.extend_input_len
|
|
|
|
pre_len = len(req.prefix_indices)
|
|
seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1)
|
|
seq_lens.append(seq_len)
|
|
if len(req.output_ids) == 0:
|
|
assert (
|
|
seq_len - pre_len == req.extend_input_len
|
|
), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}"
|
|
|
|
req.cached_tokens += pre_len - req.already_computed
|
|
req.already_computed = seq_len
|
|
req.is_retracted = False
|
|
pre_lens.append(pre_len)
|
|
req.extend_logprob_start_len = 0
|
|
|
|
extend_input_logprob_token_ids = None
|
|
|
|
# Set fields
|
|
self.input_ids = torch.tensor(
|
|
sum(input_ids, []), dtype=torch.int32, device=self.device
|
|
)
|
|
self.req_pool_indices = torch.tensor(
|
|
req_pool_indices, dtype=torch.int64, device=self.device
|
|
)
|
|
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
|
|
self.out_cache_loc = out_cache_loc
|
|
self.seq_lens_sum = sum(seq_lens)
|
|
self.extend_num_tokens = extend_num_tokens
|
|
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
|
self.extend_lens = [r.extend_input_len for r in reqs]
|
|
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
|
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
|
|
|
# Build sampling info
|
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
|
self,
|
|
self.model_config.vocab_size,
|
|
)
|
|
|
|
def process_prebuilt_extend(
|
|
self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
|
|
):
|
|
"""Assign the buffered last input id to schedule batch"""
|
|
self.output_ids = []
|
|
for req in self.reqs:
|
|
if req.output_ids and len(req.output_ids) > 0:
|
|
# resumed retracted req
|
|
self.output_ids.append(req.output_ids[-1])
|
|
else:
|
|
assert req.transferred_output_id is not None
|
|
req.output_ids.append(req.transferred_output_id)
|
|
self.output_ids.append(req.transferred_output_id)
|
|
self.tree_cache.cache_unfinished_req(req)
|
|
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
|
|
|
|
|
class SchedulerDisaggregationDecodeMixin:
|
|
|
|
def get_next_disagg_decode_batch_to_run(
|
|
self: Scheduler,
|
|
) -> Optional[Tuple[ScheduleBatch, bool]]:
|
|
"""Create fake completed prefill if possible and merge with running batch"""
|
|
# Merge the prefill batch into the running batch
|
|
last_batch = self.last_batch
|
|
if last_batch and last_batch.forward_mode.is_extend():
|
|
# chunked prefill doesn't happen in decode instance.
|
|
assert self.chunked_req is None
|
|
# Filter finished batches.
|
|
last_batch.filter_batch()
|
|
if not last_batch.is_empty():
|
|
if self.running_batch.is_empty():
|
|
self.running_batch = last_batch
|
|
else:
|
|
# merge running_batch with prefill batch
|
|
self.running_batch.merge_batch(last_batch)
|
|
|
|
new_prebuilt_batch = self.get_new_prebuilt_batch()
|
|
|
|
ret: Optional[ScheduleBatch] = None
|
|
if new_prebuilt_batch:
|
|
ret = new_prebuilt_batch
|
|
else:
|
|
if self.running_batch.is_empty():
|
|
ret = None
|
|
else:
|
|
self.running_batch = self.update_running_batch(self.running_batch)
|
|
ret = self.running_batch if not self.running_batch.is_empty() else None
|
|
|
|
return ret
|
|
|
|
def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
|
|
"""Create a schedulebatch for fake completed prefill"""
|
|
if len(self.waiting_queue) == 0:
|
|
return None
|
|
|
|
curr_batch_size = self.running_batch.batch_size()
|
|
|
|
batch_size = min(self.req_to_token_pool.size, self.max_running_requests)
|
|
|
|
num_not_used_batch = batch_size - curr_batch_size
|
|
|
|
# pop req from waiting queue
|
|
can_run_list: List[Req] = []
|
|
waiting_queue: List[Req] = []
|
|
|
|
for i in range(len(self.waiting_queue)):
|
|
req = self.waiting_queue[i]
|
|
# we can only add at least `num_not_used_batch` new batch to the running queue
|
|
if i < num_not_used_batch:
|
|
can_run_list.append(req)
|
|
req.init_next_round_input(self.tree_cache)
|
|
else:
|
|
waiting_queue.append(req)
|
|
|
|
self.waiting_queue = waiting_queue
|
|
if len(can_run_list) == 0:
|
|
return None
|
|
# local import to avoid circular import
|
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
|
|
|
# construct a schedule batch with those requests and mark as decode
|
|
new_batch = ScheduleBatch.init_new(
|
|
can_run_list,
|
|
self.req_to_token_pool,
|
|
self.token_to_kv_pool_allocator,
|
|
self.tree_cache,
|
|
self.model_config,
|
|
self.enable_overlap,
|
|
self.spec_algorithm,
|
|
self.server_args.enable_custom_logit_processor,
|
|
)
|
|
|
|
# construct fake completed prefill
|
|
new_batch.prepare_for_prebuilt_extend()
|
|
new_batch.process_prebuilt_extend(self.server_args, self.model_config)
|
|
|
|
return new_batch
|
|
|
|
def process_decode_queue(self: Scheduler):
|
|
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
|
|
self.disagg_decode_transfer_queue.extend(req_conns)
|
|
alloc_reqs = (
|
|
self.disagg_decode_transfer_queue.pop_transferred()
|
|
) # the requests which kv has arrived
|
|
self.waiting_queue.extend(alloc_reqs)
|