313 lines
12 KiB
Python
313 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
import logging
|
|
import threading
|
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
|
|
import sglang.srt.sampling.penaltylib as penaltylib
|
|
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class SamplingBatchInfo:
|
|
# Basic batched sampling params
|
|
temperatures: torch.Tensor
|
|
top_ps: torch.Tensor
|
|
top_ks: torch.Tensor
|
|
min_ps: torch.Tensor
|
|
|
|
# Whether all requests use greedy sampling
|
|
is_all_greedy: bool
|
|
|
|
# Whether any request needs min_p sampling
|
|
need_min_p_sampling: bool
|
|
|
|
# Masking tensors for grammar-guided structured outputs
|
|
vocab_size: int
|
|
grammars: Optional[List] = None
|
|
vocab_mask: Optional[torch.Tensor] = None
|
|
apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
|
|
|
# An event used for overlap schedule
|
|
sampling_info_done: Optional[threading.Event] = None
|
|
|
|
# Penalizer
|
|
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
|
linear_penalty: torch.Tensor = None
|
|
|
|
# Whether any request has custom logit processor
|
|
has_custom_logit_processor: bool = False
|
|
# Custom parameters
|
|
custom_params: Optional[List[Optional[Dict[str, Any]]]] = None
|
|
# Custom logit processor
|
|
custom_logit_processor: Optional[
|
|
Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]
|
|
] = None
|
|
|
|
# Device
|
|
device: str = "cuda"
|
|
|
|
@classmethod
|
|
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
|
reqs = batch.reqs
|
|
device = batch.device
|
|
temperatures = (
|
|
torch.tensor(
|
|
[r.sampling_params.temperature for r in reqs],
|
|
dtype=torch.float,
|
|
)
|
|
.view(-1, 1)
|
|
.to(device, non_blocking=True)
|
|
)
|
|
top_ps = torch.tensor(
|
|
[r.sampling_params.top_p for r in reqs], dtype=torch.float
|
|
).to(device, non_blocking=True)
|
|
top_ks = torch.tensor(
|
|
[r.sampling_params.top_k for r in reqs], dtype=torch.int32
|
|
).to(device, non_blocking=True)
|
|
min_ps = torch.tensor(
|
|
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
|
).to(device, non_blocking=True)
|
|
|
|
# Check if any request has custom logit processor
|
|
has_custom_logit_processor = (
|
|
batch.enable_custom_logit_processor # check the flag first.
|
|
and any(r.custom_logit_processor for r in reqs) # then check the requests.
|
|
)
|
|
|
|
if has_custom_logit_processor:
|
|
# Merge the same type of custom logit processors together
|
|
processor_dict = {}
|
|
for i, r in enumerate(reqs):
|
|
if r.custom_logit_processor is None:
|
|
continue
|
|
processor_str = r.custom_logit_processor
|
|
if processor_str not in processor_dict:
|
|
processor_dict[processor_str] = []
|
|
processor_dict[processor_str].append(i)
|
|
|
|
merged_custom_logit_processor = {
|
|
hash(processor_str): (
|
|
# The deserialized custom logit processor object
|
|
CustomLogitProcessor.from_str(processor_str),
|
|
# The mask tensor for the requests that use this custom logit processor
|
|
torch.zeros(len(reqs), dtype=torch.bool)
|
|
.scatter_(0, torch.tensor(true_indices), True)
|
|
.to(device, non_blocking=True),
|
|
)
|
|
for processor_str, true_indices in processor_dict.items()
|
|
}
|
|
custom_params = [r.sampling_params.custom_params for r in reqs]
|
|
else:
|
|
merged_custom_logit_processor = None
|
|
custom_params = None
|
|
|
|
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
|
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
|
# should not add hefty computation overhead other than simple checks.
|
|
#
|
|
# While we can choose not to even create the class instances if they are not required, this
|
|
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
|
# handle {filter_batch()} and {merge_batch()} cases as well.
|
|
penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
|
vocab_size=vocab_size,
|
|
batch=batch,
|
|
penalizers={
|
|
penaltylib.BatchedFrequencyPenalizer,
|
|
penaltylib.BatchedMinNewTokensPenalizer,
|
|
penaltylib.BatchedPresencePenalizer,
|
|
},
|
|
)
|
|
|
|
ret = cls(
|
|
temperatures=temperatures,
|
|
top_ps=top_ps,
|
|
top_ks=top_ks,
|
|
min_ps=min_ps,
|
|
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
|
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
|
vocab_size=vocab_size,
|
|
penalizer_orchestrator=penalizer_orchestrator,
|
|
has_custom_logit_processor=has_custom_logit_processor,
|
|
custom_params=custom_params,
|
|
custom_logit_processor=merged_custom_logit_processor,
|
|
device=device,
|
|
)
|
|
return ret
|
|
|
|
def __len__(self):
|
|
return len(self.temperatures)
|
|
|
|
def update_regex_vocab_mask(self):
|
|
if not self.grammars:
|
|
self.vocab_mask = None
|
|
self.apply_mask_func = None
|
|
return
|
|
|
|
# Find a grammar from the list
|
|
first_grammar = next(grammar for grammar in self.grammars if grammar)
|
|
|
|
# TODO(lianmin): Maybe we can reuse the existing mask?
|
|
self.vocab_mask = first_grammar.allocate_vocab_mask(
|
|
vocab_size=self.vocab_size,
|
|
batch_size=len(self.temperatures),
|
|
device=self.device,
|
|
)
|
|
self.apply_mask_func = (
|
|
first_grammar.apply_vocab_mask
|
|
) # force to use static method
|
|
|
|
# Apply the mask
|
|
for i, grammar in enumerate(self.grammars):
|
|
if grammar and not grammar.finished:
|
|
grammar.fill_vocab_mask(self.vocab_mask, i)
|
|
|
|
# Move the mask to the device if needed
|
|
self.vocab_mask = first_grammar.move_vocab_mask(self.vocab_mask, self.device)
|
|
|
|
def update_penalties(self):
|
|
if self.penalizer_orchestrator.is_required:
|
|
self.linear_penalty = torch.zeros(
|
|
(len(self.temperatures), self.vocab_size),
|
|
dtype=torch.float32,
|
|
device=self.temperatures.device,
|
|
)
|
|
self.penalizer_orchestrator.apply(self.linear_penalty)
|
|
else:
|
|
self.linear_penalty = None
|
|
|
|
def apply_logits_bias(self, logits: torch.Tensor):
|
|
if self.linear_penalty is not None:
|
|
# Used in the overlap mode
|
|
logits.add_(self.linear_penalty)
|
|
|
|
if self.penalizer_orchestrator and self.penalizer_orchestrator.is_required:
|
|
# Used in the non-overlap mode
|
|
self.penalizer_orchestrator.apply(logits)
|
|
|
|
if self.vocab_mask is not None:
|
|
self.apply_mask_func(logits=logits, vocab_mask=self.vocab_mask)
|
|
|
|
def filter_batch(self, keep_indices: List[int], keep_indices_device: torch.Tensor):
|
|
self.penalizer_orchestrator.filter(keep_indices_device)
|
|
|
|
if self.has_custom_logit_processor:
|
|
self._filter_batch_custom_logit_processor(keep_indices, keep_indices_device)
|
|
|
|
for item in [
|
|
"temperatures",
|
|
"top_ps",
|
|
"top_ks",
|
|
"min_ps",
|
|
]:
|
|
value = getattr(self, item, None)
|
|
setattr(self, item, value[keep_indices_device])
|
|
|
|
def _filter_batch_custom_logit_processor(
|
|
self, keep_indices: List[int], keep_indices_device: torch.Tensor
|
|
):
|
|
"""Filter the custom logit processor and custom params"""
|
|
self.custom_logit_processor = {
|
|
k: (p, mask[keep_indices_device])
|
|
for k, (p, mask) in self.custom_logit_processor.items()
|
|
if torch.any(
|
|
mask[keep_indices_device]
|
|
) # ignore the custom logit processor whose mask is all False
|
|
}
|
|
self.custom_params = [self.custom_params[i] for i in keep_indices]
|
|
|
|
# If the custom logit processor is an empty dict, set the flag to False,
|
|
# and set the custom logit processor and custom params to None.
|
|
if len(self.custom_logit_processor) == 0:
|
|
self.custom_logit_processor = None
|
|
self.custom_params = None
|
|
self.has_custom_logit_processor = False
|
|
|
|
@staticmethod
|
|
def merge_custom_logit_processor(
|
|
lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
|
|
rhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
|
|
bs1: int,
|
|
bs2: int,
|
|
device: str,
|
|
):
|
|
if lhs is None and rhs is None:
|
|
return None
|
|
lhs, rhs = lhs or {}, rhs or {}
|
|
|
|
keys = set(lhs.keys()).union(set(rhs.keys()))
|
|
merged_dict = {}
|
|
|
|
for k in keys:
|
|
# Get the logit processor object
|
|
processor = lhs[k][0] if k in lhs else rhs[k][0]
|
|
# Get and merge the mask tensors from the two dicts
|
|
left_mask = (
|
|
lhs[k][1]
|
|
if k in lhs
|
|
else torch.zeros(bs1, dtype=torch.bool, device=device)
|
|
)
|
|
right_mask = (
|
|
rhs[k][1]
|
|
if k in rhs
|
|
else torch.zeros(bs2, dtype=torch.bool, device=device)
|
|
)
|
|
merged_dict[k] = (processor, torch.cat([left_mask, right_mask]))
|
|
|
|
assert merged_dict[k][1].shape[0] == bs1 + bs2, (
|
|
f"The batch size of merged mask ({merged_dict[k][1].shape[0]}) does not match "
|
|
f"the sum of the batch sizes of the two masks ({bs1 + bs2})"
|
|
f"\n{left_mask=}\n{right_mask=}\n{bs1=}\n{bs2=}"
|
|
f"\n{lhs=}\n{rhs=}"
|
|
)
|
|
|
|
return merged_dict
|
|
|
|
def merge_batch(self, other: "SamplingBatchInfo"):
|
|
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
|
|
|
# Merge the custom logit processors and custom params lists
|
|
if self.has_custom_logit_processor or other.has_custom_logit_processor:
|
|
# Merge the custom logit processors
|
|
self.custom_logit_processor = (
|
|
SamplingBatchInfo.merge_custom_logit_processor(
|
|
self.custom_logit_processor,
|
|
other.custom_logit_processor,
|
|
len(self),
|
|
len(other),
|
|
self.device,
|
|
)
|
|
)
|
|
# Merge the custom params lists
|
|
self.custom_params = self.custom_params or [None] * len(self)
|
|
other.custom_params = other.custom_params or [None] * len(other)
|
|
self.custom_params.extend(other.custom_params)
|
|
|
|
# Set the flag to True if any of the two has custom logit processor
|
|
self.has_custom_logit_processor = True
|
|
|
|
# Note: becasue the __len()__ operator is defined on the temperatures tensor,
|
|
# please make sure any merge operation with len(self) or len(other) is done before
|
|
# the merge operation of the temperatures tensor below.
|
|
for item in [
|
|
"temperatures",
|
|
"top_ps",
|
|
"top_ks",
|
|
"min_ps",
|
|
]:
|
|
self_val = getattr(self, item, None)
|
|
other_val = getattr(other, item, None)
|
|
setattr(self, item, torch.cat([self_val, other_val]))
|
|
|
|
self.is_all_greedy |= other.is_all_greedy
|
|
self.need_min_p_sampling |= other.need_min_p_sampling
|