""" Copyright (c) 2025 by FlashInfer team. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from typing import Any, Optional, Tuple, Union import torch from flashinfer.sampling import get_sampling_module from flashinfer.utils import _get_cache_buf, device_support_pdl from .op import ParameterizedOp from .types import TaggedTensor, TensorType def _to_tensor_scalar_tuple( x: Union[torch.Tensor, float, int], ) -> Tuple[Optional[torch.Tensor], Union[float, int]]: if isinstance(x, torch.Tensor): return (x, 0 if x.dtype == torch.int32 else 0.0) else: return (None, x) class TemperatureOp(ParameterizedOp): """ Temperature scaling operator. :attr:`TensorType.LOGITS` -> :attr:`TensorType.LOGITS` Parameters ---------- temperature : float or torch.Tensor Temperature value for scaling. """ IN = TensorType.LOGITS OUT = TensorType.LOGITS def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: output_type = self._validate_input_type(tensor) temperature = self._get_param("temperature", kwargs, required=True) maybe_temperature_arr, temperature_val = _to_tensor_scalar_tuple(temperature) if maybe_temperature_arr is None and ( not isinstance(temperature_val, float) or temperature_val <= 0 ): raise ValueError("Temperature must be positive float or a tensor array") if maybe_temperature_arr is not None: temperature = maybe_temperature_arr else: temperature = temperature_val scaled_logits = tensor.data / temperature return TaggedTensor(scaled_logits, output_type) class SoftmaxOp(ParameterizedOp): """ Softmax operator. Converts logits to probabilities using softmax function. :attr:`TensorType.LOGITS` -> :attr:`TensorType.PROBS` Parameters ---------- enable_pdl: bool, optional Whether to enable PDL for the fused kernel. """ IN = TensorType.LOGITS OUT = TensorType.PROBS def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: output_type = self._validate_input_type(tensor) enable_pdl = self.default_params.get("enable_pdl", None) if enable_pdl is None: enable_pdl = device_support_pdl(tensor.data.device) probs = torch.softmax(tensor.data, dim=-1) return TaggedTensor(probs, output_type) class ProbsTopKOp(ParameterizedOp): """ Top-k filtering operator for probabilities. Keeps top-k probabilities, zeros out others, and renormalizes. :attr:`TensorType.PROBS` -> :attr:`TensorType.PROBS` Parameters ---------- top_k : int or torch.Tensor Number of top tokens to keep. See Also -------- :meth:`~flashinfer.sampling.top_k_renorm_probs` """ IN = TensorType.PROBS OUT = TensorType.PROBS def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: output_type = self._validate_input_type(tensor) top_k = self._get_param("top_k", kwargs, required=True) maybe_top_k_arr, top_k_val = _to_tensor_scalar_tuple(top_k) if maybe_top_k_arr is None and ( not isinstance(top_k_val, int) or top_k_val <= 0 ): raise ValueError("top_k must be a positive integer or a tensor array") renorm_probs = get_sampling_module().top_k_renorm_probs( tensor.data, maybe_top_k_arr, top_k_val ) return TaggedTensor(renorm_probs, output_type) class LogitsTopKOp(ParameterizedOp): """ Top-k filtering operator for logits. Masks rejected logits to -inf. :attr:`TensorType.LOGITS` -> :attr:`TensorType.LOGITS` Parameters ---------- top_k : int or torch.Tensor Number of top tokens to keep. See Also -------- :class:`~flashinfer.sampling.top_k_mask_logits` """ IN = TensorType.LOGITS OUT = TensorType.LOGITS def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: output_type = self._validate_input_type(tensor) top_k = self._get_param("top_k", kwargs, required=True) maybe_top_k_arr, top_k_val = _to_tensor_scalar_tuple(top_k) if maybe_top_k_arr is None and ( not isinstance(top_k_val, int) or top_k_val <= 0 ): raise ValueError("top_k must be a positive integer or a tensor array") masked_logits = get_sampling_module().top_k_mask_logits( tensor.data, maybe_top_k_arr, top_k_val ) return TaggedTensor(masked_logits, output_type) class TopPOp(ParameterizedOp): """ Top-p (nucleus) filtering operator. Keeps tokens with cumulative probability up to threshold p, zeros out others, and renormalizes. :attr:`TensorType.PROBS` -> :attr:`TensorType.PROBS` Parameters ---------- top_p : float or torch.Tensor Cumulative probability threshold in (0, 1]. See Also -------- :meth:`~flashinfer.sampling.top_p_renorm_probs` """ IN = TensorType.PROBS OUT = TensorType.PROBS def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: output_type = self._validate_input_type(tensor) top_p = self._get_param("top_p", kwargs, required=True) maybe_top_p_arr, top_p_val = _to_tensor_scalar_tuple(top_p) if maybe_top_p_arr is None and not (0 < top_p_val <= 1): raise ValueError("top_p must be float in (0, 1] or a tensor array") renorm_probs = get_sampling_module().top_p_renorm_probs( tensor.data, maybe_top_p_arr, top_p_val ) return TaggedTensor(renorm_probs, output_type) class MinPOp(ParameterizedOp): """ Min-p filtering operator. Keeps tokens with probability at least p times the maximum probability, zeros out others, and renormalizes. :attr:`TensorType.PROBS` -> :attr:`TensorType.PROBS` Parameters ---------- min_p : float or torch.Tensor Minimum probability threshold as ratio of max probability. See Also -------- :meth:`~flashinfer.sampling.min_p_renorm_probs` """ IN = TensorType.PROBS OUT = TensorType.PROBS def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: output_type = self._validate_input_type(tensor) min_p = self._get_param("min_p", kwargs, required=True) maybe_min_p_arr, min_p_val = _to_tensor_scalar_tuple(min_p) if maybe_min_p_arr is None and not (0 < min_p_val <= 1): raise ValueError("min_p must be float in (0, 1] or a tensor array") if maybe_min_p_arr is not None: min_p_mask = tensor.data >= ( maybe_min_p_arr.unsqueeze(-1) * tensor.data.max(dim=-1, keepdim=True)[0] ) else: min_p_mask = tensor.data >= ( min_p_val * tensor.data.max(dim=-1, keepdim=True)[0] ) masked_probs = tensor.data.clone() masked_probs[~min_p_mask] = 0 probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True) return TaggedTensor(probs, output_type) class ProbsSampleOp(ParameterizedOp): """ Sampling operator for probabilities. Samples token indices from probability distribution using inverse transform sampling. :attr:`TensorType.PROBS` -> :attr:`TensorType.INDICES` Parameters ---------- deterministic : bool, optional Whether to use deterministic kernel implementation. indices : torch.Tensor, optional Indices for batched sampling. generator : torch.Generator, optional Random number generator. See Also -------- :meth:`~flashinfer.sampling.sampling_from_probs` """ IN = TensorType.PROBS OUT = TensorType.INDICES def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: output_type = self._validate_input_type(tensor) deterministic = self.default_params.get("deterministic", True) indices = self._get_param("indices", kwargs, required=False) generator = self._get_param("generator", kwargs, required=False) samples = get_sampling_module().sampling_from_probs( tensor.data, indices, deterministic, generator ) return TaggedTensor(samples, output_type) class LogitsSampleOp(ParameterizedOp): """ Sampling operator for logits. Samples token indices from logits using Gumbel-max trick. :attr:`TensorType.LOGITS` -> :attr:`TensorType.INDICES` Parameters ---------- deterministic : bool, optional Whether to use deterministic kernel implementation. indices : torch.Tensor, optional Indices for batched sampling. generator : torch.Generator, optional Random number generator. See Also -------- :meth:`~flashinfer.sampling.sampling_from_logits` """ IN = TensorType.LOGITS OUT = TensorType.INDICES def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: output_type = self._validate_input_type(tensor) deterministic = self.default_params.get("deterministic", True) indices = self._get_param("indices", kwargs, required=False) generator = self._get_param("generator", kwargs, required=False) samples = get_sampling_module().sampling_from_logits( tensor.data, indices, deterministic, generator ) return TaggedTensor(samples, output_type) # Fused operators class FusedTemperatureSoftmaxOp(ParameterizedOp): """ Fused temperature scaling and softmax operator. :attr:`TensorType.LOGITS` -> :attr:`TensorType.PROBS` Parameters ---------- enable_pdl: bool, optional Whether to enable PDL for the fused kernel. temperature : float or torch.Tensor Temperature value for scaling. See Also -------- :meth:`~flashinfer.sampling.softmax` """ IN = TensorType.LOGITS OUT = TensorType.PROBS def __init__(self, enable_pdl: Optional[bool] = None, **default_params: Any): super().__init__(enable_pdl=enable_pdl, **default_params) def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: output_type = self._validate_input_type(tensor) temperature = self._get_param("temperature", kwargs, required=True) maybe_temperature_arr, temperature_val = _to_tensor_scalar_tuple(temperature) if maybe_temperature_arr is None and ( not isinstance(temperature_val, float) or temperature_val <= 0 ): raise ValueError("Temperature must be positive float or a tensor array") workspace_buffer = _get_cache_buf( "softmax_workspace", 1024 * 1024, tensor.data.device ) enable_pdl = self.default_params.get("enable_pdl", None) if enable_pdl is None: enable_pdl = device_support_pdl(tensor.data.device) probs = get_sampling_module().softmax( workspace_buffer, tensor.data, maybe_temperature_arr, temperature_val, enable_pdl, ) return TaggedTensor(probs, output_type) class FusedProbsTopKSampleOp(ParameterizedOp): """ Fused top-k filtering and sampling operator for probabilities. Use rejection sampling to directly sample from the top-k probabilities. :attr:`TensorType.PROBS` -> :attr:`TensorType.INDICES` Parameters ---------- deterministic : bool, optional Whether to use deterministic kernel implementation. top_k : int or torch.Tensor Number of top tokens to keep. indices : torch.Tensor, optional Indices for batched sampling. generator : torch.Generator, optional Random number generator. See Also -------- :meth:`~flashinfer.sampling.top_k_sampling_from_probs` """ IN = TensorType.PROBS OUT = TensorType.INDICES def __init__(self, deterministic: bool = True, **default_params: Any): super().__init__(deterministic=deterministic, **default_params) def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: output_type = self._validate_input_type(tensor) deterministic = self.default_params.get("deterministic", True) top_k = self._get_param("top_k", kwargs, required=True) maybe_top_k_arr, top_k_val = _to_tensor_scalar_tuple(top_k) if maybe_top_k_arr is None and ( not isinstance(top_k_val, int) or top_k_val <= 0 ): raise ValueError("top_k must be a positive integer or a tensor array") indices = self._get_param("indices", kwargs, required=False) generator = self._get_param("generator", kwargs, required=False) samples = get_sampling_module().top_k_sampling_from_probs( tensor.data, indices, maybe_top_k_arr, top_k_val, deterministic, generator ) return TaggedTensor(samples, output_type) class FusedProbsTopPSampleOp(ParameterizedOp): """ Fused top-p filtering and sampling operator for probabilities. Use rejection sampling to directly sample from the top-p probabilities. :attr:`TensorType.PROBS` -> :attr:`TensorType.INDICES` Parameters ---------- deterministic : bool, optional Whether to use deterministic kernel implementation. top_p : float or torch.Tensor Cumulative probability threshold. indices : torch.Tensor, optional Indices for batched sampling. generator : torch.Generator, optional Random number generator. See Also -------- :meth:`~flashinfer.sampling.top_p_sampling_from_probs` """ IN = TensorType.PROBS OUT = TensorType.INDICES def __init__(self, deterministic: bool = True, **default_params: Any): super().__init__(deterministic=deterministic, **default_params) def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: output_type = self._validate_input_type(tensor) deterministic = self.default_params.get("deterministic", True) top_p = self._get_param("top_p", kwargs, required=True) maybe_top_p_arr, top_p_val = _to_tensor_scalar_tuple(top_p) if maybe_top_p_arr is None and not (0 < top_p_val <= 1): raise ValueError("top_p must be float in (0, 1] or a tensor array") indices = self._get_param("indices", kwargs, required=False) generator = self._get_param("generator", kwargs, required=False) samples = get_sampling_module().top_p_sampling_from_probs( tensor.data, indices, maybe_top_p_arr, top_p_val, deterministic, generator ) return TaggedTensor(samples, output_type) class FusedProbsMinPSampleOp(ParameterizedOp): """ Fused min-p filtering and sampling operator for probabilities. Use rejection sampling to directly sample from the min-p probabilities. PROBS → INDICES Parameters ---------- deterministic : bool, optional Whether to use deterministic kernel implementation. min_p : float or torch.Tensor Minimum probability threshold. indices : torch.Tensor, optional Indices for batched sampling. generator : torch.Generator, optional Random number generator. See Also -------- :meth:`~flashinfer.sampling.min_p_sampling_from_probs` """ IN = TensorType.PROBS OUT = TensorType.INDICES def __init__(self, deterministic: bool = True, **default_params: Any): super().__init__(deterministic=deterministic, **default_params) def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: output_type = self._validate_input_type(tensor) deterministic = self.default_params.get("deterministic", True) min_p = self._get_param("min_p", kwargs, required=True) maybe_min_p_arr, min_p_val = _to_tensor_scalar_tuple(min_p) if maybe_min_p_arr is None and not (0 < min_p_val <= 1): raise ValueError("min_p must be float in (0, 1] or a tensor array") indices = self._get_param("indices", kwargs, required=False) generator = self._get_param("generator", kwargs, required=False) samples = get_sampling_module().min_p_sampling_from_probs( tensor.data, indices, maybe_min_p_arr, min_p_val, deterministic, generator ) return TaggedTensor(samples, output_type) class FusedProbsTopKTopPSampleOp(ParameterizedOp): """ Fused top-k, top-p filtering and sampling operator for probabilities. Use rejection sampling to directly sample from the probabilities, top-k and top-p filtering are applied jointly (rather than applying first -> renormalize -> second). :attr:`TensorType.PROBS` -> :attr:`TensorType.INDICES` Parameters ---------- deterministic : bool, optional Whether to use deterministic kernel implementation. top_k : int or torch.Tensor Number of top tokens to keep. top_p : float or torch.Tensor Cumulative probability threshold. indices : torch.Tensor, optional Indices for batched sampling. generator : torch.Generator, optional Random number generator. See Also -------- :meth:`~flashinfer.sampling.top_k_top_p_sampling_from_probs` """ IN = TensorType.PROBS OUT = TensorType.INDICES def __init__(self, deterministic: bool = True, **default_params: Any): super().__init__(deterministic=deterministic, **default_params) def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: output_type = self._validate_input_type(tensor) deterministic = self.default_params.get("deterministic", True) top_k = self._get_param("top_k", kwargs, required=True) maybe_top_k_arr, top_k_val = _to_tensor_scalar_tuple(top_k) top_p = self._get_param("top_p", kwargs, required=True) maybe_top_p_arr, top_p_val = _to_tensor_scalar_tuple(top_p) if maybe_top_k_arr is None and ( not isinstance(top_k_val, int) or top_k_val <= 0 ): raise ValueError("top_k must be a positive integer or a tensor array") if maybe_top_p_arr is None and not (0 < top_p_val <= 1): raise ValueError("top_p must be float in (0, 1] or a tensor array") indices = self._get_param("indices", kwargs, required=False) generator = self._get_param("generator", kwargs, required=False) samples = get_sampling_module().top_k_top_p_sampling_from_probs( tensor.data, indices, maybe_top_k_arr, top_k_val, maybe_top_p_arr, top_p_val, deterministic, generator, ) return TaggedTensor(samples, output_type)