import torch from sglang.srt.sampling.penaltylib.orchestrator import ( BatchedPenalizerOrchestrator, _BatchedPenalizer, ) class BatchedFrequencyPenalizer(_BatchedPenalizer): """ Frequency penalizer penalizes tokens based on their frequency in the output. """ def __init__(self, orchestrator: BatchedPenalizerOrchestrator): self.orchestrator = orchestrator self._is_prepared = False def _is_required(self) -> bool: return any( req.sampling_params.frequency_penalty != 0.0 for req in self.orchestrator.reqs() ) def _prepare(self): self.cumulated_frequency_penalties = torch.zeros( (len(self.orchestrator.reqs()), self.orchestrator.vocab_size), dtype=torch.float32, device=self.orchestrator.device, ) self.frequency_penalties = ( torch.tensor( data=[ req.sampling_params.frequency_penalty for req in self.orchestrator.reqs() ], dtype=torch.float32, device=self.orchestrator.device, ) ).unsqueeze_(1) def _cumulate_output_tokens(self, output_ids: torch.Tensor): self.cumulated_frequency_penalties.scatter_add_( dim=1, index=output_ids.unsqueeze(1), src=self.frequency_penalties, ) def _apply(self, logits: torch.Tensor) -> torch.Tensor: logits.sub_(self.cumulated_frequency_penalties) def _filter(self, keep_indices: torch.Tensor): self.frequency_penalties = self.frequency_penalties[keep_indices] self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[ keep_indices ] def _merge(self, their: "BatchedFrequencyPenalizer"): self.frequency_penalties = torch.cat( [self.frequency_penalties, their.frequency_penalties], dim=0 ) self.cumulated_frequency_penalties = torch.cat( [self.cumulated_frequency_penalties, their.cumulated_frequency_penalties], dim=0, )