sglang_v0.5.2/flashinfer_0.3.1/flashinfer/logits_processor/processors.py

450 lines
14 KiB
Python

"""
Copyright (c) 2025 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from abc import ABC, abstractmethod
from typing import Any, List, Optional
from .op import Op
from .types import TensorType
class LogitsProcessor(ABC):
"""
LogitsProcessor defines high-level transformations that can be applied to
logits or probabilities. Each processor is automatically
legalized into low-level :class:`Op` or :class:`ParameterizedOp` that can be type-checked, validated, and
fused for optimal performance. Users can extend this class to implement their own processors.
Parameters
----------
**params : Any
Processor-specific parameters at compile-time.
Examples
--------
>>> import torch
>>> from flashinfer.logits_processor import LogitsPipe, TopK, Sample, TensorType
>>> torch.manual_seed(42)
>>>
>>> # Create a pipeline that legalizes to a fused op.
>>> pipe = LogitsPipe([
... TopK(), # Top-k filtering on logits
... Sample() # Sample from the filtered distribution
... ], input_type=TensorType.PROBS) # assume the input is probabilities
>>>
>>> pipe
LogitsPipe([TopK -> Sample], ops=[ProbsTopKOp -> ProbsSampleOp], compiled_ops=[FusedProbsTopKSampleOp])
Notes
-----
Subclasses must implement the :meth:`legalize` method to convert the high-level
processor into one or more low-level operators with specific input/output types
"""
def __init__(self, **params: Any):
"""
Initialize the processor.
Parameters
----------
**params : Any
Processor-specific parameters at compile-time.
"""
self.params = params
@abstractmethod
def legalize(self, input_type: TensorType) -> List[Op]:
"""
Legalize the processor into a list of low-level operators.
Parameters
----------
input_type : TensorType
The expected input tensor type of the processor.
Returns
-------
List[Op]
A list of low-level operators.
"""
raise NotImplementedError
def __repr__(self) -> str:
params_str = ", ".join(f"{k}={v}" for k, v in self.params.items())
return f"{self.__class__.__name__}({params_str})"
class Temperature(LogitsProcessor):
"""
Temperature scaling processor for logits.
Scales logits by dividing by a temperature value.
:attr:`TensorType.LOGITS` -> :attr:`TensorType.LOGITS`
Parameters
----------
temperature : float or torch.Tensor, Runtime
Temperature value for scaling. Must be positive. Can be a scalar or per-batch tensor.
Examples
--------
>>> import torch
>>> from flashinfer.logits_processor import LogitsPipe, Temperature, Sample
>>> torch.manual_seed(42)
>>> pipe = LogitsPipe([Temperature()])
>>> logits = torch.randn(2, 2, device="cuda")
>>> logits
tensor([[ 0.1940, 2.1614], [ -0.1721, 0.8491]], device='cuda:0')
>>> scaled_logits = pipe(logits, temperature=0.8)
>>> scaled_logits
tensor([[ 0.2425, 2.7017], [-0.2151, 1.0613]], device='cuda:0')
"""
def __init__(self, **params: Any):
"""
Constructor for Temperature processor. No compile-time parameters are needed.
"""
super().__init__(**params)
def legalize(self, input_type: TensorType) -> List[Op]:
"""
Legalize the processor into a list of low-level operators.
"""
from .operators import TemperatureOp
if input_type != TensorType.LOGITS:
raise ValueError(
f"Temperature can only be applied to LOGITS, got {input_type}"
)
return [TemperatureOp(**self.params)]
class Softmax(LogitsProcessor):
"""
Softmax processor to convert logits to probabilities.
Applies the softmax function.
:attr:`TensorType.LOGITS` -> :attr:`TensorType.PROBS`
Parameters
----------
enable_pdl : bool, optional, Compile-time
Whether to enable PDL for the kernel implementation.
Default is True.
Examples
--------
>>> import torch
>>> from flashinfer.logits_processor import LogitsPipe, Softmax, Sample
>>> torch.manual_seed(42)
>>> pipe = LogitsPipe([Softmax()])
>>> logits = torch.randn(2, 2, device="cuda")
>>> logits
tensor([[ 0.1940, 2.1614], [ -0.1721, 0.8491]], device='cuda:0')
>>> probs = pipe(logits)
>>> probs
tensor([[0.1227, 0.8773], [0.2648, 0.7352]], device='cuda:0')
Notes
-----
Can only appear once in a pipeline.
"""
def __init__(self, enable_pdl: Optional[bool] = None, **params: Any):
"""
Constructor for Softmax processor.
Parameters
----------
enable_pdl : bool, optional, Compile-time
Whether to enable PDL for the kernel implementation.
Default is None, which means the kernel will be automatically enabled if PDL is supported on the device.
"""
super().__init__(enable_pdl=enable_pdl, **params)
def legalize(self, input_type: TensorType) -> List[Op]:
"""
Legalize the processor into a list of low-level operators.
"""
from .operators import SoftmaxOp
if input_type != TensorType.LOGITS:
raise ValueError(f"Softmax can only be applied to LOGITS, got {input_type}")
return [SoftmaxOp(**self.params)]
class TopK(LogitsProcessor):
"""
Top-k filtering processor.
Keeps only the top-k highest probability tokens and masks out the rest.
:attr:`TensorType.LOGITS` -> :attr:`TensorType.LOGITS` | :attr:`TensorType.PROBS` -> :attr:`TensorType.PROBS`
Parameters
----------
joint_topk_topp : bool, optional, Compile-time
Whether to enable joint top-k and top-p filtering when followed by TopP.
Default is False.
top_k : int or torch.Tensor, Runtime
Number of top tokens to keep. Can be a scalar or per-batch tensor.
Examples
--------
>>> import torch
>>> from flashinfer.logits_processor import LogitsPipe, TopK, Sample, TensorType
>>> torch.manual_seed(42)
>>>
>>> # Top-k filtering on logits
>>> pipe = LogitsPipe([TopK()], input_type=TensorType.LOGITS)
>>> logits = torch.randn(2, 2, device="cuda")
>>> logits
tensor([[ 0.1940, 2.1614], [ -0.1721, 0.8491]], device='cuda:0')
>>> topk_logits = pipe(logits, top_k=1)
>>> topk_logits
tensor([[ -inf, 2.1614], [ -inf, 0.8491]], device='cuda:0')
>>>
>>> # Top-k filtering on probabilities
>>> pipe = LogitsPipe([TopK()], input_type=TensorType.PROBS)
>>> probs = torch.randn(2, 2, device="cuda")
>>> probs_normed = probs / probs.sum(dim=-1, keepdim=True)
>>> probs_normed
tensor([[ 4.4998, -3.4998], [-18.2893, 19.2893]], device='cuda:0')
>>> topk_probs = pipe(probs_normed, top_k=1)
>>> topk_probs
tensor([[1., 0.], [0., 1.]], device='cuda:0')
Notes
-----
When applied to :attr:`TensorType.LOGITS`, sets non-top-k values to -inf.
When applied to :attr:`TensorType.PROBS`, zeros out non-top-k values and renormalizes.
See Also
--------
:meth:`~flashinfer.sampling.top_k_mask_logits`
:meth:`~flashinfer.sampling.top_k_renorm_probs`
"""
def __init__(self, joint_topk_topp: bool = False, **params: Any):
"""
Constructor for TopK processor.
Parameters
----------
joint_topk_topp : bool, optional, Compile-time
Whether to enable joint top-k and top-p filtering when followed by TopP.
Default is False.
"""
super().__init__(joint_topk_topp=joint_topk_topp, **params)
def legalize(self, input_type: TensorType) -> List[Op]:
"""
Legalize the processor into a list of low-level operators.
"""
from .operators import LogitsTopKOp, ProbsTopKOp
if input_type == TensorType.LOGITS:
return [LogitsTopKOp(**self.params)]
elif input_type == TensorType.PROBS:
return [ProbsTopKOp(**self.params)]
else:
raise ValueError(f"TopK cannot be applied to {input_type}")
class TopP(LogitsProcessor):
"""
Top-p (nucleus) filtering processor.
Keeps tokens with cumulative probability up to threshold p.
:attr:`TensorType.PROBS` -> :attr:`TensorType.PROBS`
Parameters
----------
top_p : float or torch.Tensor, Runtime
Cumulative probability threshold in (0, 1]. Can be a scalar or per-batch tensor.
Examples
--------
>>> import torch
>>> from flashinfer.logits_processor import LogitsPipe, Softmax, TopP, Sample
>>> torch.manual_seed(42)
>>> pipe = LogitsPipe([TopP()])
>>> probs = torch.randn(2, 2, device="cuda")
>>> probs_normed = probs / probs.sum(dim=-1, keepdim=True)
>>> probs_normed
tensor([[ 0.0824, 0.9176], [-0.2541, 1.2541]], device='cuda:0')
>>> topp_probs = pipe(probs_normed, top_p=0.9)
>>> topp_probs
tensor([[0., 1.], [0., 1.]], device='cuda:0')
See Also
--------
:meth:`~flashinfer.sampling.top_p_renorm_probs`
"""
def __init__(self, **params: Any):
"""
Constructor for TopP processor. No compile-time parameters are needed.
"""
super().__init__(**params)
def legalize(self, input_type: TensorType) -> List[Op]:
"""
Legalize the processor into a list of low-level operators.
"""
from .operators import TopPOp
if input_type != TensorType.PROBS:
raise ValueError(f"TopP can only be applied to PROBS, got {input_type}")
return [TopPOp(**self.params)]
class MinP(LogitsProcessor):
"""
Min-p filtering processor.
Keeps tokens with probability at least p times the maximum probability.
:attr:`TensorType.PROBS` -> :attr:`TensorType.PROBS`
Parameters
----------
min_p : float or torch.Tensor, Runtime
Minimum probability threshold as a ratio of max probability.
Must be in (0, 1]. Can be a scalar or per-batch tensor.
Examples
--------
>>> import torch
>>> from flashinfer.logits_processor import LogitsPipe, Softmax, MinP, Sample
>>> torch.manual_seed(42)
>>> pipe = LogitsPipe([MinP()])
>>> probs = torch.randn(2, 2, device="cuda")
>>> probs_normed = probs / probs.sum(dim=-1, keepdim=True)
>>> probs_normed
tensor([[ 0.0824, 0.9176], [-0.2541, 1.2541]], device='cuda:0')
>>> minp_probs = pipe(probs_normed, min_p=0.05)
>>> minp_probs
tensor([[0.0824, 0.9176], [0.0000, 1.0000]], device='cuda:0')
"""
def __init__(self, **params: Any):
"""
Constructor for MinP processor. No compile-time parameters are needed.
"""
super().__init__(**params)
def legalize(self, input_type: TensorType) -> List[Op]:
"""
Legalize the processor into a list of low-level operators.
"""
from .operators import MinPOp
if input_type != TensorType.PROBS:
raise ValueError(f"MinP can only be applied to PROBS, got {input_type}")
return [MinPOp(**self.params)]
class Sample(LogitsProcessor):
"""
Sampling processor to generate token indices.
Samples tokens from logits or probability distributions.
:attr:`TensorType.LOGITS` -> :attr:`TensorType.INDICES` | :attr:`TensorType.PROBS` -> :attr:`TensorType.INDICES`
Parameters
----------
deterministic : bool, optional, Compile-time
Whether to use deterministic kernel implementation.
Default is True.
indices : torch.Tensor, optional, Runtime
Indices for batched sampling when probability tensors are shared.
generator : torch.Generator, optional, Runtime
Random number generator for reproducible sampling.
Examples
--------
>>> import torch
>>> from flashinfer.logits_processor import LogitsPipe, Sample, TensorType
>>> torch.manual_seed(42)
>>>
>>> # Sampling from logits
>>> pipe = LogitsPipe([Sample(deterministic=True)], input_type=TensorType.LOGITS)
>>> logits = torch.randn(2, 5, device="cuda")
>>> logits
tensor([[ 0.1940, 2.1614, -0.1721, 0.8491, -1.9244],
[ 0.6530, -0.6494, -0.8175, 0.5280, -1.2753]], device='cuda:0')
>>> tokens = pipe(logits, top_k=1)
>>> tokens
tensor([0, 1], device='cuda:0')
>>>
>>> # Sampling from probabilities
>>> pipe = LogitsPipe([Sample(deterministic=True)], input_type=TensorType.PROBS)
>>> probs = torch.randn(2, 5, device="cuda")
>>> probs_normed = probs / probs.sum(dim=-1, keepdim=True)
>>> probs_normed
tensor([[ 2.8827, 0.0870, 0.2340, -3.2731, 1.0694],
[ 0.3526, 0.0928, 0.1601, -0.1737, 0.5683]], device='cuda:0')
>>> tokens = pipe(probs_normed, top_k=1)
>>> tokens
tensor([0, 0], device='cuda:0')
Notes
-----
Outputs :attr:`TensorType.INDICES` - no operators can follow
See Also
--------
:meth:`~flashinfer.sampling.sampling_from_logits`
:meth:`~flashinfer.sampling.sampling_from_probs`
"""
def __init__(self, deterministic: bool = True, **params: Any):
"""
Constructor for Sample processor.
Parameters
----------
deterministic : bool, optional
Whether to use deterministic kernel implementation.
Default is True.
"""
super().__init__(deterministic=deterministic, **params)
def legalize(self, input_type: TensorType) -> List[Op]:
"""
Legalize the processor into a list of low-level operators.
"""
from .operators import LogitsSampleOp, ProbsSampleOp
if input_type == TensorType.LOGITS:
return [LogitsSampleOp(**self.params)]
elif input_type == TensorType.PROBS:
return [ProbsSampleOp(**self.params)]
else:
raise ValueError(f"Sampling cannot be applied to {input_type}")