sglang0.4.5.post1/python/sglang/srt/managers/scheduler_output_processor_...

615 lines
26 KiB
Python

from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import BatchEmbeddingOut, BatchTokenIDOut
from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
if TYPE_CHECKING:
from sglang.srt.managers.scheduler import (
EmbeddingBatchResult,
GenerationBatchResult,
ScheduleBatch,
)
class SchedulerOutputProcessorMixin:
"""
This class implements the output processing logic for Scheduler.
We put them into a separate file to make the `scheduler.py` shorter.
"""
def process_batch_result_prefill(
self,
batch: ScheduleBatch,
result: Union[GenerationBatchResult, EmbeddingBatchResult],
):
skip_stream_req = None
if self.is_generation:
(
logits_output,
next_token_ids,
extend_input_len_per_req,
extend_logprob_start_len_per_req,
bid,
) = (
result.logits_output,
result.next_token_ids,
result.extend_input_len_per_req,
result.extend_logprob_start_len_per_req,
result.bid,
)
if self.enable_overlap:
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
else:
# Move next_token_ids and logprobs to cpu
next_token_ids = next_token_ids.tolist()
if batch.return_logprob:
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist()
)
hidden_state_offset = 0
# Check finish conditions
logprob_pt = 0
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
if req.is_retracted:
continue
if self.is_mixed_chunk and self.enable_overlap and req.finished():
# Free the one delayed token for the mixed decode batch
j = len(batch.out_cache_loc) - len(batch.reqs) + i
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
continue
if req.is_chunked <= 0:
# req output_ids are set here
req.output_ids.append(next_token_id)
req.check_finished()
if req.finished():
self.tree_cache.cache_finished_req(req)
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
# This updates radix so others can match
self.tree_cache.cache_unfinished_req(req)
if req.return_logprob:
assert extend_logprob_start_len_per_req is not None
assert extend_input_len_per_req is not None
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
extend_input_len = extend_input_len_per_req[i]
num_input_logprobs = extend_input_len - extend_logprob_start_len
self.add_logprob_return_values(
i,
req,
logprob_pt,
next_token_ids,
num_input_logprobs,
logits_output,
)
logprob_pt += num_input_logprobs
if (
req.return_hidden_states
and logits_output.hidden_states is not None
):
req.hidden_states.append(
logits_output.hidden_states[
hidden_state_offset : (
hidden_state_offset := hidden_state_offset
+ len(req.origin_input_ids)
)
]
.cpu()
.clone()
.tolist()
)
if req.grammar is not None:
req.grammar.accept_token(next_token_id)
req.grammar.finished = req.finished()
else:
# being chunked reqs' prefill is not finished
req.is_chunked -= 1
# There is only at most one request being currently chunked.
# Because this request does not finish prefill,
# we don't want to stream the request currently being chunked.
skip_stream_req = req
# Incrementally update input logprobs.
if req.return_logprob:
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
extend_input_len = extend_input_len_per_req[i]
if extend_logprob_start_len < extend_input_len:
# Update input logprobs.
num_input_logprobs = (
extend_input_len - extend_logprob_start_len
)
self.add_input_logprob_return_values(
i,
req,
logits_output,
logprob_pt,
num_input_logprobs,
last_prefill_chunk=False,
)
logprob_pt += num_input_logprobs
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
else: # embedding or reward model
embeddings, bid = result.embeddings, result.bid
embeddings = embeddings.tolist()
# Check finish conditions
for i, req in enumerate(batch.reqs):
if req.is_retracted:
continue
req.embedding = embeddings[i]
if req.is_chunked <= 0:
# Dummy output token for embedding models
req.output_ids.append(0)
req.check_finished()
if req.finished():
self.tree_cache.cache_finished_req(req)
else:
self.tree_cache.cache_unfinished_req(req)
else:
# being chunked reqs' prefill is not finished
req.is_chunked -= 1
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
def process_batch_result_decode(
self,
batch: ScheduleBatch,
result: GenerationBatchResult,
):
logits_output, next_token_ids, bid = (
result.logits_output,
result.next_token_ids,
result.bid,
)
self.num_generated_tokens += len(batch.reqs)
if self.enable_overlap:
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
next_token_logprobs = logits_output.next_token_logprobs
elif batch.spec_algorithm.is_none():
# spec decoding handles output logprobs inside verify process.
next_token_ids = next_token_ids.tolist()
if batch.return_logprob:
next_token_logprobs = logits_output.next_token_logprobs.tolist()
self.token_to_kv_pool_allocator.free_group_begin()
# Check finish condition
# NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
# We should ignore using next_token_ids for spec decoding cases.
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
if req.is_retracted:
continue
if self.enable_overlap and req.finished():
# Free the one extra delayed token
if self.page_size == 1:
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
else:
# Only free when the extra token is in a new page
if (
len(req.origin_input_ids) + len(req.output_ids) - 1
) % self.page_size == 0:
self.token_to_kv_pool_allocator.free(
batch.out_cache_loc[i : i + 1]
)
continue
if batch.spec_algorithm.is_none():
# speculative worker will solve the output_ids in speculative decoding
req.output_ids.append(next_token_id)
req.check_finished()
if req.finished():
self.tree_cache.cache_finished_req(req)
if req.return_logprob and batch.spec_algorithm.is_none():
# speculative worker handles logprob in speculative decoding
req.output_token_logprobs_val.append(next_token_logprobs[i])
req.output_token_logprobs_idx.append(next_token_id)
if req.top_logprobs_num > 0:
req.output_top_logprobs_val.append(
logits_output.next_token_top_logprobs_val[i]
)
req.output_top_logprobs_idx.append(
logits_output.next_token_top_logprobs_idx[i]
)
if req.token_ids_logprob is not None:
req.output_token_ids_logprobs_val.append(
logits_output.next_token_token_ids_logprobs_val[i]
)
req.output_token_ids_logprobs_idx.append(
logits_output.next_token_token_ids_logprobs_idx[i]
)
if req.return_hidden_states and logits_output.hidden_states is not None:
req.hidden_states.append(
logits_output.hidden_states[i].cpu().clone().tolist()
)
if req.grammar is not None and batch.spec_algorithm.is_none():
req.grammar.accept_token(next_token_id)
req.grammar.finished = req.finished()
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
self.stream_output(batch.reqs, batch.return_logprob)
self.token_to_kv_pool_allocator.free_group_end()
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
if (
self.attn_tp_rank == 0
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
):
self.log_decode_stats()
def add_input_logprob_return_values(
self,
i: int,
req: Req,
output: LogitsProcessorOutput,
logprob_pt: int,
num_input_logprobs: int,
last_prefill_chunk: bool, # If True, it means prefill is finished.
):
"""Incrementally add input logprobs to `req`.
Args:
i: The request index in a batch.
req: The request. Input logprobs inside req are modified as a
consequence of the API
fill_ids: The prefill ids processed.
output: Logit processor output that's used to compute input logprobs
last_prefill_chunk: True if it is the last prefill (when chunked).
Some of input logprob operation should only happen at the last
prefill (e.g., computing input token logprobs).
"""
assert output.input_token_logprobs is not None
if req.input_token_logprobs is None:
req.input_token_logprobs = []
if req.temp_input_top_logprobs_val is None:
req.temp_input_top_logprobs_val = []
if req.temp_input_top_logprobs_idx is None:
req.temp_input_top_logprobs_idx = []
if req.temp_input_token_ids_logprobs_val is None:
req.temp_input_token_ids_logprobs_val = []
if req.temp_input_token_ids_logprobs_idx is None:
req.temp_input_token_ids_logprobs_idx = []
if req.input_token_logprobs_val is not None:
# The input logprob has been already computed. It only happens
# upon retract.
if req.top_logprobs_num > 0:
assert req.input_token_logprobs_val is not None
return
# Important for the performance.
assert isinstance(output.input_token_logprobs, tuple)
input_token_logprobs: Tuple[int] = output.input_token_logprobs
input_token_logprobs = input_token_logprobs[
logprob_pt : logprob_pt + num_input_logprobs
]
req.input_token_logprobs.extend(input_token_logprobs)
if req.top_logprobs_num > 0:
req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i])
req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i])
if req.token_ids_logprob is not None:
req.temp_input_token_ids_logprobs_val.append(
output.input_token_ids_logprobs_val[i]
)
req.temp_input_token_ids_logprobs_idx.append(
output.input_token_ids_logprobs_idx[i]
)
if last_prefill_chunk:
input_token_logprobs = req.input_token_logprobs
req.input_token_logprobs = None
assert req.input_token_logprobs_val is None
assert req.input_token_logprobs_idx is None
assert req.input_top_logprobs_val is None
assert req.input_top_logprobs_idx is None
# Compute input_token_logprobs_val
# Always pad the first one with None.
req.input_token_logprobs_val = [None]
req.input_token_logprobs_val.extend(input_token_logprobs)
# The last input logprob is for sampling, so just pop it out.
req.input_token_logprobs_val.pop()
# Compute input_token_logprobs_idx
input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
# Clip the padded hash values from image tokens.
# Otherwise, it will lead to detokenization errors.
input_token_logprobs_idx = [
x if x < self.model_config.vocab_size - 1 else 0
for x in input_token_logprobs_idx
]
req.input_token_logprobs_idx = input_token_logprobs_idx
if req.top_logprobs_num > 0:
req.input_top_logprobs_val = [None]
req.input_top_logprobs_idx = [None]
assert len(req.temp_input_token_ids_logprobs_val) == len(
req.temp_input_token_ids_logprobs_idx
)
for val, idx in zip(
req.temp_input_top_logprobs_val,
req.temp_input_top_logprobs_idx,
strict=True,
):
req.input_top_logprobs_val.extend(val)
req.input_top_logprobs_idx.extend(idx)
# Last token is a sample token.
req.input_top_logprobs_val.pop()
req.input_top_logprobs_idx.pop()
req.temp_input_top_logprobs_idx = None
req.temp_input_top_logprobs_val = None
if req.token_ids_logprob is not None:
req.input_token_ids_logprobs_val = [None]
req.input_token_ids_logprobs_idx = [None]
for val, idx in zip(
req.temp_input_token_ids_logprobs_val,
req.temp_input_token_ids_logprobs_idx,
strict=True,
):
req.input_token_ids_logprobs_val.extend(val)
req.input_token_ids_logprobs_idx.extend(idx)
# Last token is a sample token.
req.input_token_ids_logprobs_val.pop()
req.input_token_ids_logprobs_idx.pop()
req.temp_input_token_ids_logprobs_idx = None
req.temp_input_token_ids_logprobs_val = None
if req.return_logprob:
relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len
assert len(req.input_token_logprobs_val) == relevant_tokens_len
assert len(req.input_token_logprobs_idx) == relevant_tokens_len
if req.top_logprobs_num > 0:
assert len(req.input_top_logprobs_val) == relevant_tokens_len
assert len(req.input_top_logprobs_idx) == relevant_tokens_len
if req.token_ids_logprob is not None:
assert len(req.input_token_ids_logprobs_val) == relevant_tokens_len
assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
def add_logprob_return_values(
self,
i: int,
req: Req,
pt: int,
next_token_ids: List[int],
num_input_logprobs: int,
output: LogitsProcessorOutput,
):
"""Attach logprobs to the return values."""
req.output_token_logprobs_val.append(output.next_token_logprobs[i])
req.output_token_logprobs_idx.append(next_token_ids[i])
self.add_input_logprob_return_values(
i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
)
if req.top_logprobs_num > 0:
req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
if req.token_ids_logprob is not None:
req.output_token_ids_logprobs_val.append(
output.next_token_token_ids_logprobs_val[i]
)
req.output_token_ids_logprobs_idx.append(
output.next_token_token_ids_logprobs_idx[i]
)
return num_input_logprobs
def stream_output(
self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
):
"""Stream the output to detokenizer."""
if self.is_generation:
self.stream_output_generation(reqs, return_logprob, skip_req)
else: # embedding or reward model
self.stream_output_embedding(reqs)
def stream_output_generation(
self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
):
rids = []
finished_reasons: List[BaseFinishReason] = []
decoded_texts = []
decode_ids_list = []
read_offsets = []
output_ids = []
skip_special_tokens = []
spaces_between_special_tokens = []
no_stop_trim = []
prompt_tokens = []
completion_tokens = []
cached_tokens = []
spec_verify_ct = []
output_hidden_states = None
if return_logprob:
input_token_logprobs_val = []
input_token_logprobs_idx = []
output_token_logprobs_val = []
output_token_logprobs_idx = []
input_top_logprobs_val = []
input_top_logprobs_idx = []
output_top_logprobs_val = []
output_top_logprobs_idx = []
input_token_ids_logprobs_val = []
input_token_ids_logprobs_idx = []
output_token_ids_logprobs_val = []
output_token_ids_logprobs_idx = []
else:
input_token_logprobs_val = input_token_logprobs_idx = (
output_token_logprobs_val
) = output_token_logprobs_idx = input_top_logprobs_val = (
input_top_logprobs_idx
) = output_top_logprobs_val = output_top_logprobs_idx = (
input_token_ids_logprobs_val
) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = (
output_token_ids_logprobs_idx
) = None
for req in reqs:
if req is skip_req:
continue
# Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
if self.model_config.is_multimodal_gen and req.to_abort:
continue
if (
req.finished()
# If stream, follow the given stream_interval
or (req.stream and len(req.output_ids) % self.stream_interval == 0)
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
# TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
# always increase one-by-one.
or (
not req.stream
and len(req.output_ids) % 50 == 0
and not self.model_config.is_multimodal_gen
)
):
rids.append(req.rid)
finished_reasons.append(
req.finished_reason.to_json() if req.finished_reason else None
)
decoded_texts.append(req.decoded_text)
decode_ids, read_offset = req.init_incremental_detokenize()
decode_ids_list.append(decode_ids)
read_offsets.append(read_offset)
if self.skip_tokenizer_init:
output_ids.append(req.output_ids)
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens
)
no_stop_trim.append(req.sampling_params.no_stop_trim)
prompt_tokens.append(len(req.origin_input_ids))
completion_tokens.append(len(req.output_ids))
cached_tokens.append(req.cached_tokens)
if not self.spec_algorithm.is_none():
spec_verify_ct.append(req.spec_verify_ct)
if return_logprob:
input_token_logprobs_val.append(req.input_token_logprobs_val)
input_token_logprobs_idx.append(req.input_token_logprobs_idx)
output_token_logprobs_val.append(req.output_token_logprobs_val)
output_token_logprobs_idx.append(req.output_token_logprobs_idx)
input_top_logprobs_val.append(req.input_top_logprobs_val)
input_top_logprobs_idx.append(req.input_top_logprobs_idx)
output_top_logprobs_val.append(req.output_top_logprobs_val)
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
input_token_ids_logprobs_val.append(
req.input_token_ids_logprobs_val
)
input_token_ids_logprobs_idx.append(
req.input_token_ids_logprobs_idx
)
output_token_ids_logprobs_val.append(
req.output_token_ids_logprobs_val
)
output_token_ids_logprobs_idx.append(
req.output_token_ids_logprobs_idx
)
if req.return_hidden_states:
if output_hidden_states is None:
output_hidden_states = []
output_hidden_states.append(req.hidden_states)
# Send to detokenizer
if rids:
if self.model_config.is_multimodal_gen:
return
self.send_to_detokenizer.send_pyobj(
BatchTokenIDOut(
rids,
finished_reasons,
decoded_texts,
decode_ids_list,
read_offsets,
output_ids,
skip_special_tokens,
spaces_between_special_tokens,
no_stop_trim,
prompt_tokens,
completion_tokens,
cached_tokens,
spec_verify_ct,
input_token_logprobs_val,
input_token_logprobs_idx,
output_token_logprobs_val,
output_token_logprobs_idx,
input_top_logprobs_val,
input_top_logprobs_idx,
output_top_logprobs_val,
output_top_logprobs_idx,
input_token_ids_logprobs_val,
input_token_ids_logprobs_idx,
output_token_ids_logprobs_val,
output_token_ids_logprobs_idx,
output_hidden_states,
)
)
def stream_output_embedding(self, reqs: List[Req]):
rids = []
finished_reasons: List[BaseFinishReason] = []
embeddings = []
prompt_tokens = []
cached_tokens = []
for req in reqs:
if req.finished():
rids.append(req.rid)
finished_reasons.append(req.finished_reason.to_json())
embeddings.append(req.embedding)
prompt_tokens.append(len(req.origin_input_ids))
cached_tokens.append(req.cached_tokens)
self.send_to_detokenizer.send_pyobj(
BatchEmbeddingOut(
rids, finished_reasons, embeddings, prompt_tokens, cached_tokens
)
)