sglang_v0.5.2/flashinfer_0.3.1/flashinfer/sampling.py

1416 lines
54 KiB
Python

"""
Copyright (c) 2024 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import functools
from types import SimpleNamespace
from typing import Optional, Union
import torch
from .jit import JitSpec
from .jit import env as jit_env
from .jit import gen_jit_spec
from .utils import (
_get_cache_buf,
device_support_pdl,
register_custom_op,
register_fake_op,
)
def gen_sampling_module() -> JitSpec:
return gen_jit_spec(
"sampling",
[
jit_env.FLASHINFER_CSRC_DIR / "sampling.cu",
jit_env.FLASHINFER_CSRC_DIR / "renorm.cu",
jit_env.FLASHINFER_CSRC_DIR / "flashinfer_sampling_ops.cu",
],
)
@functools.cache
def get_sampling_module():
module = gen_sampling_module().build_and_load()
@register_custom_op("flashinfer::softmax", mutates_args=("workspace_buffer",))
def softmax(
workspace_buffer: torch.Tensor,
logits: torch.Tensor,
maybe_temperature_arr: Optional[torch.Tensor],
temperature_val: float,
enable_pdl: bool,
) -> torch.Tensor:
logits = logits.float()
probs = torch.empty_like(logits, device=logits.device)
maybe_temperature_arr = (
maybe_temperature_arr.float() if maybe_temperature_arr is not None else None
)
module.softmax.default(
workspace_buffer,
logits,
probs,
maybe_temperature_arr,
temperature_val,
enable_pdl,
)
return probs
@register_fake_op("flashinfer::softmax")
def _fake_softmax(
workspace_buffer: torch.Tensor,
logits: torch.Tensor,
maybe_temperature_arr: Optional[torch.Tensor],
temperature_val: float,
enable_pdl: bool,
) -> torch.Tensor:
return torch.empty_like(logits, device=logits.device, dtype=torch.float32)
# torch library for sampling_from_logits
@register_custom_op("flashinfer::sampling_from_logits", mutates_args=())
def sampling_from_logits(
logits: torch.Tensor,
indices: Optional[torch.Tensor],
deterministic: bool,
generator: Optional[torch.Generator],
) -> torch.Tensor:
device = logits.device
# TODO: support more data types in logits to avoid conversion
# to float32
logits = logits.float()
batch_size = indices.size(0) if indices is not None else logits.size(0)
samples = torch.empty(batch_size, dtype=torch.int32, device=device)
module.sampling_from_logits.default(
logits,
samples,
indices,
deterministic,
generator,
)
return samples
@register_fake_op("flashinfer::sampling_from_logits")
def _fake_sampling_from_logits(
logits: torch.Tensor,
indices: Optional[torch.Tensor],
deterministic: bool,
generator: Optional[torch.Generator],
) -> torch.Tensor:
batch_size = indices.size(0) if indices is not None else logits.size(0)
return torch.empty(batch_size, dtype=torch.int32, device=logits.device)
# torch library for sampling_from_probs
@register_custom_op("flashinfer::sampling_from_probs", mutates_args=())
def sampling_from_probs(
probs: torch.Tensor,
indices: Optional[torch.Tensor],
deterministic: bool,
generator: Optional[torch.Generator],
) -> torch.Tensor:
device = probs.device
probs = probs.float()
batch_size = indices.size(0) if indices is not None else probs.size(0)
samples = torch.empty(batch_size, dtype=torch.int32, device=device)
module.sampling_from_probs.default(
probs,
samples,
indices,
deterministic,
generator,
)
return samples
# torch library for sampling_from_probs
@register_fake_op("flashinfer::sampling_from_probs")
def _fake_sampling_from_probs(
probs: torch.Tensor,
indices: Optional[torch.Tensor],
deterministic: bool,
generator: Optional[torch.Generator],
) -> torch.Tensor:
batch_size = indices.size(0) if indices is not None else probs.size(0)
return torch.empty(batch_size, dtype=torch.int32, device=probs.device)
# torch library for top_p_sampling_from_probs
@register_custom_op("flashinfer::top_p_sampling_from_probs", mutates_args=())
def top_p_sampling_from_probs(
probs: torch.Tensor,
indices: Optional[torch.Tensor],
maybe_top_p_arr: Optional[torch.Tensor],
top_p_val: float,
deterministic: bool,
generator: Optional[torch.Generator],
) -> torch.Tensor:
device = probs.device
probs = probs.float()
maybe_top_p_arr = (
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
)
batch_size = indices.size(0) if indices is not None else probs.size(0)
samples = torch.empty(batch_size, dtype=torch.int32, device=device)
module.top_p_sampling_from_probs.default(
probs,
samples,
indices,
maybe_top_p_arr,
top_p_val,
deterministic,
generator,
)
return samples
@register_fake_op("flashinfer::top_p_sampling_from_probs")
def _fake_top_p_sampling_from_probs(
probs: torch.Tensor,
indices: Optional[torch.Tensor],
maybe_top_p_arr: Optional[torch.Tensor],
top_p_val: float,
deterministic: bool,
generator: Optional[torch.Generator],
) -> torch.Tensor:
sample = torch.empty(probs.size(0), dtype=torch.int32, device=probs.device)
return sample
# torch library for top_k_sampling_from_probs
@register_custom_op("flashinfer::top_k_sampling_from_probs", mutates_args=())
def top_k_sampling_from_probs(
probs: torch.Tensor,
indices: Optional[torch.Tensor],
maybe_top_k_arr: Optional[torch.Tensor],
top_k_val: int,
deterministic: bool,
generator: Optional[torch.Generator],
) -> torch.Tensor:
device = probs.device
probs = probs.float()
batch_size = indices.size(0) if indices is not None else probs.size(0)
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
samples = torch.empty(batch_size, dtype=torch.int32, device=device)
module.top_k_sampling_from_probs.default(
probs,
samples,
indices,
maybe_top_k_arr,
top_k_val,
deterministic,
generator,
)
return samples
@register_fake_op("flashinfer::top_k_sampling_from_probs")
def _fake_top_k_sampling_from_probs(
probs: torch.Tensor,
indices: Optional[torch.Tensor],
maybe_top_k_arr: Optional[torch.Tensor],
top_k_val: int,
deterministic: bool,
generator: Optional[torch.Generator],
) -> torch.Tensor:
batch_size = indices.size(0) if indices is not None else probs.size(0)
sample = torch.empty(batch_size, dtype=torch.int32, device=probs.device)
return sample
# torch library for min_p_sampling_from_probs
@register_custom_op("flashinfer::min_p_sampling_from_probs", mutates_args=())
def min_p_sampling_from_probs(
probs: torch.Tensor,
indices: Optional[torch.Tensor],
maybe_min_p_arr: Optional[torch.Tensor],
min_p_val: float,
deterministic: bool,
generator: Optional[torch.Generator],
) -> torch.Tensor:
device = probs.device
probs = probs.float()
maybe_min_p_arr = (
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
)
batch_size = indices.size(0) if indices is not None else probs.size(0)
samples = torch.empty(batch_size, dtype=torch.int32, device=device)
module.min_p_sampling_from_probs.default(
probs,
samples,
indices,
maybe_min_p_arr,
min_p_val,
deterministic,
generator,
)
return samples
# torch library for top_k_top_p_sampling_from_probs
@register_custom_op("flashinfer::top_k_top_p_sampling_from_probs", mutates_args=())
def top_k_top_p_sampling_from_probs(
probs: torch.Tensor,
indices: Optional[torch.Tensor],
maybe_top_k_arr: Optional[torch.Tensor],
top_k_val: int,
maybe_top_p_arr: Optional[torch.Tensor],
top_p_val: float,
deterministic: bool,
generator: Optional[torch.Generator],
) -> torch.Tensor:
device = probs.device
probs = probs.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
maybe_top_p_arr = (
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
)
batch_size = indices.size(0) if indices is not None else probs.size(0)
samples = torch.empty(batch_size, dtype=torch.int32, device=device)
module.top_k_top_p_sampling_from_probs.default(
probs,
samples,
indices,
maybe_top_k_arr,
top_k_val,
maybe_top_p_arr,
top_p_val,
deterministic,
generator,
)
return samples
@register_fake_op("flashinfer::top_k_top_p_sampling_from_probs")
def _fake_top_k_top_p_sampling_from_probs(
probs: torch.Tensor,
indices: Optional[torch.Tensor],
maybe_top_k_arr: Optional[torch.Tensor],
top_k_val: int,
maybe_top_p_arr: Optional[torch.Tensor],
top_p_val: float,
deterministic: bool,
generator: Optional[torch.Generator],
) -> torch.Tensor:
batch_size = indices.size(0) if indices is not None else probs.size(0)
sample = torch.empty(batch_size, dtype=torch.int32, device=probs.device)
return sample
# torch library for top_p_renorm_probs
@register_custom_op("flashinfer::top_p_renorm_probs", mutates_args=())
def top_p_renorm_probs(
probs: torch.Tensor,
maybe_top_p_arr: Optional[torch.Tensor],
top_p_val: float,
) -> torch.Tensor:
probs = probs.float()
maybe_top_p_arr = (
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
)
renorm_probs = torch.empty_like(probs)
module.top_p_renorm_probs.default(
probs,
renorm_probs,
maybe_top_p_arr,
top_p_val,
)
return renorm_probs
@register_fake_op("flashinfer::top_p_renorm_probs")
def _fake_top_p_renorm_probs(
probs: torch.Tensor,
maybe_top_p_arr: Optional[torch.Tensor],
top_p_val: float,
) -> torch.Tensor:
return torch.empty_like(probs)
# torch library for top_k_renorm_probs
@register_custom_op("flashinfer::top_k_renorm_probs", mutates_args=())
def top_k_renorm_probs(
probs: torch.Tensor,
maybe_top_k_arr: Optional[torch.Tensor],
top_k_val: int,
) -> torch.Tensor:
probs = probs.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
renorm_probs = torch.empty_like(probs)
module.top_k_renorm_probs.default(
probs,
renorm_probs,
maybe_top_k_arr,
top_k_val,
)
return renorm_probs
@register_fake_op("flashinfer::top_k_renorm_probs")
def _fake_top_k_renorm_probs(
probs: torch.Tensor,
maybe_top_k_arr: Optional[torch.Tensor],
top_k_val: int,
) -> torch.Tensor:
return torch.empty_like(probs)
# torch library for top_k_mask_logits
@register_custom_op("flashinfer::top_k_mask_logits", mutates_args=())
def top_k_mask_logits(
logits: torch.Tensor,
maybe_top_k_arr: Optional[torch.Tensor],
top_k_val: int,
) -> torch.Tensor:
logits = logits.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
mask_logits = torch.empty_like(logits)
module.top_k_mask_logits.default(
logits,
mask_logits,
maybe_top_k_arr,
top_k_val,
)
return mask_logits
@register_fake_op("flashinfer::top_k_mask_logits")
def _fake_top_k_mask_logits(
logits: torch.Tensor,
maybe_top_k_arr: Optional[torch.Tensor],
top_k_val: int,
) -> torch.Tensor:
return torch.empty_like(logits)
# torch library for chain_speculative_sampling
@register_custom_op(
"flashinfer::chain_speculative_sampling",
mutates_args=(
"output_accepted_token_num",
"output_emitted_draft_token_num",
),
)
def chain_speculative_sampling(
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
target_probs: torch.Tensor,
output_accepted_token_num: torch.Tensor,
output_emitted_draft_token_num: torch.Tensor,
deterministic: bool,
generator: Optional[torch.Generator],
) -> torch.Tensor:
device = draft_probs.device
draft_probs = draft_probs.float()
draft_token_ids = draft_token_ids.int()
target_probs = target_probs.float()
output_accepted_token_num = output_accepted_token_num.int()
output_emitted_draft_token_num = output_emitted_draft_token_num.int()
b, n = draft_token_ids.shape
output_token_ids = torch.empty((b, n + 1), dtype=torch.int32, device=device)
module.chain_speculative_sampling.default(
draft_probs,
draft_token_ids,
target_probs,
output_token_ids,
output_accepted_token_num,
output_emitted_draft_token_num,
deterministic,
generator,
)
return output_token_ids
@register_fake_op("flashinfer::chain_speculative_sampling")
def _fake_chain_speculative_sampling(
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
target_probs: torch.Tensor,
output_accepted_token_num: torch.Tensor,
output_emitted_draft_token_num: torch.Tensor,
deterministic: bool,
generator: Optional[torch.Generator],
) -> torch.Tensor:
b, n = draft_token_ids.shape
device = draft_token_ids.device
return torch.empty((b, n + 1), dtype=torch.int32, device=device)
# Register the module
return SimpleNamespace(
softmax=softmax,
sampling_from_probs=sampling_from_probs,
sampling_from_logits=sampling_from_logits,
top_p_sampling_from_probs=top_p_sampling_from_probs,
top_k_sampling_from_probs=top_k_sampling_from_probs,
min_p_sampling_from_probs=min_p_sampling_from_probs,
top_k_top_p_sampling_from_probs=top_k_top_p_sampling_from_probs,
top_p_renorm_probs=top_p_renorm_probs,
top_k_renorm_probs=top_k_renorm_probs,
top_k_mask_logits=top_k_mask_logits,
chain_speculative_sampling=chain_speculative_sampling,
)
def _to_tensor_scalar_tuple(x):
if isinstance(x, torch.Tensor):
return (x, 0)
else:
return (None, x)
def softmax(
logits: torch.Tensor,
temperature: Optional[Union[torch.Tensor, float]] = None,
enable_pdl: Optional[bool] = None,
) -> torch.Tensor:
r"""Fused GPU kernel for `online safe softmax <https://arxiv.org/abs/1805.02867>`_ with temperature scaling.
Parameters
----------
logits : torch.Tensor
Input tensor of logits.
temperature: Optional[Union[torch.Tensor, float]]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the temperature for temperature scaling.
If a scalar, the same temperature is used for all requests.
If a tensor, each request has its own temperature.
enable_pdl : Optional[bool]
Whether to enable Programmatic Dependent Launch (PDL) for improved performance on supported hardware.
If None (default), PDL will be automatically enabled on devices with compute capability >= 9.0.
Returns
-------
probs : torch.Tensor
Tensor of the same shape as input containing the softmax probabilities.
Examples
--------
>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 4
>>> vocab_size = 5
>>> logits = torch.rand(batch_size, vocab_size).to(0)
>>> logits
tensor([[0.8823, 0.9150, 0.3829, 0.9593, 0.3904],
[0.6009, 0.2566, 0.7936, 0.9408, 0.1332],
[0.9346, 0.5936, 0.8694, 0.5677, 0.7411],
[0.4294, 0.8854, 0.5739, 0.2666, 0.6274]], device='cuda:0')
>>> probs = flashinfer.sampling.softmax(logits, temperature=1.0)
>>> probs
tensor([[0.2309, 0.2385, 0.1401, 0.2493, 0.1412],
[0.2019, 0.1431, 0.2448, 0.2837, 0.1265],
[0.2401, 0.1707, 0.2249, 0.1664, 0.1979],
[0.1724, 0.2719, 0.1991, 0.1465, 0.2101]], device='cuda:0')
"""
workspace_buffer = _get_cache_buf("softmax_workspace", 1024 * 1024, logits.device)
if temperature is None:
temperature = 1.0
# Auto-detect PDL support if not specified
if enable_pdl is None:
enable_pdl = device_support_pdl(logits.device)
return get_sampling_module().softmax(
workspace_buffer, logits, *_to_tensor_scalar_tuple(temperature), enable_pdl
)
def sampling_from_logits(
logits: torch.Tensor,
indices: Optional[torch.Tensor] = None,
deterministic: bool = True,
generator: Optional[torch.Generator] = None,
check_nan: bool = False,
) -> torch.Tensor:
r"""Fused GPU kernel for category sampling from logits. It's equivalent to sampling
from :attr:`logits` after applying softmax.
Parameters
----------
logits: torch.Tensor
Logits for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
and the i-th output will be sampled from the i-th row of logits. When indices is provided,
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
probability distributions.
indices: Optional[torch.Tensor]
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in logits.
For example, if indices[i] = j, then the i-th output will be sampled from logits[j].
This allows reusing the same probability distribution for multiple outputs.
If indices is not provided, the i-th output will be sampled from the i-th row of logits.
deterministic: bool
Since the sampling doesn't use cub's BlockScan, the sampling is deterministic. We keep this
argument for compatibility with other sampling functions.
generator: Optional[torch.Generator]
A random number generator for the operation.
check_nan: bool
Whether to check nan in :attr:`logits`, default is ``False``.
Returns
-------
samples: torch.Tensor
Sampled categories, shape (batch_size,). It's equivalent to sampling from
:attr:`logits` after applying softmax.
Examples
--------
>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 4
>>> vocab_size = 5
>>> logits = torch.rand(batch_size, vocab_size).to(0)
>>> logits
tensor([[0.8823, 0.9150, 0.3829, 0.9593, 0.3904],
[0.6009, 0.2566, 0.7936, 0.9408, 0.1332],
[0.9346, 0.5936, 0.8694, 0.5677, 0.7411],
[0.4294, 0.8854, 0.5739, 0.2666, 0.6274]], device='cuda:0')
>>> samples = flashinfer.sampling.sampling_from_logits(logits)
>>> samples
tensor([0, 1, 1, 1], device='cuda:0', dtype=torch.int32)
"""
if check_nan:
if torch.any(torch.isnan(logits)):
raise ValueError("Input logits contains NaN.")
return get_sampling_module().sampling_from_logits(
logits, indices, deterministic, generator
)
def sampling_from_probs(
probs: torch.Tensor,
indices: Optional[torch.Tensor] = None,
deterministic: bool = True,
generator: Optional[torch.Generator] = None,
check_nan: bool = False,
) -> torch.Tensor:
r"""Fused GPU kernel for category sampling from probabilities.
Parameters
----------
probs: torch.Tensor
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
probability distributions.
indices: Optional[torch.Tensor]
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
This allows reusing the same probability distribution for multiple outputs.
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
generator: Optional[torch.Generator]
A random number generator for the operation.
check_nan: bool
Whether to check nan in :attr:`probs`, default is ``False``.
Returns
-------
samples: torch.Tensor
Sampled categories, shape (batch_size,).
Examples
--------
>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 4
>>> vocab_size = 5
>>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
>>> norm_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
>>> norm_prob
tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106],
[0.2205, 0.0942, 0.2912, 0.3452, 0.0489],
[0.2522, 0.1602, 0.2346, 0.1532, 0.2000],
[0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0')
>>> samples = flashinfer.sampling.sampling_from_probs(norm_prob)
>>> samples
tensor([1, 2, 1, 4], device='cuda:0', dtype=torch.int32)
Note
----
This function expects float32 inputs, and the output is int32.
"""
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return get_sampling_module().sampling_from_probs(
probs, indices, deterministic, generator
)
def top_p_sampling_from_probs(
probs: torch.Tensor,
top_p: Union[torch.Tensor, float],
indices: Optional[torch.Tensor] = None,
deterministic: bool = True,
generator: Optional[torch.Generator] = None,
check_nan: bool = False,
) -> torch.Tensor:
r"""Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.
Parameters
----------
probs: torch.Tensor
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
probability distributions.
top_p: Union[torch.Tensor, float]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
indices: Optional[torch.Tensor]
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
This allows reusing the same probability distribution for multiple outputs.
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
generator: Optional[torch.Generator]
A random number generator for the operation.
check_nan: bool
Whether to check nan in :attr:`probs`, default is ``False``.
Returns
-------
samples: torch.Tensor
Sampled categories, shape ``(batch_size,)``.
Examples
--------
>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 4
>>> vocab_size = 5
>>> top_p = 0.5
>>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
>>> norm_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
>>> norm_prob
tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106],
[0.2205, 0.0942, 0.2912, 0.3452, 0.0489],
[0.2522, 0.1602, 0.2346, 0.1532, 0.2000],
[0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0')
>>> samples = flashinfer.sampling.top_p_sampling_from_probs(norm_prob, top_p)
>>> samples
tensor([1, 2, 0, 4], device='cuda:0', dtype=torch.int32)
Note
----
This function expects float32 inputs, and the output is int32.
See Also
--------
top_k_top_p_sampling_from_probs
top_k_sampling_from_probs
top_p_renorm_probs
"""
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return get_sampling_module().top_p_sampling_from_probs(
probs, indices, *_to_tensor_scalar_tuple(top_p), deterministic, generator
)
def top_k_sampling_from_probs(
probs: torch.Tensor,
top_k: Union[torch.Tensor, int],
indices: Optional[torch.Tensor] = None,
deterministic: bool = True,
generator: Optional[torch.Generator] = None,
check_nan: bool = False,
) -> torch.Tensor:
r"""Fused GPU kernel for top-k sampling from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.
Parameters
----------
probs: torch.Tensor
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
probability distributions.
top_k: Union[torch.Tensor, int]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
indices: Optional[torch.Tensor]
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
This allows reusing the same probability distribution for multiple outputs.
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
generator: Optional[torch.Generator]
A random number generator for the operation.
check_nan: bool
Whether to check nan in :attr:`probs`, default is ``False``.
Returns
-------
samples: torch.Tensor
Sampled categories, shape ``(batch_size,)``.
Examples
--------
>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 4
>>> vocab_size = 5
>>> top_k = 1
>>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
>>> norm_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
>>> norm_prob
tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106],
[0.2205, 0.0942, 0.2912, 0.3452, 0.0489],
[0.2522, 0.1602, 0.2346, 0.1532, 0.2000],
[0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0')
>>> samples = flashinfer.sampling.top_k_sampling_from_probs(norm_prob, top_k)
>>> samples
tensor([3, 3, 0, 1], device='cuda:0', dtype=torch.int32)
Note
----
This function expects float32 inputs, and the output is int32.
See Also
--------
top_k_top_p_sampling_from_probs
top_p_sampling_from_probs
top_k_renorm_probs
"""
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return get_sampling_module().top_k_sampling_from_probs(
probs, indices, *_to_tensor_scalar_tuple(top_k), deterministic, generator
)
def min_p_sampling_from_probs(
probs: torch.Tensor,
min_p: Union[torch.Tensor, float],
indices: Optional[torch.Tensor] = None,
deterministic: bool = True,
generator: Optional[torch.Generator] = None,
check_nan: bool = False,
) -> torch.Tensor:
r"""Fused GPU kernel for `min_p sampling <https://arxiv.org/abs/2407.01082>`_ from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.
Parameters
----------
probs: torch.Tensor
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
probability distributions.
min_p: Union[torch.Tensor, float]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for min-p sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
indices: Optional[torch.Tensor]
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
This allows reusing the same probability distribution for multiple outputs.
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
generator: Optional[torch.Generator]
A random number generator for the operation.
check_nan: bool
Whether to check nan in :attr:`probs`, default is ``False``.
Returns
-------
samples: torch.Tensor
Sampled categories, shape ``(batch_size,)``.
Examples
--------
>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
<torch._C.Generator object at 0x7f8b3db06df0>
>>> batch_size = 4
>>> vocab_size = 5
>>> min_p = torch.full((batch_size,), 0.05).to(0)
>>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
>>> norm_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
>>> norm_prob
tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106],
[0.2205, 0.0942, 0.2912, 0.3452, 0.0489],
[0.2522, 0.1602, 0.2346, 0.1532, 0.2000],
[0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0')
>>> samples = flashinfer.sampling.min_p_sampling_from_probs(norm_prob, min_p)
>>> samples
tensor([1, 2, 1, 4], device='cuda:0', dtype=torch.int32)
Note
----
This function expects float32 inputs, and the output is int32.
"""
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return get_sampling_module().min_p_sampling_from_probs(
probs, indices, *_to_tensor_scalar_tuple(min_p), deterministic, generator
)
def top_k_top_p_sampling_from_logits(
logits: torch.Tensor,
top_k: Union[torch.Tensor, int],
top_p: Union[torch.Tensor, float],
indices: Optional[torch.Tensor] = None,
filter_apply_order: str = "top_k_first",
deterministic: bool = True,
generator: Optional[torch.Generator] = None,
check_nan: bool = False,
) -> torch.Tensor:
r"""Fused GPU kernel for top-k and top-p sampling from pre-softmax logits,
this operator implements GPU-based rejection sampling without explicit sorting.
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.
Parameters
----------
logits: torch.Tensor
Pre-softmax logits for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
and the i-th output will be sampled from the i-th row of logits. When indices is provided,
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
probability distributions.
top_k: Union[torch.Tensor, int]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
top_p: Union[torch.Tensor, float]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
indices: Optional[torch.Tensor]
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
This allows reusing the same probability distribution for multiple outputs.
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
filter_apply_order: str
The order of applying top-k and top-p sampling, should be either ``"top_k_first"`` or ``"joint"``.
If ``"top_k_first"``, we first apply top-k filter, then apply top-p sampling on the top-k results.
If ``"joint"``, we apply top-k and top-p filter simultaneously in each round. Default is ``"top_k_first"``.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
generator: Optional[torch.Generator]
A random number generator for the operation.
check_nan: bool
Whether to check nan in :attr:`probs`, default is ``False``.
Returns
-------
samples: torch.Tensor
Sampled categories, shape ``(batch_size,)``.
Examples
--------
>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 4
>>> vocab_size = 5
>>> top_p = 0.5
>>> top_k = 3
>>> logits = torch.rand(batch_size, vocab_size).to(0)
>>> logits
tensor([[ 1.9269, 1.4873, 0.9007, -2.1055, -0.7581],
[ 1.0783, 0.8008, 1.6806, 0.3559, -0.6866],
[-0.4934, 0.2415, -0.2316, 0.0418, -0.2516],
[ 0.8599, -0.3097, -0.3957, 0.8034, -0.6216]], device='cuda:0')
>>> samples = flashinfer.sampling.top_k_top_p_sampling_from_logits(logits, top_k, top_p)
>>> samples
tensor([0, 2, 1, 3], device='cuda:0', dtype=torch.int32
>>> probs = torch.softmax(logits, dim=-1)
>>> probs
tensor([[0.4788, 0.3085, 0.1716, 0.0085, 0.0327],
[0.2358, 0.1787, 0.4307, 0.1145, 0.0404],
[0.1358, 0.2831, 0.1764, 0.2318, 0.1729],
[0.3613, 0.1122, 0.1029, 0.3415, 0.0821]], device='cuda:0')
>>> samples
tensor([0, 2, 1, 3], device='cuda:0', dtype=torch.int32)
Note
----
This function expects float32 inputs, and the output is int32.
See Also
--------
top_k_top_p_sampling_from_probs
top_k_mask_logits
top_p_sampling_from_probs
"""
if filter_apply_order == "top_k_first":
masked_logits = top_k_mask_logits(logits, top_k)
probs = torch.softmax(masked_logits, dim=-1)
return top_p_sampling_from_probs(
probs,
top_p,
indices,
deterministic,
check_nan=check_nan,
generator=generator,
)
elif filter_apply_order == "joint":
probs = torch.softmax(logits, dim=-1)
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return get_sampling_module().top_k_top_p_sampling_from_probs(
probs,
indices,
*_to_tensor_scalar_tuple(top_k),
*_to_tensor_scalar_tuple(top_p),
deterministic,
generator,
)
else:
raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
def top_k_top_p_sampling_from_probs(
probs: torch.Tensor,
top_k: Union[torch.Tensor, int],
top_p: Union[torch.Tensor, float],
indices: Optional[torch.Tensor] = None,
filter_apply_order: str = "top_k_first",
deterministic: bool = True,
generator: Optional[torch.Generator] = None,
check_nan: bool = False,
) -> torch.Tensor:
r"""Fused GPU kernel for top-k and top-p sampling from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.
Parameters
----------
probs: torch.Tensor
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
probability distributions.
top_k: Union[torch.Tensor, int]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
top_p: Union[torch.Tensor, float]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
indices: Optional[torch.Tensor]
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
This allows reusing the same probability distribution for multiple outputs.
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
filter_apply_order: str
The order of applying top-k and top-p sampling, should be either ``"top_k_first"`` or ``"joint"``.
If ``"top_k_first"``, we first apply top-k filter, then apply top-p sampling on the top-k results.
If ``"joint"``, we apply top-k and top-p filter simultaneously in each round. Default is ``"top_k_first"``.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
generator: Optional[torch.Generator]
A random number generator for the operation.
check_nan: bool
Whether to check nan in :attr:`probs`, default is ``False``.
Returns
-------
samples: torch.Tensor
Sampled categories, shape ``(batch_size,)``.
Examples
--------
>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 4
>>> vocab_size = 5
>>> top_p = torch.full((batch_size,), 0.2).to(0)
>>> top_k = torch.full((batch_size,), 2).to(0)
>>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
>>> norm_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
>>> norm_prob
tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106],
[0.2205, 0.0942, 0.2912, 0.3452, 0.0489],
[0.2522, 0.1602, 0.2346, 0.1532, 0.2000],
[0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0')
>>> samples = flashinfer.sampling.top_k_top_p_sampling_from_probs(norm_prob, top_k, top_p)
>>> samples
tensor([3, 3, 0, 1], device='cuda:0', dtype=torch.int32)
Note
----
This function expects float32 inputs, and the output is int32.
See Also
--------
top_k_sampling_from_probs
top_p_sampling_from_probs
top_k_renorm_probs
top_p_renorm_probs
top_k_mask_logits
"""
if filter_apply_order == "top_k_first":
renorm_probs = top_k_renorm_probs(probs, top_k)
return top_p_sampling_from_probs(
renorm_probs,
top_p,
indices,
deterministic,
check_nan=check_nan,
generator=generator,
)
elif filter_apply_order == "joint":
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return get_sampling_module().top_k_top_p_sampling_from_probs(
probs,
indices,
*_to_tensor_scalar_tuple(top_k),
*_to_tensor_scalar_tuple(top_p),
deterministic,
generator,
)
else:
raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
def top_p_renorm_probs(
probs: torch.Tensor,
top_p: Union[torch.Tensor, float],
) -> torch.Tensor:
r"""Fused GPU kernel for renormalizing probabilities by top-p thresholding.
Parameters
----------
probs: torch.Tensor
Probabilities, shape ``(batch_size, num_classes)``.
top_p: Union[torch.Tensor, float]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-p threshold for for
re-normalizing probabilities, should be in ``(0, 1)``.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
We mask out the probabilities less than `threshold` where the cumulative sum
of ``probs[probs >= threshold]`` is `top_p`, and renormalize the probabilities.
Returns
-------
renorm_probs: torch.Tensor
Renormalized probabilities, shape ``(batch_size, num_classes)``.
Examples
--------
>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 4
>>> vocab_size = 5
>>> top_p = 0.3
>>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
>>> prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
>>> prob
tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106],
[0.2205, 0.0942, 0.2912, 0.3452, 0.0489],
[0.2522, 0.1602, 0.2346, 0.1532, 0.2000],
[0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0')
>>> renormed_probs = flashinfer.sampling.top_p_renorm_probs(prob, top_p)
>>> renormed_probs
tensor([[0.0000, 0.4882, 0.0000, 0.5118, 0.0000],
[0.0000, 0.0000, 0.0000, 1.0000, 0.0000],
[0.5181, 0.0000, 0.4819, 0.0000, 0.0000],
[0.0000, 1.0000, 0.0000, 0.0000, 0.0000]], device='cuda:0')
Note
----
This combination of ``top_p_renorm_probs`` and ``sampling_from_probs`` should be equivalent to
``top_p_sampling_from_probs``.
See Also
--------
top_p_sampling_from_probs
sampling_from_probs
top_k_renorm_probs
"""
return get_sampling_module().top_p_renorm_probs(
probs, *_to_tensor_scalar_tuple(top_p)
)
top_p_renorm_prob = top_p_renorm_probs
def top_k_renorm_probs(
probs: torch.Tensor,
top_k: Union[torch.Tensor, int],
) -> torch.Tensor:
r"""Fused GPU kernel for renormalizing probabilities by top-k thresholding.
Parameters
----------
probs: torch.Tensor
Probabilities, shape ``(batch_size, num_classes)``.
top_k: Union[torch.Tensor, int]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-k threshold for for
for re-normalizing probabilities, should be in ``(0, num_classes)``.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
We keep the top-k probabilities, set the rest to zero, and renormalize the probabilities.
Returns
-------
renorm_probs: torch.Tensor
Renormalized probabilities, shape ``(batch_size, num_classes)``.
Examples
--------
>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 4
>>> vocab_size = 5
>>> top_k = 3
>>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
>>> prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
>>> prob
tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106],
[0.2205, 0.0942, 0.2912, 0.3452, 0.0489],
[0.2522, 0.1602, 0.2346, 0.1532, 0.2000],
[0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0')
>>> renormed_probs = flashinfer.sampling.top_k_renorm_probs(prob, top_k)
>>> renormed_probs
tensor([[0.3201, 0.3319, 0.0000, 0.3480, 0.0000],
[0.2573, 0.0000, 0.3398, 0.4028, 0.0000],
[0.3672, 0.0000, 0.3416, 0.0000, 0.2912],
[0.0000, 0.4243, 0.2750, 0.0000, 0.3007]], device='cuda:0')
Note
----
This combination of ``top_k_renorm_probs`` and ``sampling_from_probs`` should be equivalent to
``top_k_sampling_from_probs``.
See Also
--------
top_k_sampling_from_probs
sampling_from_probs
top_p_renorm_probs
"""
return get_sampling_module().top_k_renorm_probs(
probs, *_to_tensor_scalar_tuple(top_k)
)
top_k_renorm_prob = top_k_renorm_probs
def top_k_mask_logits(
logits: torch.Tensor, top_k: Union[torch.Tensor, int]
) -> torch.Tensor:
r"""Fused GPU kernel for masking logits by top-k thresholding.
Parameters
----------
logits: torch.Tensor
Logits before softmax, shape ``(batch_size, num_classes)``.
top_k: Union[torch.Tensor, int]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-k threshold for for
for masking logits, should be in ``(0, num_classes)``.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
We keep the top-k logits, set the rest to negative infinity.
Returns
-------
masked_logits: torch.Tensor
Masked logits, shape ``(batch_size, num_classes)``.
Examples
--------
>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 4
>>> vocab_size = 5
>>> top_k = 3
>>> logits = torch.randn(batch_size, vocab_size).to(0)
>>> logits
tensor([[ 1.9269, 1.4873, 0.9007, -2.1055, -0.7581],
[ 1.0783, 0.8008, 1.6806, 0.3559, -0.6866],
[-0.4934, 0.2415, -0.2316, 0.0418, -0.2516],
[ 0.8599, -0.3097, -0.3957, 0.8034, -0.6216]], device='cuda:0')
>>> masked_logits = flashinfer.sampling.top_k_mask_logits(logits, top_k)
>>> masked_logits
tensor([[ 1.9269, 1.4873, 0.9007, -inf, -inf],
[ 1.0783, 0.8008, 1.6806, -inf, -inf],
[ -inf, 0.2415, -0.2316, 0.0418, -inf],
[ 0.8599, -0.3097, -inf, 0.8034, -inf]], device='cuda:0')
Note
----
The combination of ``top_k_mask_logits`` and ``softmax`` should be equivalent to ``top_k_renorm_probs``.
See Also
--------
top_k_renorm_probs
"""
return get_sampling_module().top_k_mask_logits(
logits, *_to_tensor_scalar_tuple(top_k)
)
def chain_speculative_sampling(
draft_probs,
draft_token_ids,
target_probs,
maybe_output_accepted_token_num: Optional[torch.Tensor] = None,
maybe_output_emitted_draft_token_num: Optional[torch.Tensor] = None,
deterministic: bool = True,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
r"""Fused-GPU kernel for speculative sampling for sequence generation (proposed in
paper `Accelerating Large Language Model Decoding with Speculative Sampling <https://arxiv.org/pdf/2302.01318>`_),
where the draft model generates a sequence(chain) of tokens for each request.
Parameters
----------
draft_probs: torch.Tensor
The probability over vocabulary generated by draft model.
Shape: ``(batch_size, num_speculate_tokens, vocab_size)``
draft_token_ids: torch.Tensor
The draft model's generated token indices.
Shape: ``(batch_size, num_speculate_tokens)``
target_probs: torch.Tensor
The probability over vocabulary generated by target model.
Compared to input :attr:`draft_probs`, the target model's probability has an additional
slot at the end because the target model will generate one more token than the draft model.
Shape: ``(batch_size, num_speculate_tokens + 1, vocab_size)``
maybe_output_accepted_token_num: Optional[torch.Tensor]
The number of tokens that can be accepted if each token is considered independently for each request.
This metric does not consider the fact that rejection sampling will stop at the first token that does not
satisfy the probability requirement r < p/q.
It only evaluates the alignment of draft model and target model.
Shape: ``(batch_size)``
If specified, the number of accepted token number will be added to this tensor inplace. Default is ``None``.
maybe_output_emitted_draft_token_num: Optional[torch.Tensor]
The number of draft tokens that are finally emitted for each request. Does not include
the bonus token. (Thus the total number of tokens sampled for a given request is
output_emitted_draft_token_num + 1).
Shape: ``(batch_size)``
If specified, the number of emitted token number will be added to this tensor inplace. Default is ``None``.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
generator: Optional[torch.Generator]
A random number generator for the operation.
Returns
-------
output_token_ids: torch.Tensor
The output token indices verified by the target model, rejected samples are
padded with ``-1``.
Compared to input :attr:`draft_token_ids`, the output tensor has an additional
token index at the end for the final token, if all previous tokens are accepted,
another "bonus" token will be sampled from the target model's probability.
Shape: (batch_size, num_speculate_tokens + 1)
output_accepted_token_num: torch.Tensor
The number of tokens that can be accepted if each token is considered independently for each request.
This metric does not consider the fact that rejection sampling will stop at the first token that does not
satisfy the probability requirement r < p/q.
It only evaluates the alignment of draft model and target model.
Shape: ``(batch_size)``
output_emitted_draft_token_num: torch.Tensor
The number of draft tokens that are finally emitted for each request. Does not include
the bonus token. (Thus the total number of tokens sampled for a given request is
output_emitted_draft_token_num + 1).
Shape: ``(batch_size)``
Examples
--------
>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 1
>>> num_speculate_tokens = 2
>>> vocab_size = 4
>>> draft_probs = torch.tensor([[[0.1, 0.2, 0.3, 0.4], [0.2, 0.3, 0.4, 0.1]]]).to(0)
>>> # token 2 was sampled from draft model for the first token, and
>>> # token 1 was sampled from draft model for the second token
>>> draft_token_ids = torch.tensor([[2, 1]], dtype=torch.int32).to(0)
>>> target_probs = torch.tensor([[[0.0, 0.1, 0.6, 0.3], [1.0, 0.0, 0.0, 0.0], [0.7, 0.1, 0.1, 0.1]]]).to(0)
>>> output_token_ids, output_accepted_token_num, output_emitted_draft_token_num =\
... flashinfer.sampling.chain_speculative_sampling(
... draft_probs, draft_token_ids, target_probs)
>>> # the first token is accepted, the second token is rejected and sampled from the difference
>>> # between the target model and the draft model, the third token is padded with -1
>>> output_token_ids
tensor([[ 2, 0, -1]], device='cuda:0', dtype=torch.int32)
>>> output_accepted_token_num
tensor([1], device='cuda:0')
>>> output_emitted_draft_token_num
tensor([1], device='cuda:0')
"""
b = draft_probs.size(0)
dev = draft_probs.device
if maybe_output_accepted_token_num is None:
output_accepted_token_num = torch.zeros(b, dtype=torch.int32, device=dev)
else:
output_accepted_token_num = maybe_output_accepted_token_num
if maybe_output_emitted_draft_token_num is None:
output_emitted_draft_token_num = torch.zeros(b, dtype=torch.int32, device=dev)
else:
output_emitted_draft_token_num = maybe_output_emitted_draft_token_num
output_token_ids = get_sampling_module().chain_speculative_sampling(
draft_probs,
draft_token_ids,
target_probs,
output_accepted_token_num,
output_emitted_draft_token_num,
deterministic,
generator,
)
return output_token_ids, output_accepted_token_num, output_emitted_draft_token_num