66 lines
2.1 KiB
Python
66 lines
2.1 KiB
Python
import torch
|
|
|
|
from sglang.srt.sampling.penaltylib.orchestrator import (
|
|
BatchedPenalizerOrchestrator,
|
|
_BatchedPenalizer,
|
|
)
|
|
|
|
|
|
class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
|
"""
|
|
Frequency penalizer penalizes tokens based on their frequency in the output.
|
|
"""
|
|
|
|
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
|
self.orchestrator = orchestrator
|
|
self._is_prepared = False
|
|
|
|
def _is_required(self) -> bool:
|
|
return any(
|
|
req.sampling_params.frequency_penalty != 0.0
|
|
for req in self.orchestrator.reqs()
|
|
)
|
|
|
|
def _prepare(self):
|
|
self.cumulated_frequency_penalties = torch.zeros(
|
|
(len(self.orchestrator.reqs()), self.orchestrator.vocab_size),
|
|
dtype=torch.float32,
|
|
device=self.orchestrator.device,
|
|
)
|
|
|
|
self.frequency_penalties = (
|
|
torch.tensor(
|
|
data=[
|
|
req.sampling_params.frequency_penalty
|
|
for req in self.orchestrator.reqs()
|
|
],
|
|
dtype=torch.float32,
|
|
device=self.orchestrator.device,
|
|
)
|
|
).unsqueeze_(1)
|
|
|
|
def _cumulate_output_tokens(self, output_ids: torch.Tensor):
|
|
self.cumulated_frequency_penalties.scatter_add_(
|
|
dim=1,
|
|
index=output_ids.unsqueeze(1),
|
|
src=self.frequency_penalties,
|
|
)
|
|
|
|
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
|
logits.sub_(self.cumulated_frequency_penalties)
|
|
|
|
def _filter(self, keep_indices: torch.Tensor):
|
|
self.frequency_penalties = self.frequency_penalties[keep_indices]
|
|
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
|
|
keep_indices
|
|
]
|
|
|
|
def _merge(self, their: "BatchedFrequencyPenalizer"):
|
|
self.frequency_penalties = torch.cat(
|
|
[self.frequency_penalties, their.frequency_penalties], dim=0
|
|
)
|
|
self.cumulated_frequency_penalties = torch.cat(
|
|
[self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],
|
|
dim=0,
|
|
)
|