""" Copyright (c) 2024 by FlashInfer team. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import functools from types import SimpleNamespace from typing import Optional, Union import torch from .jit import JitSpec from .jit import env as jit_env from .jit import gen_jit_spec from .utils import ( _get_cache_buf, device_support_pdl, register_custom_op, register_fake_op, ) def gen_sampling_module() -> JitSpec: return gen_jit_spec( "sampling", [ jit_env.FLASHINFER_CSRC_DIR / "sampling.cu", jit_env.FLASHINFER_CSRC_DIR / "renorm.cu", jit_env.FLASHINFER_CSRC_DIR / "flashinfer_sampling_ops.cu", ], ) @functools.cache def get_sampling_module(): module = gen_sampling_module().build_and_load() @register_custom_op("flashinfer::softmax", mutates_args=("workspace_buffer",)) def softmax( workspace_buffer: torch.Tensor, logits: torch.Tensor, maybe_temperature_arr: Optional[torch.Tensor], temperature_val: float, enable_pdl: bool, ) -> torch.Tensor: logits = logits.float() probs = torch.empty_like(logits, device=logits.device) maybe_temperature_arr = ( maybe_temperature_arr.float() if maybe_temperature_arr is not None else None ) module.softmax.default( workspace_buffer, logits, probs, maybe_temperature_arr, temperature_val, enable_pdl, ) return probs @register_fake_op("flashinfer::softmax") def _fake_softmax( workspace_buffer: torch.Tensor, logits: torch.Tensor, maybe_temperature_arr: Optional[torch.Tensor], temperature_val: float, enable_pdl: bool, ) -> torch.Tensor: return torch.empty_like(logits, device=logits.device, dtype=torch.float32) # torch library for sampling_from_logits @register_custom_op("flashinfer::sampling_from_logits", mutates_args=()) def sampling_from_logits( logits: torch.Tensor, indices: Optional[torch.Tensor], deterministic: bool, generator: Optional[torch.Generator], ) -> torch.Tensor: device = logits.device # TODO: support more data types in logits to avoid conversion # to float32 logits = logits.float() batch_size = indices.size(0) if indices is not None else logits.size(0) samples = torch.empty(batch_size, dtype=torch.int32, device=device) module.sampling_from_logits.default( logits, samples, indices, deterministic, generator, ) return samples @register_fake_op("flashinfer::sampling_from_logits") def _fake_sampling_from_logits( logits: torch.Tensor, indices: Optional[torch.Tensor], deterministic: bool, generator: Optional[torch.Generator], ) -> torch.Tensor: batch_size = indices.size(0) if indices is not None else logits.size(0) return torch.empty(batch_size, dtype=torch.int32, device=logits.device) # torch library for sampling_from_probs @register_custom_op("flashinfer::sampling_from_probs", mutates_args=()) def sampling_from_probs( probs: torch.Tensor, indices: Optional[torch.Tensor], deterministic: bool, generator: Optional[torch.Generator], ) -> torch.Tensor: device = probs.device probs = probs.float() batch_size = indices.size(0) if indices is not None else probs.size(0) samples = torch.empty(batch_size, dtype=torch.int32, device=device) module.sampling_from_probs.default( probs, samples, indices, deterministic, generator, ) return samples # torch library for sampling_from_probs @register_fake_op("flashinfer::sampling_from_probs") def _fake_sampling_from_probs( probs: torch.Tensor, indices: Optional[torch.Tensor], deterministic: bool, generator: Optional[torch.Generator], ) -> torch.Tensor: batch_size = indices.size(0) if indices is not None else probs.size(0) return torch.empty(batch_size, dtype=torch.int32, device=probs.device) # torch library for top_p_sampling_from_probs @register_custom_op("flashinfer::top_p_sampling_from_probs", mutates_args=()) def top_p_sampling_from_probs( probs: torch.Tensor, indices: Optional[torch.Tensor], maybe_top_p_arr: Optional[torch.Tensor], top_p_val: float, deterministic: bool, generator: Optional[torch.Generator], ) -> torch.Tensor: device = probs.device probs = probs.float() maybe_top_p_arr = ( maybe_top_p_arr.float() if maybe_top_p_arr is not None else None ) batch_size = indices.size(0) if indices is not None else probs.size(0) samples = torch.empty(batch_size, dtype=torch.int32, device=device) module.top_p_sampling_from_probs.default( probs, samples, indices, maybe_top_p_arr, top_p_val, deterministic, generator, ) return samples @register_fake_op("flashinfer::top_p_sampling_from_probs") def _fake_top_p_sampling_from_probs( probs: torch.Tensor, indices: Optional[torch.Tensor], maybe_top_p_arr: Optional[torch.Tensor], top_p_val: float, deterministic: bool, generator: Optional[torch.Generator], ) -> torch.Tensor: sample = torch.empty(probs.size(0), dtype=torch.int32, device=probs.device) return sample # torch library for top_k_sampling_from_probs @register_custom_op("flashinfer::top_k_sampling_from_probs", mutates_args=()) def top_k_sampling_from_probs( probs: torch.Tensor, indices: Optional[torch.Tensor], maybe_top_k_arr: Optional[torch.Tensor], top_k_val: int, deterministic: bool, generator: Optional[torch.Generator], ) -> torch.Tensor: device = probs.device probs = probs.float() batch_size = indices.size(0) if indices is not None else probs.size(0) maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None samples = torch.empty(batch_size, dtype=torch.int32, device=device) module.top_k_sampling_from_probs.default( probs, samples, indices, maybe_top_k_arr, top_k_val, deterministic, generator, ) return samples @register_fake_op("flashinfer::top_k_sampling_from_probs") def _fake_top_k_sampling_from_probs( probs: torch.Tensor, indices: Optional[torch.Tensor], maybe_top_k_arr: Optional[torch.Tensor], top_k_val: int, deterministic: bool, generator: Optional[torch.Generator], ) -> torch.Tensor: batch_size = indices.size(0) if indices is not None else probs.size(0) sample = torch.empty(batch_size, dtype=torch.int32, device=probs.device) return sample # torch library for min_p_sampling_from_probs @register_custom_op("flashinfer::min_p_sampling_from_probs", mutates_args=()) def min_p_sampling_from_probs( probs: torch.Tensor, indices: Optional[torch.Tensor], maybe_min_p_arr: Optional[torch.Tensor], min_p_val: float, deterministic: bool, generator: Optional[torch.Generator], ) -> torch.Tensor: device = probs.device probs = probs.float() maybe_min_p_arr = ( maybe_min_p_arr.float() if maybe_min_p_arr is not None else None ) batch_size = indices.size(0) if indices is not None else probs.size(0) samples = torch.empty(batch_size, dtype=torch.int32, device=device) module.min_p_sampling_from_probs.default( probs, samples, indices, maybe_min_p_arr, min_p_val, deterministic, generator, ) return samples # torch library for top_k_top_p_sampling_from_probs @register_custom_op("flashinfer::top_k_top_p_sampling_from_probs", mutates_args=()) def top_k_top_p_sampling_from_probs( probs: torch.Tensor, indices: Optional[torch.Tensor], maybe_top_k_arr: Optional[torch.Tensor], top_k_val: int, maybe_top_p_arr: Optional[torch.Tensor], top_p_val: float, deterministic: bool, generator: Optional[torch.Generator], ) -> torch.Tensor: device = probs.device probs = probs.float() maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None maybe_top_p_arr = ( maybe_top_p_arr.float() if maybe_top_p_arr is not None else None ) batch_size = indices.size(0) if indices is not None else probs.size(0) samples = torch.empty(batch_size, dtype=torch.int32, device=device) module.top_k_top_p_sampling_from_probs.default( probs, samples, indices, maybe_top_k_arr, top_k_val, maybe_top_p_arr, top_p_val, deterministic, generator, ) return samples @register_fake_op("flashinfer::top_k_top_p_sampling_from_probs") def _fake_top_k_top_p_sampling_from_probs( probs: torch.Tensor, indices: Optional[torch.Tensor], maybe_top_k_arr: Optional[torch.Tensor], top_k_val: int, maybe_top_p_arr: Optional[torch.Tensor], top_p_val: float, deterministic: bool, generator: Optional[torch.Generator], ) -> torch.Tensor: batch_size = indices.size(0) if indices is not None else probs.size(0) sample = torch.empty(batch_size, dtype=torch.int32, device=probs.device) return sample # torch library for top_p_renorm_probs @register_custom_op("flashinfer::top_p_renorm_probs", mutates_args=()) def top_p_renorm_probs( probs: torch.Tensor, maybe_top_p_arr: Optional[torch.Tensor], top_p_val: float, ) -> torch.Tensor: probs = probs.float() maybe_top_p_arr = ( maybe_top_p_arr.float() if maybe_top_p_arr is not None else None ) renorm_probs = torch.empty_like(probs) module.top_p_renorm_probs.default( probs, renorm_probs, maybe_top_p_arr, top_p_val, ) return renorm_probs @register_fake_op("flashinfer::top_p_renorm_probs") def _fake_top_p_renorm_probs( probs: torch.Tensor, maybe_top_p_arr: Optional[torch.Tensor], top_p_val: float, ) -> torch.Tensor: return torch.empty_like(probs) # torch library for top_k_renorm_probs @register_custom_op("flashinfer::top_k_renorm_probs", mutates_args=()) def top_k_renorm_probs( probs: torch.Tensor, maybe_top_k_arr: Optional[torch.Tensor], top_k_val: int, ) -> torch.Tensor: probs = probs.float() maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None renorm_probs = torch.empty_like(probs) module.top_k_renorm_probs.default( probs, renorm_probs, maybe_top_k_arr, top_k_val, ) return renorm_probs @register_fake_op("flashinfer::top_k_renorm_probs") def _fake_top_k_renorm_probs( probs: torch.Tensor, maybe_top_k_arr: Optional[torch.Tensor], top_k_val: int, ) -> torch.Tensor: return torch.empty_like(probs) # torch library for top_k_mask_logits @register_custom_op("flashinfer::top_k_mask_logits", mutates_args=()) def top_k_mask_logits( logits: torch.Tensor, maybe_top_k_arr: Optional[torch.Tensor], top_k_val: int, ) -> torch.Tensor: logits = logits.float() maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None mask_logits = torch.empty_like(logits) module.top_k_mask_logits.default( logits, mask_logits, maybe_top_k_arr, top_k_val, ) return mask_logits @register_fake_op("flashinfer::top_k_mask_logits") def _fake_top_k_mask_logits( logits: torch.Tensor, maybe_top_k_arr: Optional[torch.Tensor], top_k_val: int, ) -> torch.Tensor: return torch.empty_like(logits) # torch library for chain_speculative_sampling @register_custom_op( "flashinfer::chain_speculative_sampling", mutates_args=( "output_accepted_token_num", "output_emitted_draft_token_num", ), ) def chain_speculative_sampling( draft_probs: torch.Tensor, draft_token_ids: torch.Tensor, target_probs: torch.Tensor, output_accepted_token_num: torch.Tensor, output_emitted_draft_token_num: torch.Tensor, deterministic: bool, generator: Optional[torch.Generator], ) -> torch.Tensor: device = draft_probs.device draft_probs = draft_probs.float() draft_token_ids = draft_token_ids.int() target_probs = target_probs.float() output_accepted_token_num = output_accepted_token_num.int() output_emitted_draft_token_num = output_emitted_draft_token_num.int() b, n = draft_token_ids.shape output_token_ids = torch.empty((b, n + 1), dtype=torch.int32, device=device) module.chain_speculative_sampling.default( draft_probs, draft_token_ids, target_probs, output_token_ids, output_accepted_token_num, output_emitted_draft_token_num, deterministic, generator, ) return output_token_ids @register_fake_op("flashinfer::chain_speculative_sampling") def _fake_chain_speculative_sampling( draft_probs: torch.Tensor, draft_token_ids: torch.Tensor, target_probs: torch.Tensor, output_accepted_token_num: torch.Tensor, output_emitted_draft_token_num: torch.Tensor, deterministic: bool, generator: Optional[torch.Generator], ) -> torch.Tensor: b, n = draft_token_ids.shape device = draft_token_ids.device return torch.empty((b, n + 1), dtype=torch.int32, device=device) # Register the module return SimpleNamespace( softmax=softmax, sampling_from_probs=sampling_from_probs, sampling_from_logits=sampling_from_logits, top_p_sampling_from_probs=top_p_sampling_from_probs, top_k_sampling_from_probs=top_k_sampling_from_probs, min_p_sampling_from_probs=min_p_sampling_from_probs, top_k_top_p_sampling_from_probs=top_k_top_p_sampling_from_probs, top_p_renorm_probs=top_p_renorm_probs, top_k_renorm_probs=top_k_renorm_probs, top_k_mask_logits=top_k_mask_logits, chain_speculative_sampling=chain_speculative_sampling, ) def _to_tensor_scalar_tuple(x): if isinstance(x, torch.Tensor): return (x, 0) else: return (None, x) def softmax( logits: torch.Tensor, temperature: Optional[Union[torch.Tensor, float]] = None, enable_pdl: Optional[bool] = None, ) -> torch.Tensor: r"""Fused GPU kernel for `online safe softmax `_ with temperature scaling. Parameters ---------- logits : torch.Tensor Input tensor of logits. temperature: Optional[Union[torch.Tensor, float]] Either a scalar or a tensor of shape ``(batch_size,)``, representing the temperature for temperature scaling. If a scalar, the same temperature is used for all requests. If a tensor, each request has its own temperature. enable_pdl : Optional[bool] Whether to enable Programmatic Dependent Launch (PDL) for improved performance on supported hardware. If None (default), PDL will be automatically enabled on devices with compute capability >= 9.0. Returns ------- probs : torch.Tensor Tensor of the same shape as input containing the softmax probabilities. Examples -------- >>> import torch >>> import flashinfer >>> torch.manual_seed(42) >>> batch_size = 4 >>> vocab_size = 5 >>> logits = torch.rand(batch_size, vocab_size).to(0) >>> logits tensor([[0.8823, 0.9150, 0.3829, 0.9593, 0.3904], [0.6009, 0.2566, 0.7936, 0.9408, 0.1332], [0.9346, 0.5936, 0.8694, 0.5677, 0.7411], [0.4294, 0.8854, 0.5739, 0.2666, 0.6274]], device='cuda:0') >>> probs = flashinfer.sampling.softmax(logits, temperature=1.0) >>> probs tensor([[0.2309, 0.2385, 0.1401, 0.2493, 0.1412], [0.2019, 0.1431, 0.2448, 0.2837, 0.1265], [0.2401, 0.1707, 0.2249, 0.1664, 0.1979], [0.1724, 0.2719, 0.1991, 0.1465, 0.2101]], device='cuda:0') """ workspace_buffer = _get_cache_buf("softmax_workspace", 1024 * 1024, logits.device) if temperature is None: temperature = 1.0 # Auto-detect PDL support if not specified if enable_pdl is None: enable_pdl = device_support_pdl(logits.device) return get_sampling_module().softmax( workspace_buffer, logits, *_to_tensor_scalar_tuple(temperature), enable_pdl ) def sampling_from_logits( logits: torch.Tensor, indices: Optional[torch.Tensor] = None, deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, ) -> torch.Tensor: r"""Fused GPU kernel for category sampling from logits. It's equivalent to sampling from :attr:`logits` after applying softmax. Parameters ---------- logits: torch.Tensor Logits for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)`` and the i-th output will be sampled from the i-th row of logits. When indices is provided, shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique probability distributions. indices: Optional[torch.Tensor] Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in logits. For example, if indices[i] = j, then the i-th output will be sampled from logits[j]. This allows reusing the same probability distribution for multiple outputs. If indices is not provided, the i-th output will be sampled from the i-th row of logits. deterministic: bool Since the sampling doesn't use cub's BlockScan, the sampling is deterministic. We keep this argument for compatibility with other sampling functions. generator: Optional[torch.Generator] A random number generator for the operation. check_nan: bool Whether to check nan in :attr:`logits`, default is ``False``. Returns ------- samples: torch.Tensor Sampled categories, shape (batch_size,). It's equivalent to sampling from :attr:`logits` after applying softmax. Examples -------- >>> import torch >>> import flashinfer >>> torch.manual_seed(42) >>> batch_size = 4 >>> vocab_size = 5 >>> logits = torch.rand(batch_size, vocab_size).to(0) >>> logits tensor([[0.8823, 0.9150, 0.3829, 0.9593, 0.3904], [0.6009, 0.2566, 0.7936, 0.9408, 0.1332], [0.9346, 0.5936, 0.8694, 0.5677, 0.7411], [0.4294, 0.8854, 0.5739, 0.2666, 0.6274]], device='cuda:0') >>> samples = flashinfer.sampling.sampling_from_logits(logits) >>> samples tensor([0, 1, 1, 1], device='cuda:0', dtype=torch.int32) """ if check_nan: if torch.any(torch.isnan(logits)): raise ValueError("Input logits contains NaN.") return get_sampling_module().sampling_from_logits( logits, indices, deterministic, generator ) def sampling_from_probs( probs: torch.Tensor, indices: Optional[torch.Tensor] = None, deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, ) -> torch.Tensor: r"""Fused GPU kernel for category sampling from probabilities. Parameters ---------- probs: torch.Tensor Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)`` and the i-th output will be sampled from the i-th row of probabilities. When indices is provided, shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique probability distributions. indices: Optional[torch.Tensor] Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs. For example, if indices[i] = j, then the i-th output will be sampled from probs[j]. This allows reusing the same probability distribution for multiple outputs. If indices is not provided, the i-th output will be sampled from the i-th row of probs. deterministic: bool Whether to use deterministic kernel implementation, default is ``True``. generator: Optional[torch.Generator] A random number generator for the operation. check_nan: bool Whether to check nan in :attr:`probs`, default is ``False``. Returns ------- samples: torch.Tensor Sampled categories, shape (batch_size,). Examples -------- >>> import torch >>> import flashinfer >>> torch.manual_seed(42) >>> batch_size = 4 >>> vocab_size = 5 >>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) >>> norm_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) >>> norm_prob tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106], [0.2205, 0.0942, 0.2912, 0.3452, 0.0489], [0.2522, 0.1602, 0.2346, 0.1532, 0.2000], [0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0') >>> samples = flashinfer.sampling.sampling_from_probs(norm_prob) >>> samples tensor([1, 2, 1, 4], device='cuda:0', dtype=torch.int32) Note ---- This function expects float32 inputs, and the output is int32. """ if check_nan: if torch.any(torch.isnan(probs)): raise ValueError("Input probs contains NaN.") return get_sampling_module().sampling_from_probs( probs, indices, deterministic, generator ) def top_p_sampling_from_probs( probs: torch.Tensor, top_p: Union[torch.Tensor, float], indices: Optional[torch.Tensor] = None, deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, ) -> torch.Tensor: r"""Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities, this operator implements GPU-based rejection sampling without explicit sorting. Check the `blog post `_ for more details. The multiple rounds of rejection sampling are implemented in a single CUDA kernel, which is more efficient than the naive implementation that launches a series of kernels. Parameters ---------- probs: torch.Tensor Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)`` and the i-th output will be sampled from the i-th row of probabilities. When indices is provided, shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique probability distributions. top_p: Union[torch.Tensor, float] Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling. If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold. indices: Optional[torch.Tensor] Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs. For example, if indices[i] = j, then the i-th output will be sampled from probs[j]. This allows reusing the same probability distribution for multiple outputs. If indices is not provided, the i-th output will be sampled from the i-th row of probs. deterministic: bool Whether to use deterministic kernel implementation, default is ``True``. generator: Optional[torch.Generator] A random number generator for the operation. check_nan: bool Whether to check nan in :attr:`probs`, default is ``False``. Returns ------- samples: torch.Tensor Sampled categories, shape ``(batch_size,)``. Examples -------- >>> import torch >>> import flashinfer >>> torch.manual_seed(42) >>> batch_size = 4 >>> vocab_size = 5 >>> top_p = 0.5 >>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) >>> norm_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) >>> norm_prob tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106], [0.2205, 0.0942, 0.2912, 0.3452, 0.0489], [0.2522, 0.1602, 0.2346, 0.1532, 0.2000], [0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0') >>> samples = flashinfer.sampling.top_p_sampling_from_probs(norm_prob, top_p) >>> samples tensor([1, 2, 0, 4], device='cuda:0', dtype=torch.int32) Note ---- This function expects float32 inputs, and the output is int32. See Also -------- top_k_top_p_sampling_from_probs top_k_sampling_from_probs top_p_renorm_probs """ if check_nan: if torch.any(torch.isnan(probs)): raise ValueError("Input probs contains NaN.") return get_sampling_module().top_p_sampling_from_probs( probs, indices, *_to_tensor_scalar_tuple(top_p), deterministic, generator ) def top_k_sampling_from_probs( probs: torch.Tensor, top_k: Union[torch.Tensor, int], indices: Optional[torch.Tensor] = None, deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, ) -> torch.Tensor: r"""Fused GPU kernel for top-k sampling from probabilities, this operator implements GPU-based rejection sampling without explicit sorting. Check the `blog post `_ for more details. The multiple rounds of rejection sampling are implemented in a single CUDA kernel, which is more efficient than the naive implementation that launches a series of kernels. Parameters ---------- probs: torch.Tensor Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)`` and the i-th output will be sampled from the i-th row of probabilities. When indices is provided, shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique probability distributions. top_k: Union[torch.Tensor, int] Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling. If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold. indices: Optional[torch.Tensor] Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs. For example, if indices[i] = j, then the i-th output will be sampled from probs[j]. This allows reusing the same probability distribution for multiple outputs. If indices is not provided, the i-th output will be sampled from the i-th row of probs. deterministic: bool Whether to use deterministic kernel implementation, default is ``True``. generator: Optional[torch.Generator] A random number generator for the operation. check_nan: bool Whether to check nan in :attr:`probs`, default is ``False``. Returns ------- samples: torch.Tensor Sampled categories, shape ``(batch_size,)``. Examples -------- >>> import torch >>> import flashinfer >>> torch.manual_seed(42) >>> batch_size = 4 >>> vocab_size = 5 >>> top_k = 1 >>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) >>> norm_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) >>> norm_prob tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106], [0.2205, 0.0942, 0.2912, 0.3452, 0.0489], [0.2522, 0.1602, 0.2346, 0.1532, 0.2000], [0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0') >>> samples = flashinfer.sampling.top_k_sampling_from_probs(norm_prob, top_k) >>> samples tensor([3, 3, 0, 1], device='cuda:0', dtype=torch.int32) Note ---- This function expects float32 inputs, and the output is int32. See Also -------- top_k_top_p_sampling_from_probs top_p_sampling_from_probs top_k_renorm_probs """ if check_nan: if torch.any(torch.isnan(probs)): raise ValueError("Input probs contains NaN.") return get_sampling_module().top_k_sampling_from_probs( probs, indices, *_to_tensor_scalar_tuple(top_k), deterministic, generator ) def min_p_sampling_from_probs( probs: torch.Tensor, min_p: Union[torch.Tensor, float], indices: Optional[torch.Tensor] = None, deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, ) -> torch.Tensor: r"""Fused GPU kernel for `min_p sampling `_ from probabilities, this operator implements GPU-based rejection sampling without explicit sorting. Check the `blog post `_ for more details. The multiple rounds of rejection sampling are implemented in a single CUDA kernel, which is more efficient than the naive implementation that launches a series of kernels. Parameters ---------- probs: torch.Tensor Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)`` and the i-th output will be sampled from the i-th row of probabilities. When indices is provided, shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique probability distributions. min_p: Union[torch.Tensor, float] Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for min-p sampling. If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold. indices: Optional[torch.Tensor] Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs. For example, if indices[i] = j, then the i-th output will be sampled from probs[j]. This allows reusing the same probability distribution for multiple outputs. If indices is not provided, the i-th output will be sampled from the i-th row of probs. deterministic: bool Whether to use deterministic kernel implementation, default is ``True``. generator: Optional[torch.Generator] A random number generator for the operation. check_nan: bool Whether to check nan in :attr:`probs`, default is ``False``. Returns ------- samples: torch.Tensor Sampled categories, shape ``(batch_size,)``. Examples -------- >>> import torch >>> import flashinfer >>> torch.manual_seed(42) >>> batch_size = 4 >>> vocab_size = 5 >>> min_p = torch.full((batch_size,), 0.05).to(0) >>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) >>> norm_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) >>> norm_prob tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106], [0.2205, 0.0942, 0.2912, 0.3452, 0.0489], [0.2522, 0.1602, 0.2346, 0.1532, 0.2000], [0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0') >>> samples = flashinfer.sampling.min_p_sampling_from_probs(norm_prob, min_p) >>> samples tensor([1, 2, 1, 4], device='cuda:0', dtype=torch.int32) Note ---- This function expects float32 inputs, and the output is int32. """ if check_nan: if torch.any(torch.isnan(probs)): raise ValueError("Input probs contains NaN.") return get_sampling_module().min_p_sampling_from_probs( probs, indices, *_to_tensor_scalar_tuple(min_p), deterministic, generator ) def top_k_top_p_sampling_from_logits( logits: torch.Tensor, top_k: Union[torch.Tensor, int], top_p: Union[torch.Tensor, float], indices: Optional[torch.Tensor] = None, filter_apply_order: str = "top_k_first", deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, ) -> torch.Tensor: r"""Fused GPU kernel for top-k and top-p sampling from pre-softmax logits, this operator implements GPU-based rejection sampling without explicit sorting. Check the `blog post `_ for more details. The multiple rounds of rejection sampling are implemented in a single CUDA kernel, which is more efficient than the naive implementation that launches a series of kernels. Parameters ---------- logits: torch.Tensor Pre-softmax logits for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)`` and the i-th output will be sampled from the i-th row of logits. When indices is provided, shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique probability distributions. top_k: Union[torch.Tensor, int] Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling. If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold. top_p: Union[torch.Tensor, float] Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling. If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold. indices: Optional[torch.Tensor] Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs. For example, if indices[i] = j, then the i-th output will be sampled from probs[j]. This allows reusing the same probability distribution for multiple outputs. If indices is not provided, the i-th output will be sampled from the i-th row of probs. filter_apply_order: str The order of applying top-k and top-p sampling, should be either ``"top_k_first"`` or ``"joint"``. If ``"top_k_first"``, we first apply top-k filter, then apply top-p sampling on the top-k results. If ``"joint"``, we apply top-k and top-p filter simultaneously in each round. Default is ``"top_k_first"``. deterministic: bool Whether to use deterministic kernel implementation, default is ``True``. generator: Optional[torch.Generator] A random number generator for the operation. check_nan: bool Whether to check nan in :attr:`probs`, default is ``False``. Returns ------- samples: torch.Tensor Sampled categories, shape ``(batch_size,)``. Examples -------- >>> import torch >>> import flashinfer >>> torch.manual_seed(42) >>> batch_size = 4 >>> vocab_size = 5 >>> top_p = 0.5 >>> top_k = 3 >>> logits = torch.rand(batch_size, vocab_size).to(0) >>> logits tensor([[ 1.9269, 1.4873, 0.9007, -2.1055, -0.7581], [ 1.0783, 0.8008, 1.6806, 0.3559, -0.6866], [-0.4934, 0.2415, -0.2316, 0.0418, -0.2516], [ 0.8599, -0.3097, -0.3957, 0.8034, -0.6216]], device='cuda:0') >>> samples = flashinfer.sampling.top_k_top_p_sampling_from_logits(logits, top_k, top_p) >>> samples tensor([0, 2, 1, 3], device='cuda:0', dtype=torch.int32 >>> probs = torch.softmax(logits, dim=-1) >>> probs tensor([[0.4788, 0.3085, 0.1716, 0.0085, 0.0327], [0.2358, 0.1787, 0.4307, 0.1145, 0.0404], [0.1358, 0.2831, 0.1764, 0.2318, 0.1729], [0.3613, 0.1122, 0.1029, 0.3415, 0.0821]], device='cuda:0') >>> samples tensor([0, 2, 1, 3], device='cuda:0', dtype=torch.int32) Note ---- This function expects float32 inputs, and the output is int32. See Also -------- top_k_top_p_sampling_from_probs top_k_mask_logits top_p_sampling_from_probs """ if filter_apply_order == "top_k_first": masked_logits = top_k_mask_logits(logits, top_k) probs = torch.softmax(masked_logits, dim=-1) return top_p_sampling_from_probs( probs, top_p, indices, deterministic, check_nan=check_nan, generator=generator, ) elif filter_apply_order == "joint": probs = torch.softmax(logits, dim=-1) if check_nan: if torch.any(torch.isnan(probs)): raise ValueError("Input probs contains NaN.") return get_sampling_module().top_k_top_p_sampling_from_probs( probs, indices, *_to_tensor_scalar_tuple(top_k), *_to_tensor_scalar_tuple(top_p), deterministic, generator, ) else: raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}") def top_k_top_p_sampling_from_probs( probs: torch.Tensor, top_k: Union[torch.Tensor, int], top_p: Union[torch.Tensor, float], indices: Optional[torch.Tensor] = None, filter_apply_order: str = "top_k_first", deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, ) -> torch.Tensor: r"""Fused GPU kernel for top-k and top-p sampling from probabilities, this operator implements GPU-based rejection sampling without explicit sorting. Check the `blog post `_ for more details. The multiple rounds of rejection sampling are implemented in a single CUDA kernel, which is more efficient than the naive implementation that launches a series of kernels. Parameters ---------- probs: torch.Tensor Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)`` and the i-th output will be sampled from the i-th row of probabilities. When indices is provided, shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique probability distributions. top_k: Union[torch.Tensor, int] Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling. If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold. top_p: Union[torch.Tensor, float] Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling. If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold. indices: Optional[torch.Tensor] Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs. For example, if indices[i] = j, then the i-th output will be sampled from probs[j]. This allows reusing the same probability distribution for multiple outputs. If indices is not provided, the i-th output will be sampled from the i-th row of probs. filter_apply_order: str The order of applying top-k and top-p sampling, should be either ``"top_k_first"`` or ``"joint"``. If ``"top_k_first"``, we first apply top-k filter, then apply top-p sampling on the top-k results. If ``"joint"``, we apply top-k and top-p filter simultaneously in each round. Default is ``"top_k_first"``. deterministic: bool Whether to use deterministic kernel implementation, default is ``True``. generator: Optional[torch.Generator] A random number generator for the operation. check_nan: bool Whether to check nan in :attr:`probs`, default is ``False``. Returns ------- samples: torch.Tensor Sampled categories, shape ``(batch_size,)``. Examples -------- >>> import torch >>> import flashinfer >>> torch.manual_seed(42) >>> batch_size = 4 >>> vocab_size = 5 >>> top_p = torch.full((batch_size,), 0.2).to(0) >>> top_k = torch.full((batch_size,), 2).to(0) >>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) >>> norm_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) >>> norm_prob tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106], [0.2205, 0.0942, 0.2912, 0.3452, 0.0489], [0.2522, 0.1602, 0.2346, 0.1532, 0.2000], [0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0') >>> samples = flashinfer.sampling.top_k_top_p_sampling_from_probs(norm_prob, top_k, top_p) >>> samples tensor([3, 3, 0, 1], device='cuda:0', dtype=torch.int32) Note ---- This function expects float32 inputs, and the output is int32. See Also -------- top_k_sampling_from_probs top_p_sampling_from_probs top_k_renorm_probs top_p_renorm_probs top_k_mask_logits """ if filter_apply_order == "top_k_first": renorm_probs = top_k_renorm_probs(probs, top_k) return top_p_sampling_from_probs( renorm_probs, top_p, indices, deterministic, check_nan=check_nan, generator=generator, ) elif filter_apply_order == "joint": if check_nan: if torch.any(torch.isnan(probs)): raise ValueError("Input probs contains NaN.") return get_sampling_module().top_k_top_p_sampling_from_probs( probs, indices, *_to_tensor_scalar_tuple(top_k), *_to_tensor_scalar_tuple(top_p), deterministic, generator, ) else: raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}") def top_p_renorm_probs( probs: torch.Tensor, top_p: Union[torch.Tensor, float], ) -> torch.Tensor: r"""Fused GPU kernel for renormalizing probabilities by top-p thresholding. Parameters ---------- probs: torch.Tensor Probabilities, shape ``(batch_size, num_classes)``. top_p: Union[torch.Tensor, float] Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-p threshold for for re-normalizing probabilities, should be in ``(0, 1)``. If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold. We mask out the probabilities less than `threshold` where the cumulative sum of ``probs[probs >= threshold]`` is `top_p`, and renormalize the probabilities. Returns ------- renorm_probs: torch.Tensor Renormalized probabilities, shape ``(batch_size, num_classes)``. Examples -------- >>> import torch >>> import flashinfer >>> torch.manual_seed(42) >>> batch_size = 4 >>> vocab_size = 5 >>> top_p = 0.3 >>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) >>> prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) >>> prob tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106], [0.2205, 0.0942, 0.2912, 0.3452, 0.0489], [0.2522, 0.1602, 0.2346, 0.1532, 0.2000], [0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0') >>> renormed_probs = flashinfer.sampling.top_p_renorm_probs(prob, top_p) >>> renormed_probs tensor([[0.0000, 0.4882, 0.0000, 0.5118, 0.0000], [0.0000, 0.0000, 0.0000, 1.0000, 0.0000], [0.5181, 0.0000, 0.4819, 0.0000, 0.0000], [0.0000, 1.0000, 0.0000, 0.0000, 0.0000]], device='cuda:0') Note ---- This combination of ``top_p_renorm_probs`` and ``sampling_from_probs`` should be equivalent to ``top_p_sampling_from_probs``. See Also -------- top_p_sampling_from_probs sampling_from_probs top_k_renorm_probs """ return get_sampling_module().top_p_renorm_probs( probs, *_to_tensor_scalar_tuple(top_p) ) top_p_renorm_prob = top_p_renorm_probs def top_k_renorm_probs( probs: torch.Tensor, top_k: Union[torch.Tensor, int], ) -> torch.Tensor: r"""Fused GPU kernel for renormalizing probabilities by top-k thresholding. Parameters ---------- probs: torch.Tensor Probabilities, shape ``(batch_size, num_classes)``. top_k: Union[torch.Tensor, int] Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-k threshold for for for re-normalizing probabilities, should be in ``(0, num_classes)``. If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold. We keep the top-k probabilities, set the rest to zero, and renormalize the probabilities. Returns ------- renorm_probs: torch.Tensor Renormalized probabilities, shape ``(batch_size, num_classes)``. Examples -------- >>> import torch >>> import flashinfer >>> torch.manual_seed(42) >>> batch_size = 4 >>> vocab_size = 5 >>> top_k = 3 >>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) >>> prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) >>> prob tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106], [0.2205, 0.0942, 0.2912, 0.3452, 0.0489], [0.2522, 0.1602, 0.2346, 0.1532, 0.2000], [0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0') >>> renormed_probs = flashinfer.sampling.top_k_renorm_probs(prob, top_k) >>> renormed_probs tensor([[0.3201, 0.3319, 0.0000, 0.3480, 0.0000], [0.2573, 0.0000, 0.3398, 0.4028, 0.0000], [0.3672, 0.0000, 0.3416, 0.0000, 0.2912], [0.0000, 0.4243, 0.2750, 0.0000, 0.3007]], device='cuda:0') Note ---- This combination of ``top_k_renorm_probs`` and ``sampling_from_probs`` should be equivalent to ``top_k_sampling_from_probs``. See Also -------- top_k_sampling_from_probs sampling_from_probs top_p_renorm_probs """ return get_sampling_module().top_k_renorm_probs( probs, *_to_tensor_scalar_tuple(top_k) ) top_k_renorm_prob = top_k_renorm_probs def top_k_mask_logits( logits: torch.Tensor, top_k: Union[torch.Tensor, int] ) -> torch.Tensor: r"""Fused GPU kernel for masking logits by top-k thresholding. Parameters ---------- logits: torch.Tensor Logits before softmax, shape ``(batch_size, num_classes)``. top_k: Union[torch.Tensor, int] Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-k threshold for for for masking logits, should be in ``(0, num_classes)``. If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold. We keep the top-k logits, set the rest to negative infinity. Returns ------- masked_logits: torch.Tensor Masked logits, shape ``(batch_size, num_classes)``. Examples -------- >>> import torch >>> import flashinfer >>> torch.manual_seed(42) >>> batch_size = 4 >>> vocab_size = 5 >>> top_k = 3 >>> logits = torch.randn(batch_size, vocab_size).to(0) >>> logits tensor([[ 1.9269, 1.4873, 0.9007, -2.1055, -0.7581], [ 1.0783, 0.8008, 1.6806, 0.3559, -0.6866], [-0.4934, 0.2415, -0.2316, 0.0418, -0.2516], [ 0.8599, -0.3097, -0.3957, 0.8034, -0.6216]], device='cuda:0') >>> masked_logits = flashinfer.sampling.top_k_mask_logits(logits, top_k) >>> masked_logits tensor([[ 1.9269, 1.4873, 0.9007, -inf, -inf], [ 1.0783, 0.8008, 1.6806, -inf, -inf], [ -inf, 0.2415, -0.2316, 0.0418, -inf], [ 0.8599, -0.3097, -inf, 0.8034, -inf]], device='cuda:0') Note ---- The combination of ``top_k_mask_logits`` and ``softmax`` should be equivalent to ``top_k_renorm_probs``. See Also -------- top_k_renorm_probs """ return get_sampling_module().top_k_mask_logits( logits, *_to_tensor_scalar_tuple(top_k) ) def chain_speculative_sampling( draft_probs, draft_token_ids, target_probs, maybe_output_accepted_token_num: Optional[torch.Tensor] = None, maybe_output_emitted_draft_token_num: Optional[torch.Tensor] = None, deterministic: bool = True, generator: Optional[torch.Generator] = None, ) -> torch.Tensor: r"""Fused-GPU kernel for speculative sampling for sequence generation (proposed in paper `Accelerating Large Language Model Decoding with Speculative Sampling `_), where the draft model generates a sequence(chain) of tokens for each request. Parameters ---------- draft_probs: torch.Tensor The probability over vocabulary generated by draft model. Shape: ``(batch_size, num_speculate_tokens, vocab_size)`` draft_token_ids: torch.Tensor The draft model's generated token indices. Shape: ``(batch_size, num_speculate_tokens)`` target_probs: torch.Tensor The probability over vocabulary generated by target model. Compared to input :attr:`draft_probs`, the target model's probability has an additional slot at the end because the target model will generate one more token than the draft model. Shape: ``(batch_size, num_speculate_tokens + 1, vocab_size)`` maybe_output_accepted_token_num: Optional[torch.Tensor] The number of tokens that can be accepted if each token is considered independently for each request. This metric does not consider the fact that rejection sampling will stop at the first token that does not satisfy the probability requirement r < p/q. It only evaluates the alignment of draft model and target model. Shape: ``(batch_size)`` If specified, the number of accepted token number will be added to this tensor inplace. Default is ``None``. maybe_output_emitted_draft_token_num: Optional[torch.Tensor] The number of draft tokens that are finally emitted for each request. Does not include the bonus token. (Thus the total number of tokens sampled for a given request is output_emitted_draft_token_num + 1). Shape: ``(batch_size)`` If specified, the number of emitted token number will be added to this tensor inplace. Default is ``None``. deterministic: bool Whether to use deterministic kernel implementation, default is ``True``. generator: Optional[torch.Generator] A random number generator for the operation. Returns ------- output_token_ids: torch.Tensor The output token indices verified by the target model, rejected samples are padded with ``-1``. Compared to input :attr:`draft_token_ids`, the output tensor has an additional token index at the end for the final token, if all previous tokens are accepted, another "bonus" token will be sampled from the target model's probability. Shape: (batch_size, num_speculate_tokens + 1) output_accepted_token_num: torch.Tensor The number of tokens that can be accepted if each token is considered independently for each request. This metric does not consider the fact that rejection sampling will stop at the first token that does not satisfy the probability requirement r < p/q. It only evaluates the alignment of draft model and target model. Shape: ``(batch_size)`` output_emitted_draft_token_num: torch.Tensor The number of draft tokens that are finally emitted for each request. Does not include the bonus token. (Thus the total number of tokens sampled for a given request is output_emitted_draft_token_num + 1). Shape: ``(batch_size)`` Examples -------- >>> import torch >>> import flashinfer >>> torch.manual_seed(42) >>> batch_size = 1 >>> num_speculate_tokens = 2 >>> vocab_size = 4 >>> draft_probs = torch.tensor([[[0.1, 0.2, 0.3, 0.4], [0.2, 0.3, 0.4, 0.1]]]).to(0) >>> # token 2 was sampled from draft model for the first token, and >>> # token 1 was sampled from draft model for the second token >>> draft_token_ids = torch.tensor([[2, 1]], dtype=torch.int32).to(0) >>> target_probs = torch.tensor([[[0.0, 0.1, 0.6, 0.3], [1.0, 0.0, 0.0, 0.0], [0.7, 0.1, 0.1, 0.1]]]).to(0) >>> output_token_ids, output_accepted_token_num, output_emitted_draft_token_num =\ ... flashinfer.sampling.chain_speculative_sampling( ... draft_probs, draft_token_ids, target_probs) >>> # the first token is accepted, the second token is rejected and sampled from the difference >>> # between the target model and the draft model, the third token is padded with -1 >>> output_token_ids tensor([[ 2, 0, -1]], device='cuda:0', dtype=torch.int32) >>> output_accepted_token_num tensor([1], device='cuda:0') >>> output_emitted_draft_token_num tensor([1], device='cuda:0') """ b = draft_probs.size(0) dev = draft_probs.device if maybe_output_accepted_token_num is None: output_accepted_token_num = torch.zeros(b, dtype=torch.int32, device=dev) else: output_accepted_token_num = maybe_output_accepted_token_num if maybe_output_emitted_draft_token_num is None: output_emitted_draft_token_num = torch.zeros(b, dtype=torch.int32, device=dev) else: output_emitted_draft_token_num = maybe_output_emitted_draft_token_num output_token_ids = get_sampling_module().chain_speculative_sampling( draft_probs, draft_token_ids, target_probs, output_accepted_token_num, output_emitted_draft_token_num, deterministic, generator, ) return output_token_ids, output_accepted_token_num, output_emitted_draft_token_num