""" Copyright (c) 2023 by FlashInfer team. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import functools import logging import math from types import SimpleNamespace from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload import torch from .jit import ( gen_batch_prefill_module, gen_customize_batch_prefill_module, gen_fmha_cutlass_sm100a_module, gen_single_prefill_module, get_batch_prefill_uri, get_single_prefill_uri, setup_cubin_loader, gen_trtllm_gen_fmha_module, ) from .cudnn import cudnn_batch_prefill_with_kv_cache from .page import block_sparse_indices_to_vector_sparse_offsets, get_seq_lens from .quantization import packbits, segment_packbits from .utils import ( FP4Tensor, MaskMode, PosEncodingMode, TensorLayout, _check_cached_qkv_data_type, _check_kv_layout, _check_pos_encoding_mode, check_shape_dtype_device, _get_cache_alibi_slopes_buf, _get_cache_buf, _unpack_paged_kv_cache, canonicalize_torch_dtype, determine_attention_backend, device_support_pdl, get_device_sm_count, is_float8, is_sm100a_supported, is_sm110a_supported, register_custom_op, register_fake_op, ceil_div, round_up, ) @functools.cache def get_fmha_module( dtype_q: torch.dtype, dtype_kv: torch.dtype, dtype_o: torch.dtype, dtype_idx: torch.dtype, head_dim_qk: int, head_dim_vo: int, pos_encoding_mode: int, use_sliding_window: bool, use_logits_soft_cap: bool, device: torch.device, use_fp16_qk_reduction: bool = False, ): if is_sm100a_supported(device) or is_sm110a_supported(device): return gen_fmha_cutlass_sm100a_module( dtype_q, dtype_kv, dtype_o, dtype_idx, head_dim_qk, head_dim_vo, pos_encoding_mode, use_sliding_window, use_logits_soft_cap, ).build_and_load() else: raise ValueError("SM100A is not supported on this device") def make_hashable_cache(func): """ Decorator that converts unhashable arguments (like lists) to hashable ones (tuples) before applying functools.cache. """ @functools.cache def cached_wrapper(*args, **kwargs): return func(*args, **kwargs) @functools.wraps(func) def wrapper(*args, **kwargs): # Convert unhashable arguments to hashable ones hashable_args = [] for arg in args: if isinstance(arg, list): hashable_args.append(tuple(arg)) else: hashable_args.append(arg) hashable_kwargs = {} for key, value in kwargs.items(): if isinstance(value, list): hashable_kwargs[key] = tuple(value) else: hashable_kwargs[key] = value return cached_wrapper(*hashable_args, **hashable_kwargs) return wrapper @make_hashable_cache def get_customize_batch_prefill_module( backend: str, uri: str, dtype_q: torch.dtype, dtype_kv: torch.dtype, dtype_o: torch.dtype, idtype: torch.dtype, head_dim_qk: int, head_dim_vo: int, additional_tensor_names: List[str], additional_tensor_dtypes: List[str], additional_scalar_names: List[str], additional_scalar_dtypes: List[str], variant_name: str, variant_decl: str, pos_encoding_mode: int = 0, use_sliding_window: bool = False, use_logits_soft_cap: bool = False, use_fp16_qk_reduction: bool = False, fp8_enabled: bool = False, ): return gen_customize_batch_prefill_module( backend, uri, dtype_q, dtype_kv, dtype_o, idtype, head_dim_qk, head_dim_vo, additional_tensor_names, additional_tensor_dtypes, additional_scalar_names, additional_scalar_dtypes, variant_name, variant_decl, pos_encoding_mode, use_sliding_window, use_logits_soft_cap, use_fp16_qk_reduction, fp8_enabled, ).build_and_load() @functools.cache def get_trtllm_gen_prefill_module(): mod = gen_trtllm_gen_fmha_module() op = mod.build_and_load() setup_cubin_loader(mod.get_library_path()) def _paged_run( query: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, workspace_buffer: torch.Tensor, block_tables: torch.Tensor, seq_lens: torch.Tensor, max_q_len: int, max_kv_len: int, bmm1_scale: float, bmm2_scale: float, batch_size: int, cum_seq_lens_q: torch.Tensor, cum_seq_lens_kv: torch.Tensor, enable_pdl: bool, workspace_size: int, window_left: int = -1, out: Optional[torch.Tensor] = None, sinks: Optional[torch.Tensor] = None, ) -> torch.Tensor: sm_count = get_device_sm_count(query.device) if out is None: out = torch.empty_like(query) op.trtllm_paged_attention_context( out, None, # fp4 output not supported in wrapper api yet. query, k_cache, v_cache, workspace_buffer, block_tables, seq_lens, max_q_len, max_kv_len, bmm1_scale, bmm2_scale, -1, # o_sf_scale -1, # o_sf_vec_size 0, # o_sf_start_index batch_size, window_left, cum_seq_lens_q, cum_seq_lens_kv, sm_count, enable_pdl, workspace_size, sinks, ) return out def _ragged_run(*args, **kwargs): # TODO(Zihao): trtllm-gen backend already supports variable length attention, # but not integrated into flashinfer yet. raise NotImplementedError( "Variable length is not implemented for trtllm-gen backend yet." ) def _plan(*args, **kwargs): pass return SimpleNamespace( paged_run=_paged_run, ragged_run=_ragged_run, plan=_plan, ) @functools.cache def get_single_prefill_module(backend, *args): uri = get_single_prefill_uri(backend, *args) module = gen_single_prefill_module(backend, *args).build_and_load() run_func = module.run.default # torch library for single_prefill_with_kv_cache @register_custom_op( f"flashinfer::{uri}_run", mutates_args=("tmp", "o", "maybe_lse") ) def run_single_prefill( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tmp: torch.Tensor, o: torch.Tensor, maybe_lse: Optional[torch.Tensor], mask_mode: int, layout: int, window_left: int, maybe_packed_custom_mask: Optional[torch.Tensor], maybe_alibi_slopes: Optional[torch.Tensor], logits_soft_cap: float, sm_scale: float, scale_q: Optional[torch.Tensor], scale_k: Optional[torch.Tensor], scale_v: Optional[torch.Tensor], rope_scale: float, rope_theta: float, ) -> None: if backend == "fa3": if not is_float8(q): run_func( q, k, v, tmp, o, maybe_lse, mask_mode, layout, window_left, logits_soft_cap, sm_scale, ) else: # FP8 enabled run_func( q, k, v, tmp, o, maybe_lse, mask_mode, layout, window_left, scale_q, scale_k, scale_v, sm_scale, ) else: run_func( q, k, v, tmp, o, maybe_lse, mask_mode, layout, window_left, maybe_packed_custom_mask, maybe_alibi_slopes, logits_soft_cap, sm_scale, 1.0 / rope_scale, # rope_rcp_scale 1.0 / rope_theta, # rope_rcp_theta ) return o @register_fake_op(f"flashinfer::{uri}_run") def _fake_run_single_prefill( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tmp: torch.Tensor, o: torch.Tensor, maybe_lse: Optional[torch.Tensor], mask_mode: int, layout: int, window_left: int, maybe_packed_custom_mask: Optional[torch.Tensor], maybe_alibi_slopes: Optional[torch.Tensor], logits_soft_cap: float, sm_scale: float, rope_scale: float, rope_theta: float, ) -> None: pass # Register the module return SimpleNamespace(run=run_single_prefill) @functools.cache def get_batch_prefill_module(backend, *args): if backend == "trtllm-gen": uri = "trtllm_gen_context" module = get_trtllm_gen_prefill_module() plan_func = module.plan ragged_run_func = module.ragged_run paged_run_func = module.paged_run else: uri = get_batch_prefill_uri(backend, *args) module = gen_batch_prefill_module(backend, *args).build_and_load() plan_func = module.plan.default ragged_run_func = module.ragged_run.default paged_run_func = module.paged_run.default # torch library for ragged_run @register_custom_op( f"flashinfer::{uri}_ragged_run", mutates_args=( "float_workspace_buffer", "int_workspace_buffer", "o", "maybe_lse", ), ) def ragged_run( float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor, plan_info_vec: List[int], q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, qo_indptr: torch.Tensor, kv_indptr: torch.Tensor, o: torch.Tensor, maybe_lse: Optional[torch.Tensor], mask_mode: int, layout: int, window_left: int, enable_pdl: bool, maybe_custom_mask: Optional[torch.Tensor], maybe_mask_indptr: Optional[torch.Tensor], maybe_alibi_slopes: Optional[torch.Tensor], maybe_prefix_len_ptr: Optional[torch.Tensor], maybe_token_pos_in_items_ptr: Optional[torch.Tensor], maybe_max_item_len_ptr: Optional[torch.Tensor], logits_soft_cap: float, sm_scale: float, rope_scale: float, rope_theta: float, token_pos_in_items_len: int, ) -> None: if backend == "fa2": ragged_run_func( float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, k, v, qo_indptr, kv_indptr, o, maybe_lse, mask_mode, layout, window_left, enable_pdl, maybe_custom_mask, maybe_mask_indptr, maybe_alibi_slopes, maybe_prefix_len_ptr, maybe_token_pos_in_items_ptr, maybe_max_item_len_ptr, logits_soft_cap, sm_scale, 1.0 / rope_scale, # rope_rcp_scale 1.0 / rope_theta, # rope_rcp_theta token_pos_in_items_len, ) else: ragged_run_func( float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, k, v, qo_indptr, kv_indptr, o, maybe_lse, mask_mode, layout, window_left, enable_pdl, maybe_prefix_len_ptr, maybe_token_pos_in_items_ptr, maybe_max_item_len_ptr, logits_soft_cap, sm_scale, token_pos_in_items_len, ) return o @register_fake_op(f"flashinfer::{uri}_ragged_run") def _fake_ragged_run( float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor, plan_info_vec: List[int], q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, qo_indptr: torch.Tensor, kv_indptr: torch.Tensor, o: torch.Tensor, maybe_lse: Optional[torch.Tensor], mask_mode: int, layout: int, window_left: int, enable_pdl: bool, maybe_custom_mask: Optional[torch.Tensor], maybe_mask_indptr: Optional[torch.Tensor], maybe_alibi_slopes: Optional[torch.Tensor], maybe_prefix_len_ptr: Optional[torch.Tensor], maybe_token_pos_in_items_ptr: Optional[torch.Tensor], maybe_max_item_len_ptr: Optional[torch.Tensor], logits_soft_cap: float, sm_scale: float, rope_scale: float, rope_theta: float, token_pos_in_items_len: int, ) -> None: pass # torch library for paged_run @register_custom_op( f"flashinfer::{uri}_paged_run", mutates_args=( "float_workspace_buffer", "int_workspace_buffer", "paged_k_cache", "paged_v_cache", "o", "maybe_lse", ), ) def paged_run( float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor, plan_info_vec: List[int], q: torch.Tensor, paged_k_cache: torch.Tensor, paged_v_cache: torch.Tensor, qo_indptr: torch.Tensor, paged_kv_indptr: torch.Tensor, paged_kv_indices: torch.Tensor, paged_kv_last_page_len: torch.Tensor, o: torch.Tensor, maybe_lse: Optional[torch.Tensor], mask_mode: int, layout: int, window_left: int, enable_pdl: bool, maybe_custom_mask: Optional[torch.Tensor], maybe_mask_indptr: Optional[torch.Tensor], maybe_alibi_slopes: Optional[torch.Tensor], maybe_prefix_len_ptr: Optional[torch.Tensor], maybe_token_pos_in_items_ptr: Optional[torch.Tensor], maybe_max_item_len_ptr: Optional[torch.Tensor], logits_soft_cap: float, sm_scale: float, scale_q: Optional[torch.Tensor], scale_k: Optional[torch.Tensor], scale_v: Optional[torch.Tensor], rope_scale: float, rope_theta: float, token_pos_in_items_len: int, workspace_size: int, num_qo_heads: Optional[int] = None, num_kv_heads: Optional[int] = None, block_tables: Optional[torch.Tensor] = None, kv_lens_buffer: Optional[torch.Tensor] = None, page_size: Optional[int] = None, max_q_len: Optional[int] = None, max_kv_len: Optional[int] = None, batch_size: Optional[int] = None, cum_seq_lens_q: Optional[torch.Tensor] = None, cum_seq_lens_kv: Optional[torch.Tensor] = None, sinks: Optional[torch.Tensor] = None, ) -> None: if backend == "trtllm-gen": assert maybe_lse is None assert num_qo_heads is not None assert num_kv_heads is not None assert block_tables is not None assert kv_lens_buffer is not None assert page_size is not None assert max_kv_len is not None assert batch_size is not None assert cum_seq_lens_q is not None assert cum_seq_lens_kv is not None assert enable_pdl is not None assert workspace_size > 0, "workspace_size must be greater than 0" o = paged_run_func( q.contiguous(), # NOTE(Siyuan): without contiguous, the result is incorrect paged_k_cache, paged_v_cache, int_workspace_buffer, block_tables, kv_lens_buffer, max_q_len, max_kv_len, sm_scale, 1.0, # NOTE(Siyuan): update this to expose bmm2 scale batch_size, cum_seq_lens_q, cum_seq_lens_kv, enable_pdl, workspace_size, window_left, out=o, sinks=sinks, ) elif backend == "fa2": assert not is_float8(q) paged_run_func( float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, paged_k_cache, paged_v_cache, qo_indptr, paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len, o, maybe_lse, mask_mode, layout, window_left, enable_pdl, maybe_custom_mask, maybe_mask_indptr, maybe_alibi_slopes, maybe_prefix_len_ptr, maybe_token_pos_in_items_ptr, maybe_max_item_len_ptr, logits_soft_cap, sm_scale, 1.0 / rope_scale, # rope_rcp_scale 1.0 / rope_theta, # rope_rcp_theta token_pos_in_items_len, ) else: if not is_float8(q): paged_run_func( float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, paged_k_cache, paged_v_cache, qo_indptr, paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len, o, maybe_lse, mask_mode, layout, window_left, enable_pdl, maybe_prefix_len_ptr, maybe_token_pos_in_items_ptr, maybe_max_item_len_ptr, logits_soft_cap, sm_scale, token_pos_in_items_len, ) else: paged_run_func( float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, paged_k_cache, paged_v_cache, qo_indptr, paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len, o, maybe_lse, mask_mode, layout, window_left, enable_pdl, scale_q, scale_k, scale_v, sm_scale, ) return o @register_fake_op(f"flashinfer::{uri}_paged_run") def _fake_paged_run( float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor, plan_info_vec: List[int], q: torch.Tensor, paged_k_cache: torch.Tensor, paged_v_cache: torch.Tensor, qo_indptr: torch.Tensor, paged_kv_indptr: torch.Tensor, paged_kv_indices: torch.Tensor, paged_kv_last_page_len: torch.Tensor, o: torch.Tensor, maybe_lse: Optional[torch.Tensor], mask_mode: int, layout: int, window_left: int, enable_pdl: bool, maybe_custom_mask: Optional[torch.Tensor], maybe_mask_indptr: Optional[torch.Tensor], maybe_alibi_slopes: Optional[torch.Tensor], maybe_prefix_len_ptr: Optional[torch.Tensor], maybe_token_pos_in_items_ptr: Optional[torch.Tensor], maybe_max_item_len_ptr: Optional[torch.Tensor], logits_soft_cap: float, sm_scale: float, rope_scale: float, rope_theta: float, token_pos_in_items_len: int, workspace_size: int, num_qo_heads: Optional[int] = None, num_kv_heads: Optional[int] = None, block_tables: Optional[torch.Tensor] = None, kv_lens_buffer: Optional[torch.Tensor] = None, page_size: Optional[int] = None, max_q_len: Optional[int] = None, max_kv_len: Optional[int] = None, batch_size: Optional[int] = None, cum_seq_lens_q: Optional[torch.Tensor] = None, cum_seq_lens_kv: Optional[torch.Tensor] = None, ) -> None: pass # Register the module. # # Note that plan is not part of model logic. It should not be included in # Cuda Graph or torch.compile. So, we don't provide a torch library for plan. return SimpleNamespace( plan=plan_func, ragged_run=ragged_run, paged_run=paged_run, ) @functools.cache def get_batch_prefill_jit_module(module_name: str, jit_module: Any): plan_func = jit_module.plan.default ragged_run_func = jit_module.ragged_run.default paged_run_func = jit_module.paged_run.default # torch library for ragged_run @register_custom_op( f"flashinfer::{module_name}_ragged_run", mutates_args=( "float_workspace_buffer", "int_workspace_buffer", "o", "maybe_lse", ), ) def ragged_run( float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor, plan_info_vec: List[int], q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, qo_indptr: torch.Tensor, kv_indptr: torch.Tensor, o: torch.Tensor, maybe_lse: Optional[torch.Tensor], mask_mode: int, layout: int, window_left: int, *args, ) -> None: ragged_run_func( float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, k, v, qo_indptr, kv_indptr, o, maybe_lse, mask_mode, layout, window_left, *args, ) @register_fake_op(f"flashinfer::{module_name}_ragged_run") def _fake_ragged_run( float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor, plan_info_vec: List[int], q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, qo_indptr: torch.Tensor, kv_indptr: torch.Tensor, o: torch.Tensor, maybe_lse: Optional[torch.Tensor], mask_mode: int, layout: int, window_left: int, *args, ) -> None: pass # torch library for paged_run @register_custom_op( f"flashinfer::{module_name}_paged_run", mutates_args=( "float_workspace_buffer", "int_workspace_buffer", "paged_k_cache", "paged_v_cache", "o", "maybe_lse", ), ) def paged_run( float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor, plan_info_vec: List[int], q: torch.Tensor, paged_k_cache: torch.Tensor, paged_v_cache: torch.Tensor, qo_indptr: torch.Tensor, paged_kv_indptr: torch.Tensor, paged_kv_indices: torch.Tensor, paged_kv_last_page_len: torch.Tensor, o: torch.Tensor, maybe_lse: Optional[torch.Tensor], mask_mode: int, layout: int, window_left: int, *args, ) -> None: paged_run_func( float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, paged_k_cache, paged_v_cache, qo_indptr, paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len, o, maybe_lse, mask_mode, layout, window_left, *args, ) @register_fake_op(f"flashinfer::{module_name}_paged_run") def _fake_paged_run( float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor, plan_info_vec: List[int], q: torch.Tensor, paged_k_cache: torch.Tensor, paged_v_cache: torch.Tensor, qo_indptr: torch.Tensor, paged_kv_indptr: torch.Tensor, paged_kv_indices: torch.Tensor, paged_kv_last_page_len: torch.Tensor, o: torch.Tensor, maybe_lse: Optional[torch.Tensor], mask_mode: int, layout: int, window_left: int, *args, ) -> None: pass # Register the module. # # Note that plan is not part of model logic. It should not be included in # Cuda Graph or torch.compile. So, we don't provide a torch library for plan. return SimpleNamespace( plan=plan_func, ragged_run=ragged_run, paged_run=paged_run, ) def single_prefill_with_kv_cache_with_jit_module( jit_module: Any, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, kv_layout: str = "NHD", mask_mode: int = MaskMode.NON_CAUSAL.value, window_left: int = -1, return_lse: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: device = q.device tmp = _get_cache_buf( "single_prefill_with_kv_cache_tmp", 32 * 1024 * 1024, device=device ) o = torch.empty(q.shape[:-1] + v.shape[-1:], dtype=q.dtype, device=device) lse = None if return_lse: lse = torch.empty((q.size(0), q.size(1)), dtype=torch.float32, device=device) jit_module.run.default( q, k, v, tmp, o, lse, mask_mode, TensorLayout[kv_layout].value, window_left, *args, ) return (o, lse) if return_lse else o @overload def single_prefill_with_kv_cache( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale_q: Optional[torch.Tensor] = None, scale_k: Optional[torch.Tensor] = None, scale_v: Optional[torch.Tensor] = None, o_dtype: Optional[torch.dtype] = None, custom_mask: Optional[torch.Tensor] = None, packed_custom_mask: Optional[torch.Tensor] = None, causal: bool = False, kv_layout: str = "NHD", pos_encoding_mode: str = "NONE", use_fp16_qk_reduction: bool = False, sm_scale: Optional[float] = None, window_left: int = -1, logits_soft_cap: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, backend: str = "auto", return_lse: Literal[False] = False, ) -> torch.Tensor: ... @overload def single_prefill_with_kv_cache( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale_q: Optional[torch.Tensor] = None, scale_k: Optional[torch.Tensor] = None, scale_v: Optional[torch.Tensor] = None, o_dtype: Optional[torch.dtype] = None, custom_mask: Optional[torch.Tensor] = None, packed_custom_mask: Optional[torch.Tensor] = None, causal: bool = False, kv_layout: str = "NHD", pos_encoding_mode: str = "NONE", use_fp16_qk_reduction: bool = False, sm_scale: Optional[float] = None, window_left: int = -1, logits_soft_cap: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, backend: str = "auto", return_lse: Literal[True] = True, ) -> Tuple[torch.Tensor, torch.Tensor]: ... def single_prefill_with_kv_cache( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale_q: Optional[torch.Tensor] = None, scale_k: Optional[torch.Tensor] = None, scale_v: Optional[torch.Tensor] = None, o_dtype: Optional[torch.dtype] = None, custom_mask: Optional[torch.Tensor] = None, packed_custom_mask: Optional[torch.Tensor] = None, causal: bool = False, kv_layout: str = "NHD", pos_encoding_mode: str = "NONE", use_fp16_qk_reduction: bool = False, sm_scale: Optional[float] = None, window_left: int = -1, logits_soft_cap: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, backend: str = "auto", return_lse: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Prefill/Append attention with KV cache for single request, return the attention output. Parameters ---------- q : torch.Tensor The query tensor, shape: ``[qo_len, num_qo_heads, head_dim_qk]``. k : torch.Tensor The key tensor, shape: ``[kv_len, num_kv_heads, head_dim_qk]`` if :attr:`kv_layout` is ``NHD``, or ``[num_kv_heads, kv_len, head_dim_qk]`` if :attr:`kv_layout` is ``HND``. v : torch.Tensor The key tensor, shape: ``[kv_len, num_kv_heads, head_dim_vo]`` if :attr:`kv_layout` is ``NHD``, ``[num_kv_heads, kv_len, head_dim_vo]`` if :attr:`kv_layout` is ``HND``. scale_q : Optional[torch.Tensor] The scale tensor for query, per-head quantization with shape: ``[num_qo_heads]``. Used with FP8 Quantization. If not provided, will be set to ``1.0``. scale_k : Optional[torch.Tensor] The scale tensor for key, per-head quantization with shape: ``[num_kv_heads]``. Used with FP8 Quantization. If not provided, will be set to ``1.0``. scale_v : Optional[torch.Tensor] The scale tensor for value, per-head quantization with shape: ``[num_kv_heads]``. Used with FP8 Quantization. If not provided, will be set to ``1.0``. o_dtype : Optional[torch.dtype] The output tensor data type, if not provided, will be set to the same as the q. This is necessary as output dtype cannot be automatically inferred in quant. custom_mask : Optional[torch.Tensor] The custom boolean mask tensor, shape: ``[qo_len, kv_len]``. The elements in the mask tensor should be either ``True`` or ``False``, where ``False`` means the corresponding element in the attention matrix will be masked out. When :attr:`custom_mask` is provided, and :attr:`packed_custom_mask` is not, the function will pack the custom mask tensor into a 1D packed mask tensor, which introduces additional overhead. packed_custom_mask : Optional[torch.Tensor] The 1D packed uint8 mask tensor, if provided, the :attr:`custom_mask` will be ignored. The packed mask tensor is generated by :func:`flashinfer.quantization.packbits`. causal : bool Whether to apply causal mask to the attention matrix. This is only effective when :attr:`custom_mask` is not provided. kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. pos_encoding_mode : str The position encoding applied inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. Default is ``NONE``. use_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). window_left : int The left (inclusive) window size for the attention window, when set to ``-1``, the window size will be set to the full length of the sequence. Defaults to ``-1``. logits_soft_cap : Optional[float] The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to ``0``. If greater than 0, the logits will be capped according to formula: :math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`, where :math:`x` is the input logits. sm_scale : Optional[float] The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim_qk)``. rope_scale : Optional[float] The scale used in RoPE interpolation, if not provided, will be set to 1.0. rope_theta : Optional[float] The theta used in RoPE, if not provided, will be set to 1e4. backend : str The implementation backend, could be ``auto``/``fa2`` or ``fa3``. Defaults to ``auto``. If set to ``auto``, the function will automatically choose the backend based on the device architecture and kernel availability. return_lse : bool Whether to return the log sum exp value of the attention logits. Returns ------- Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] If :attr:`return_lse` is ``False``, the attention output, shape: ``[qo_len, num_qo_heads, head_dim_vo]``. If :attr:`return_lse` is ``True``, a tuple of two tensors: * The attention output, shape: ``[qo_len, num_qo_heads, head_dim_vo]``. * The log sum exp value, shape: ``[qo_len, num_qo_heads]``. Examples -------- >>> import torch >>> import flashinfer >>> qo_len = 128 >>> kv_len = 4096 >>> num_qo_heads = 32 >>> num_kv_heads = 4 >>> head_dim = 128 >>> q = torch.randn(qo_len, num_qo_heads, head_dim).half().to("cuda:0") >>> k = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0") >>> v = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0") >>> o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, use_fp16_qk_reduction=True) >>> o.shape torch.Size([128, 32, 128]) >>> mask = torch.tril( >>> torch.full((qo_len, kv_len), True, device="cuda:0"), >>> diagonal=(kv_len - qo_len), >>> ) >>> mask tensor([[ True, True, True, ..., False, False, False], [ True, True, True, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]], device='cuda:0') >>> o_custom = flashinfer.single_prefill_with_kv_cache(q, k, v, custom_mask=mask) >>> torch.allclose(o, o_custom, rtol=1e-3, atol=1e-3) True Note ---- The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` is not equal to ``num_kv_heads``, the function will use `grouped query attention `_. """ _check_pos_encoding_mode(pos_encoding_mode) _check_kv_layout(kv_layout) tmp = _get_cache_buf("single_prefill_with_kv_cache_tmp", 32 * 1024 * 1024, q.device) if logits_soft_cap is None: logits_soft_cap = 0.0 if sm_scale is None: sm_scale = 1.0 / math.sqrt(q.size(-1)) if rope_scale is None: rope_scale = 1.0 if rope_theta is None: rope_theta = 1e4 if custom_mask is not None and packed_custom_mask is None: # create packed custom mask from custom mask packed_custom_mask = packbits( custom_mask.contiguous().view(-1), bitorder="little" ) if packed_custom_mask is not None: mask_mode = MaskMode.CUSTOM.value else: if causal: mask_mode = MaskMode.CAUSAL.value else: mask_mode = MaskMode.NON_CAUSAL.value lse = None if return_lse: lse = torch.empty((q.size(0), q.size(1)), dtype=torch.float32, device=q.device) if is_float8(q): # FP8 quant enabled, do sanity check: # 1. unsupported feature # 2. dtype check assert window_left == -1 assert q.dtype == k.dtype == v.dtype assert q.shape[-1] == k.shape[-1] == v.shape[-1] if scale_q is None: scale_q = torch.ones(q.shape[1], dtype=torch.float32, device=q.device) if scale_k is None: scale_k = torch.ones(k.shape[1], dtype=torch.float32, device=q.device) if scale_v is None: scale_v = torch.ones(v.shape[1], dtype=torch.float32, device=q.device) if backend == "auto": backend = determine_attention_backend( q.device, PosEncodingMode[pos_encoding_mode].value, use_fp16_qk_reduction, packed_custom_mask is not None, # use_custom_mask q.dtype, k.dtype, ) # o_dtype should be provided for FP8 attention if o_dtype is None: o_dtype = q.dtype out = torch.empty(q.shape[:-1] + v.shape[-1:], dtype=o_dtype, device=q.device) module = get_single_prefill_module( backend, q.dtype, k.dtype, out.dtype, q.shape[-1], # head_dim_qk v.shape[-1], # head_dim_vo PosEncodingMode[pos_encoding_mode].value, window_left >= 0, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap use_fp16_qk_reduction, ) module.run( q, k, v, tmp, out, lse, mask_mode, TensorLayout[kv_layout].value, window_left, packed_custom_mask, _get_cache_alibi_slopes_buf(q.shape[1], q.device), logits_soft_cap, sm_scale, scale_q, scale_k, scale_v, rope_scale, rope_theta, ) return (out, lse) if return_lse else out single_prefill_with_kv_cache_return_lse = functools.partial( single_prefill_with_kv_cache, return_lse=True ) def _compute_page_mask_indptr( qo_indptr: torch.Tensor, paged_kv_indptr: torch.Tensor, paged_kv_last_page_len: torch.Tensor, page_size: int, ) -> torch.Tensor: if len(qo_indptr) != len(paged_kv_indptr): raise ValueError( "The length of qo_indptr and paged_kv_indptr should be the same." ) mask_indptr = torch.empty_like(qo_indptr) mask_indptr[0] = 0 mask_indptr[1:] = torch.cumsum( (qo_indptr[1:] - qo_indptr[:-1]) * ( (paged_kv_indptr[1:] - paged_kv_indptr[:-1] - 1) * page_size + paged_kv_last_page_len ), 0, ) return mask_indptr class BatchPrefillWithPagedKVCacheWrapper: r"""Wrapper class for prefill/append attention with paged kv-cache for batch of requests. Check :ref:`our tutorial ` for page table layout. Example ------- >>> import torch >>> import flashinfer >>> num_layers = 32 >>> num_qo_heads = 64 >>> num_kv_heads = 16 >>> head_dim = 128 >>> max_num_pages = 128 >>> page_size = 16 >>> # allocate 128MB workspace buffer >>> workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") >>> prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( ... workspace_buffer, "NHD" ... ) >>> batch_size = 7 >>> nnz_qo = 100 >>> qo_indptr = torch.tensor( ... [0, 33, 44, 55, 66, 77, 88, nnz_qo], dtype=torch.int32, device="cuda:0" ... ) >>> paged_kv_indices = torch.arange(max_num_pages).int().to("cuda:0") >>> paged_kv_indptr = torch.tensor( ... [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0" ... ) >>> # 1 <= paged_kv_last_page_len <= page_size >>> paged_kv_last_page_len = torch.tensor( ... [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0" ... ) >>> q_at_layer = torch.randn(num_layers, nnz_qo, num_qo_heads, head_dim).half().to("cuda:0") >>> kv_cache_at_layer = torch.randn( ... num_layers, max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" ... ) >>> # create auxiliary data structures for batch prefill attention >>> prefill_wrapper.plan( ... qo_indptr, ... paged_kv_indptr, ... paged_kv_indices, ... paged_kv_last_page_len, ... num_qo_heads, ... num_kv_heads, ... head_dim, ... page_size, ... causal=True, ... ) >>> outputs = [] >>> for i in range(num_layers): ... q = q_at_layer[i] ... kv_cache = kv_cache_at_layer[i] ... # compute batch prefill attention, reuse auxiliary data structures ... o = prefill_wrapper.run(q, kv_cache) ... outputs.append(o) ... >>> outputs[0].shape torch.Size([100, 64, 128]) >>> >>> # below is another example of creating custom mask for batch prefill attention >>> mask_arr = [] >>> qo_len = (qo_indptr[1:] - qo_indptr[:-1]).cpu().tolist() >>> kv_len = (page_size * (paged_kv_indptr[1:] - paged_kv_indptr[:-1] - 1) + paged_kv_last_page_len).cpu().tolist() >>> for i in range(batch_size): ... mask_i = torch.tril( ... torch.full((qo_len[i], kv_len[i]), True, device="cuda:0"), ... diagonal=(kv_len[i] - qo_len[i]), ... ) ... mask_arr.append(mask_i.flatten()) ... >>> mask = torch.cat(mask_arr, dim=0) >>> prefill_wrapper.plan( ... qo_indptr, ... paged_kv_indptr, ... paged_kv_indices, ... paged_kv_last_page_len, ... num_qo_heads, ... num_kv_heads, ... head_dim, ... page_size, ... custom_mask=mask, ... ) >>> for i in range(num_layers): ... q = q_at_layer[i] ... kv_cache = kv_cache_at_layer[i] ... # compute batch prefill attention, reuse auxiliary data structures ... o_custom = prefill_wrapper.run(q, kv_cache) ... assert torch.allclose(o_custom, outputs[i], rtol=1e-3, atol=1e-3) ... Note ---- To accelerate computation, FlashInfer's batch prefill/append attention operators create some auxiliary data structures, these data structures can be reused across multiple prefill/append attention calls (e.g. different Transformer layers). This wrapper class manages the lifecycle of these data structures. """ def __init__( self, float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD", use_cuda_graph: bool = False, qo_indptr_buf: Optional[torch.Tensor] = None, paged_kv_indptr_buf: Optional[torch.Tensor] = None, paged_kv_indices_buf: Optional[torch.Tensor] = None, paged_kv_last_page_len_buf: Optional[torch.Tensor] = None, custom_mask_buf: Optional[torch.Tensor] = None, mask_indptr_buf: Optional[torch.Tensor] = None, backend: str = "auto", jit_args: Optional[List[Any]] = None, jit_kwargs: Optional[Dict[str, Any]] = None, ) -> None: r"""Constructor of :class:`BatchPrefillWithPagedKVCacheWrapper`. Parameters ---------- float_workspace_buffer : torch.Tensor The user reserved workspace buffer used to store intermediate attention results in split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors. kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. use_cuda_graph : bool Whether to enable CUDA graph capture for the prefill kernels, if enabled, the auxiliary data structures will be stored in provided buffers. The ``batch_size`` cannot change during the lifecycle of this wrapper when CUDAGraph is enabled. qo_indptr_buf : Optional[torch.Tensor] The user reserved buffer to store the ``qo_indptr`` array, the size of the buffer should be ``[batch_size + 1]``. This argument is only effective when ``use_cuda_graph`` is ``True``. paged_kv_indptr_buf : Optional[torch.Tensor] The user reserved buffer to store the ``paged_kv_indptr`` array, the size of this buffer should be ``[batch_size + 1]``. This argument is only effective when ``use_cuda_graph`` is ``True``. paged_kv_indices_buf : Optional[torch.Tensor] The user reserved buffer to store the ``paged_kv_indices`` array, should be large enough to store the maximum possible size of the ``paged_kv_indices`` array during the lifetime of the wrapper. This argument is only effective when ``use_cuda_graph`` is ``True``. paged_kv_last_page_len_buf : Optional[torch.Tensor] The user reserved buffer to store the ``paged_kv_last_page_len`` array, the size of the buffer should be ``[batch_size]``. This argument is only effective when ``use_cuda_graph`` is ``True``. custom_mask_buf : Optional[torch.Tensor] The user reserved buffer to store the custom mask tensor, should be large enough to store the maximum possible size of the packed custom mask tensor during the lifetime of the wrapper. This argument is only effective when ``use_cuda_graph`` is set to ``True`` and the custom mask will be used in attention computation. mask_indptr_buf : Optional[torch.Tensor] The user reserved buffer to store the ``mask_indptr`` array, the size of the buffer should be ``[batch_size + 1]``. This argument is only effective when ``use_cuda_graph`` is ``True`` and the custom mask will be used in attention computation. backend : str The implementation backend, could be ``auto``/``fa2``,``fa3`` or ``cudnn``. Defaults to ``auto``. If set to ``auto``, the wrapper will automatically choose the backend based on the device architecture and kernel availability. jit_args : Optional[List[Any]] If provided, the wrapper will use the provided arguments to create the JIT module, otherwise, the wrapper will use default attention implementation. jit_kwargs : Optional[Dict[str, Any]] The keyword arguments to create the JIT module, defaults to None. """ _check_kv_layout(kv_layout) if jit_args is not None: if jit_kwargs is None: jit_kwargs = {} self._jit_module = get_batch_prefill_jit_module( jit_args[0], get_customize_batch_prefill_module(backend, *jit_args, **jit_kwargs), ) else: self._jit_module = None self._kv_layout = kv_layout if backend == "cudnn": assert kv_layout == "NHD", "CUDNN backend only supports NHD layout" self._float_workspace_buffer = float_workspace_buffer self._workspace_size = ( self._float_workspace_buffer.numel() * self._float_workspace_buffer.element_size() ) self.device = float_workspace_buffer.device self._vector_sparse_indptr_buffer: Optional[torch.Tensor] = None if backend in ["fa3", "auto", "trtllm-gen"]: # NOTE(Zihao): assume maximum accumulate kv length is 16M self._vector_sparse_indices_buffer = torch.empty( (16 * 1024 * 1024,), dtype=torch.int32, device=self.device ) # NOTE(Zihao): assume maximum batch size is 32768 self._vector_sparse_indptr_buffer = torch.empty( (32768,), dtype=torch.int32, device=self.device ) self._kv_lens_buffer = torch.empty( (32768,), dtype=torch.int32, device=self.device ) self._int_workspace_buffer = torch.empty( (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device ) self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=self._int_workspace_buffer.dtype, device="cpu", pin_memory=True, ) self._use_cuda_graph = use_cuda_graph if use_cuda_graph: if not torch.is_tensor(qo_indptr_buf): raise ValueError( "qo_indptr_buf should be a torch.Tensor in CUDA graph mode" ) if not torch.is_tensor(paged_kv_indptr_buf): raise ValueError( "paged_kv_indptr_buf should be a torch.Tensor in CUDA graph mode" ) if not torch.is_tensor(paged_kv_indices_buf): raise ValueError( "paged_kv_indices_buf should be a torch.Tensor in CUDA graph mode" ) if not torch.is_tensor(paged_kv_last_page_len_buf): raise ValueError( "paged_kv_last_page_len_buf should be a torch.Tensor in CUDA graph mode" ) self._fixed_batch_size = len(qo_indptr_buf) - 1 if len(paged_kv_indptr_buf) != self._fixed_batch_size + 1: raise ValueError( "The length of paged_kv_indptr_buf should be batch_size + 1." ) if len(paged_kv_last_page_len_buf) != self._fixed_batch_size: raise ValueError( "The length of paged_kv_last_page_len_buf should be batch_size." ) # NOTE(Zihao): do not check custom_mask_buf and mask_indptr_buf here, as they are optional else: self._fixed_batch_size = 0 self._qo_indptr_buf = qo_indptr_buf self._paged_kv_indptr_buf = paged_kv_indptr_buf self._paged_kv_indices_buf = paged_kv_indices_buf self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buf self._custom_mask_buf = custom_mask_buf self._mask_indptr_buf = mask_indptr_buf self._max_total_num_rows = None self._backend = backend self._plan_info = None self._cached_module = None self._seq_lens_kv = None self._seq_lens_q = None self._block_tables = None @property def is_cuda_graph_enabled(self) -> bool: return self._use_cuda_graph def reset_workspace_buffer( self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor ) -> None: r"""Reset the workspace buffer. Parameters ---------- float_workspace_buffer : torch.Tensor The new float workspace buffer, the device of the new float workspace buffer should be the same as the device of the input tensors. int_workspace_buffer : torch.Tensor The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors. """ self._float_workspace_buffer = float_workspace_buffer self._int_workspace_buffer = int_workspace_buffer self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=self._int_workspace_buffer.dtype, device="cpu", pin_memory=True, ) def plan( self, qo_indptr: torch.Tensor, paged_kv_indptr: torch.Tensor, paged_kv_indices: torch.Tensor, paged_kv_last_page_len: torch.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim_qk: int, page_size: int, head_dim_vo: Optional[int] = None, custom_mask: Optional[torch.Tensor] = None, packed_custom_mask: Optional[torch.Tensor] = None, causal: bool = False, pos_encoding_mode: str = "NONE", use_fp16_qk_reduction: bool = False, sm_scale: Optional[float] = None, window_left: int = -1, logits_soft_cap: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, q_data_type: Union[str, torch.dtype] = "float16", kv_data_type: Optional[Union[str, torch.dtype]] = None, non_blocking: bool = True, prefix_len_ptr: Optional[torch.Tensor] = None, token_pos_in_items_ptr: Optional[torch.Tensor] = None, token_pos_in_items_len: int = 0, max_item_len_ptr: Optional[torch.Tensor] = None, seq_lens: Optional[torch.Tensor] = None, seq_lens_q: Optional[torch.Tensor] = None, block_tables: Optional[torch.Tensor] = None, max_token_per_sequence: Optional[int] = None, max_sequence_kv: Optional[int] = None, ) -> None: r"""Plan batch prefill/append attention on Paged KV-Cache for given problem specification. Parameters ---------- qo_indptr : torch.Tensor The indptr of the query/output tensor, shape: ``[batch_size + 1]``. paged_kv_indptr : torch.Tensor The indptr of the paged kv-cache, shape: ``[batch_size + 1]``. paged_kv_indices : torch.Tensor The page indices of the paged kv-cache, shape: ``[qo_indptr[-1]]``. paged_kv_last_page_len : torch.Tensor The number of entries in the last page of each request in the paged kv-cache, shape: ``[batch_size]``. num_qo_heads : int The number of query/output heads. num_kv_heads : int The number of key/value heads. head_dim_qk : int The dimension of the query/key heads. page_size : int The size of each page in the paged kv-cache. head_dim_vo : Optional[int] The dimension of the value/output heads, if not provided, will be set to ``head_dim_qk``. custom_mask : Optional[torch.Tensor] The flattened boolean mask tensor, shape: ``(sum(q_len[i] * k_len[i] for i in range(batch_size))``. The elements in the mask tensor should be either ``True`` or ``False``, where ``False`` means the corresponding element in the attention matrix will be masked out. Please refer to the :ref:`mask layout ` for more details about flattened layout of mask tensor. When :attr:`custom_mask` is provided, and :attr:`packed_custom_mask` is not, the function will pack the custom mask tensor into a 1D packed mask tensor, which introduces additional overhead. packed_custom_mask : Optional[torch.Tensor] The 1D packed uint8 mask tensor, if provided, the :attr:`custom_mask` will be ignored. The packed mask tensor is generated by :func:`flashinfer.quantization.packbits`. causal : bool Whether to apply causal mask to the attention matrix. This is only effective when :attr:`custom_mask` is not provided in :meth:`plan`. pos_encoding_mode : str The position encoding applied inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. Default is ``NONE``. use_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). window_left : int The left (inclusive) window size for the attention window, when set to ``-1``, the window size will be set to the full length of the sequence. Defaults to ``-1``. logits_soft_cap : Optional[float] The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to ``0``. If greater than 0, the logits will be capped according to formula: :math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`, where :math:`x` is the input logits. sm_scale : Optional[float] The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. rope_scale : Optional[float] The scale used in RoPE interpolation, if not provided, will be set to ``1.0``. rope_theta : Optional[float] The theta used in RoPE, if not provided, will be set to ``1e4``. q_data_type : Union[str, torch.dtype] The data type of the query tensor, defaults torch.float16. kv_data_type : Optional[Union[str, torch.dtype]] The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`. non_blocking : bool Whether to copy the input tensors to the device asynchronously, defaults to ``True``. prefix_len_ptr :Optional[torch.Tensor] prefix length. A uint32 1D tensor indicating the prefix length of each prompt. The tensor size is equal to the batch size. token_pos_in_items_ptr : Optional[float] A uint16 1D tensor (it will be converted to uint16 in flashinfer) indicating the token position of each item and started from 0 (delimiter) for each item. E.g., if we have 3 items of length 3, 2, 4 respectively for this member. This vector will be looking like `[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0]` with 4 delimiters indexed as 0. For batch size > 1, we will concat them as 1D with zero paddings to make sure each has the same length, the padding length is defined by `token_pos_in_items_len` - length of the raw `token_pos_in_items_ptr` for each prompt. token_pos_in_items_len : int zero padding length for `token_pos_in_items_ptr` to better handle the bsz > 1 case. Still using the above 3,2,4 example. If we set `token_pos_in_items_len` to be 20, it will be `[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0]` with 7 padded zeros. (note there're 8 zeros in the end where the first one is the delimiter token 0 in the end of the prompt) max_item_len_ptr : Optional[float] a uint16 vector contains the max token length of all items for each prompt seq_lens: Optional[torch.Tensor] A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]``. seq_lens_q: Optional[torch.Tensor] A uint32 1D tensor indicating the q sequence length of each prompt. shape: ``[batch_size]``. If not provided, will be set to the same value as ``seq_lens``. block_tables: Optional[torch.Tensor] A uint32 2D tensor indicating the block table of each prompt. shape: ``[batch_size, max_num_blocks_per_seq]``. max_token_per_sequence: Optional[int], Required for cudnn backend. This is the scalar max token length of each sequence. max_sequence_kv: Optional[int], Required for cudnn backend. This is the scalar max sequence length of each sequence in kv cache. Note ---- The :meth:`plan` method should be called before any :meth:`run` or :meth:`run_return_lse` calls, auxiliary data structures will be created during this call and cached for multiple kernel runs. The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` is not equal to ``num_kv_heads``, the function will use `grouped query attention `_. The :meth:`plan` method cannot be used in Cuda Graph or in ``torch.compile``. """ q_data_type = canonicalize_torch_dtype(q_data_type) if kv_data_type is None: kv_data_type = q_data_type kv_data_type = canonicalize_torch_dtype(kv_data_type) if logits_soft_cap is None: logits_soft_cap = 0.0 if head_dim_vo is None: head_dim_vo = head_dim_qk batch_size = len(qo_indptr) - 1 self._batch_size = batch_size self._num_qo_heads = num_qo_heads self._num_kv_heads = num_kv_heads if custom_mask is not None or packed_custom_mask is not None: mask_indptr = _compute_page_mask_indptr( qo_indptr, paged_kv_indptr, paged_kv_last_page_len, page_size, ) if packed_custom_mask is None and custom_mask is not None: # create packed custom mask from custom mask packed_custom_mask, mask_indptr = segment_packbits( custom_mask.contiguous().view(-1), mask_indptr, bitorder="little", ) self._prefix_len_ptr = prefix_len_ptr self._token_pos_in_items_ptr = token_pos_in_items_ptr self._token_pos_in_items_len = token_pos_in_items_len self._max_item_len_ptr = max_item_len_ptr # NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors if max_token_per_sequence is not None: self._max_q_len = max_token_per_sequence else: qo_indptr_host = qo_indptr.to("cpu") self._max_q_len = max(qo_indptr_host).item() total_num_rows = qo_indptr_host[-1] if max_sequence_kv is not None: self._max_kv_len = max_sequence_kv else: paged_kv_indptr_host = paged_kv_indptr.to("cpu") paged_kv_last_page_len_host = paged_kv_last_page_len.to("cpu") if seq_lens is None: kv_lens_arr_host = get_seq_lens( paged_kv_indptr_host, paged_kv_last_page_len_host, page_size ) else: kv_lens_arr_host = seq_lens.cpu().flatten() self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_( kv_lens_arr_host, non_blocking=non_blocking ) self._max_kv_len = max(kv_lens_arr_host).item() if self.is_cuda_graph_enabled: if self._max_total_num_rows is None: self._max_total_num_rows = total_num_rows elif total_num_rows > self._max_total_num_rows: raise ValueError( "The total number of rows in qo_indptr {} in cuda graph mode cannot " "exceed the number of rows set during initialization {}.".format( total_num_rows, self._max_total_num_rows ) ) if batch_size != self._fixed_batch_size: raise ValueError( "The batch size should be fixed during the lifecycle of the wrapper in " "cuda graph mode, the runtime batch size {} mismatches the batch size {} " " set during initialization.".format( batch_size, self._fixed_batch_size ) ) if len(paged_kv_indices) > len(self._paged_kv_indices_buf): raise ValueError( "The length of paged_kv_indices exceeds the allocated buffer size." ) self._qo_indptr_buf.copy_(qo_indptr, non_blocking=non_blocking) self._paged_kv_indptr_buf.copy_(paged_kv_indptr, non_blocking=non_blocking) self._paged_kv_last_page_len_buf.copy_( paged_kv_last_page_len, non_blocking=non_blocking ) self._paged_kv_indices_buf[: len(paged_kv_indices)].copy_( paged_kv_indices, non_blocking=(paged_kv_indices.device == self.device) and non_blocking, ) if packed_custom_mask is not None: if not torch.is_tensor(self._custom_mask_buf): raise ValueError( "custom_mask_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation." ) if not torch.is_tensor(self._mask_indptr_buf): raise ValueError( "mask_indptr_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation." ) self._custom_mask_buf[: len(packed_custom_mask)].copy_( packed_custom_mask, non_blocking=(packed_custom_mask.device == self.device) and non_blocking, ) # NOTE(Zihao): mask_indptr has the same length as qo_indptr self._mask_indptr_buf.copy_(mask_indptr, non_blocking=non_blocking) else: self._qo_indptr_buf = qo_indptr.to(self.device, non_blocking=non_blocking) self._paged_kv_indptr_buf = paged_kv_indptr.to( self.device, non_blocking=non_blocking ) self._paged_kv_indices_buf = paged_kv_indices.to( self.device, non_blocking=non_blocking ) self._paged_kv_last_page_len_buf = paged_kv_last_page_len.to( self.device, non_blocking=non_blocking ) if packed_custom_mask is not None: self._custom_mask_buf = packed_custom_mask.to( self.device, non_blocking=non_blocking ) self._mask_indptr_buf = mask_indptr.to( self.device, non_blocking=non_blocking ) else: self._custom_mask_buf = None self._mask_indptr_buf = None self._cached_q_data_type = q_data_type self._cached_kv_data_type = kv_data_type if self._jit_module is not None: self._cached_module = self._jit_module else: if self._backend == "auto": self._backend = determine_attention_backend( self.device, PosEncodingMode[pos_encoding_mode].value, use_fp16_qk_reduction, self._custom_mask_buf is not None, # use_custom_mask q_data_type, kv_data_type, ) if self._backend != "cudnn": get_module_args = ( q_data_type, kv_data_type, q_data_type, paged_kv_indptr.dtype, head_dim_qk, head_dim_vo, PosEncodingMode[pos_encoding_mode].value, window_left >= 0, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap use_fp16_qk_reduction, ) self._cached_module = get_batch_prefill_module( self._backend, *get_module_args ) if self._backend == "fa3" or self._backend == "trtllm-gen": if page_size != 1: vector_sparse_indptr_host = torch.cat( [ torch.tensor( [0], dtype=torch.int32, device=kv_lens_arr_host.device ), torch.cumsum(kv_lens_arr_host, dim=0, dtype=torch.int32), ], dim=0, ) self._vector_sparse_indptr_buffer[ : len(vector_sparse_indptr_host) ].copy_(vector_sparse_indptr_host, non_blocking=non_blocking) paged_kv_indptr_host = vector_sparse_indptr_host self._block_tables = block_tables if self._backend == "trtllm-gen": assert self._kv_layout == "HND" assert logits_soft_cap == 0.0 if self._block_tables is None: blocks_per_seq = [ (seq_len + page_size - 1) // page_size for seq_len in kv_lens_arr_host ] max_num_blocks_per_seq = max(blocks_per_seq) self._block_tables = torch.zeros( (batch_size, max_num_blocks_per_seq), dtype=torch.int, device=self.device, ) block_id = paged_kv_indptr_host[0] for i in range(batch_size): num_blocks_needed = blocks_per_seq[i] assert self._block_tables is not None, ( "block_tables is not initialized" ) self._block_tables[i, :num_blocks_needed] = paged_kv_indices[ block_id : block_id + num_blocks_needed ] block_id += num_blocks_needed if self._cached_module is not None: self._plan_info = self._cached_module.plan( self._float_workspace_buffer, self._int_workspace_buffer, self._pin_memory_int_workspace_buffer, qo_indptr_host, paged_kv_indptr_host, kv_lens_arr_host, self._max_total_num_rows or total_num_rows, batch_size, num_qo_heads, num_kv_heads, page_size, self.is_cuda_graph_enabled, head_dim_qk, head_dim_vo, causal, ) self._causal = causal self._pos_encoding_mode = pos_encoding_mode self._use_fp16_qk_reduction = use_fp16_qk_reduction self._window_left = window_left self._logits_soft_cap = logits_soft_cap self._sm_scale = sm_scale self._rope_scale = rope_scale self._rope_theta = rope_theta self._seq_lens_kv = seq_lens self._seq_lens_q = seq_lens_q if seq_lens_q is not None else seq_lens begin_forward = plan def forward( self, q: torch.Tensor, paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], causal: bool = False, pos_encoding_mode: str = "NONE", use_fp16_qk_reduction: bool = False, k_scale: Optional[float] = None, v_scale: Optional[float] = None, window_left: int = -1, logits_soft_cap: Optional[float] = None, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ) -> torch.Tensor: r"""Warning: This function is deprecated, please use :meth:`run` instead.""" self._causal = causal self._pos_encoding_mode = pos_encoding_mode self._use_fp16_qk_reduction = use_fp16_qk_reduction self._window_left = window_left self._logits_soft_cap = logits_soft_cap self._sm_scale = sm_scale self._rope_scale = rope_scale self._rope_theta = rope_theta return self.run(q, paged_kv_cache, k_scale=k_scale, v_scale=v_scale) @overload def run( self, q: torch.Tensor, paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], *args, k_scale: Optional[float] = None, v_scale: Optional[float] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, return_lse: Literal[False] = False, enable_pdl: Optional[bool] = None, window_left: Optional[int] = None, ) -> torch.Tensor: ... @overload def run( self, q: torch.Tensor, paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], *args, k_scale: Optional[float] = None, v_scale: Optional[float] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, return_lse: Literal[True] = True, enable_pdl: Optional[bool] = None, window_left: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... def run( self, q: torch.Tensor, paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], *args, q_scale: Optional[float] = None, k_scale: Optional[float] = None, v_scale: Optional[float] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, return_lse: bool = False, enable_pdl: Optional[bool] = None, window_left: Optional[int] = None, sinks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Compute batch prefill/append attention between query and paged kv-cache. Parameters ---------- q : torch.Tensor The query tensor, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]`` paged_kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] The paged KV-Cache stored as a tuple of tensors or a single tensor: * a tuple ``(k_cache, v_cache)`` of 4-D tensors, each with shape: ``[max_num_pages, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, and ``[max_num_pages, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``. * a single 5-D tensor with shape: ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, and ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``. Where ``paged_kv_cache[:, 0]`` is the key-cache and ``paged_kv_cache[:, 1]`` is the value-cache. *args Additional arguments for custom kernels. k_scale : Optional[float] The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``. v_scale : Optional[float] The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``. out : Optional[torch.Tensor] The output tensor, if not provided, will be allocated internally. lse : Optional[torch.Tensor] The log-sum-exp of attention logits, if not provided, will be allocated internally. return_lse : bool Whether to return the logsumexp of attention output enable_pdl : bool Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization Only supported for >= sm90, and currently only for FA2 and CUDA core decode. Returns ------- Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] If :attr:`return_lse` is ``False``, the attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. If :attr:`return_lse` is ``True``, a tuple of two tensors: * The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. * The logsumexp of attention output, shape: ``[qo_indptr[-1], num_qo_heads]``. """ if enable_pdl is None: enable_pdl = device_support_pdl(q.device) k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout) _check_cached_qkv_data_type( q, k_cache, self._cached_q_data_type, self._cached_kv_data_type ) stride_block = k_cache.stride(0) if self._kv_layout == "NHD": page_size = k_cache.shape[1] stride_n = k_cache.stride(1) else: page_size = k_cache.shape[2] stride_n = k_cache.stride(2) window_left = self._window_left if window_left is None else window_left if self._backend != "trtllm-gen": # NOTE(Siyuan): since window_left is appeared in the plan function, we need to make sure it is the same as the one in the plan function. # Remove this check if the backend supports dynamic window_left. assert window_left == self._window_left logits_soft_cap = self._logits_soft_cap sm_scale = self._sm_scale rope_scale = self._rope_scale rope_theta = self._rope_theta if logits_soft_cap is None: logits_soft_cap = 0.0 if sm_scale is None: sm_scale = 1.0 / math.sqrt(q.size(-1)) if q_scale is not None: sm_scale *= q_scale if k_scale is not None: sm_scale *= k_scale if rope_scale is None: rope_scale = 1.0 if rope_theta is None: rope_theta = 1e4 if return_lse: if lse is None: lse = torch.empty( (q.size(0), q.size(1)), dtype=torch.float32, device=q.device ) else: check_shape_dtype_device( lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse" ) if out is None: out = torch.empty( q.shape[:-1] + v_cache.shape[-1:], dtype=q.dtype, device=q.device ) else: check_shape_dtype_device( out, q.shape[:-1] + v_cache.shape[-1:], q.dtype, q.device, "out" ) if self._custom_mask_buf is not None: mask_mode = MaskMode.CUSTOM.value else: if self._causal: mask_mode = MaskMode.CAUSAL.value else: mask_mode = MaskMode.NON_CAUSAL.value if self._prefix_len_ptr is not None: mask_mode = MaskMode.MULTIITEMSCORING.value if self._backend == "fa3": # NOTE(Zihao): we divide both stride_block and stride_n by stride_n # because we will multiply stride_n back in the kernel sparse_indices = block_sparse_indices_to_vector_sparse_offsets( self._paged_kv_indices_buf, self._paged_kv_indptr_buf, self._vector_sparse_indices_buffer, # output self._vector_sparse_indptr_buffer, self._kv_lens_buffer, stride_block // stride_n, 1, # stride_n // stride_n page_size, ) sparse_indptr = self._vector_sparse_indptr_buffer else: sparse_indices = self._paged_kv_indices_buf sparse_indptr = self._paged_kv_indptr_buf if self._backend == "cudnn": if self._seq_lens_q is not None and self._seq_lens_q.dim() == 1: self._seq_lens_q = self._seq_lens_q.reshape(self._batch_size, 1, 1, 1) if self._seq_lens_kv is not None and self._seq_lens_kv.dim() == 1: self._seq_lens_kv = self._seq_lens_kv.reshape(self._batch_size, 1, 1, 1) cudnn_batch_prefill_with_kv_cache( q, k_cache, # Need to be changed v_cache, # Need to be changed self._sm_scale, self._float_workspace_buffer, actual_seq_lens_q=self._seq_lens_q, actual_seq_lens_kv=self._seq_lens_kv, max_token_per_sequence=self._max_q_len, max_sequence_kv=self._max_kv_len, block_tables=self._block_tables, causal=self._causal, return_lse=return_lse, batch_offsets_q=self._qo_indptr_buf, batch_offsets_o=self._qo_indptr_buf, out=out, lse=lse, ) else: if self._backend != "trtllm-gen": assert self._plan_info is not None, "plan info is not initialized" run_args = [ self._float_workspace_buffer, self._int_workspace_buffer, self._plan_info, q, k_cache, v_cache, self._qo_indptr_buf, sparse_indptr, sparse_indices, self._paged_kv_last_page_len_buf, out, lse, mask_mode, TensorLayout[self._kv_layout].value, window_left, enable_pdl, ] if self._jit_module is not None: run_args.extend(list(args)) else: run_args += [ self._custom_mask_buf, self._mask_indptr_buf, _get_cache_alibi_slopes_buf(q.shape[1], q.device), self._prefix_len_ptr, self._token_pos_in_items_ptr, self._max_item_len_ptr, logits_soft_cap, sm_scale, None, # scale_q, not supported yet None, # scale_k None, # scale_v rope_scale, rope_theta, self._token_pos_in_items_len, self._workspace_size, self._num_qo_heads, self._num_kv_heads, self._block_tables, self._kv_lens_buffer, page_size, self._max_q_len, self._max_kv_len, self._batch_size, self._qo_indptr_buf, self._vector_sparse_indptr_buffer, sinks, ] assert self._cached_module is not None, "cached module is not initialized" self._cached_module.paged_run(*run_args) if v_scale is not None: # TODO(Zihao): fused into kernel if is_float8(out): out = (out.to(torch.float32) * v_scale).to(out.dtype) else: out *= v_scale return (out, lse) if return_lse else out run_return_lse = functools.partialmethod(run, return_lse=True) def forward_return_lse( self, q: torch.Tensor, paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], causal: bool = False, pos_encoding_mode: str = "NONE", use_fp16_qk_reduction: bool = False, k_scale: Optional[float] = None, v_scale: Optional[float] = None, window_left: int = -1, logits_soft_cap: Optional[float] = None, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Warning: This function is deprecated, please use :meth:`run_return_lse` instead.""" self._causal = causal self._pos_encoding_mode = pos_encoding_mode self._use_fp16_qk_reduction = use_fp16_qk_reduction self._window_left = window_left self._logits_soft_cap = logits_soft_cap self._sm_scale = sm_scale self._rope_scale = rope_scale self._rope_theta = rope_theta return self.run_return_lse(q, paged_kv_cache, k_scale=k_scale, v_scale=v_scale) def end_forward(self) -> None: r"""Warning: this function is deprecated and has no effect.""" pass def _compute_mask_indptr( qo_indptr: torch.Tensor, kv_indptr: torch.Tensor ) -> torch.Tensor: if len(qo_indptr) != len(kv_indptr): raise ValueError("The length of qo_indptr and kv_indptr should be the same.") mask_indptr = torch.empty_like(qo_indptr) mask_indptr[0] = 0 mask_indptr[1:] = torch.cumsum( (qo_indptr[1:] - qo_indptr[:-1]) * (kv_indptr[1:] - kv_indptr[:-1]), 0, ) return mask_indptr class BatchPrefillWithRaggedKVCacheWrapper: r"""Wrapper class for prefill/append attention with ragged (tensor) kv-cache for batch of requests. Check :ref:`our tutorial ` for ragged kv-cache layout. Example ------- >>> import torch >>> import flashinfer >>> num_layers = 32 >>> num_qo_heads = 64 >>> num_kv_heads = 16 >>> head_dim = 128 >>> # allocate 128MB workspace buffer >>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") >>> prefill_wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( ... workspace_buffer, "NHD" ... ) >>> batch_size = 7 >>> nnz_kv = 100 >>> nnz_qo = 100 >>> qo_indptr = torch.tensor( ... [0, 33, 44, 55, 66, 77, 88, nnz_qo], dtype=torch.int32, device="cuda:0" ... ) >>> kv_indptr = qo_indptr.clone() >>> q_at_layer = torch.randn(num_layers, nnz_qo, num_qo_heads, head_dim).half().to("cuda:0") >>> k_at_layer = torch.randn(num_layers, nnz_kv, num_kv_heads, head_dim).half().to("cuda:0") >>> v_at_layer = torch.randn(num_layers, nnz_kv, num_kv_heads, head_dim).half().to("cuda:0") >>> # create auxiliary data structures for batch prefill attention >>> prefill_wrapper.plan( ... qo_indptr, ... kv_indptr, ... num_qo_heads, ... num_kv_heads, ... head_dim, ... causal=True, ... ) >>> outputs = [] >>> for i in range(num_layers): ... q = q_at_layer[i] ... k = k_at_layer[i] ... v = v_at_layer[i] ... # compute batch prefill attention, reuse auxiliary data structures ... o = prefill_wrapper.run(q, k, v) ... outputs.append(o) ... >>> outputs[0].shape torch.Size([100, 64, 128]) >>> >>> # below is another example of creating custom mask for batch prefill attention >>> mask_arr = [] >>> qo_len = (qo_indptr[1:] - qo_indptr[:-1]).cpu().tolist() >>> kv_len = (kv_indptr[1:] - kv_indptr[:-1]).cpu().tolist() >>> for i in range(batch_size): ... mask_i = torch.tril( ... torch.full((qo_len[i], kv_len[i]), True, device="cuda:0"), ... diagonal=(kv_len[i] - qo_len[i]), ... ) ... mask_arr.append(mask_i.flatten()) ... >>> mask = torch.cat(mask_arr, dim=0) >>> prefill_wrapper.plan( ... qo_indptr, ... kv_indptr, ... num_qo_heads, ... num_kv_heads, ... head_dim, ... custom_mask=mask ... ) >>> outputs_custom_mask = [] >>> for i in range(num_layers): ... q = q_at_layer[i] ... k = k_at_layer[i] ... v = v_at_layer[i] ... # compute batch prefill attention, reuse auxiliary data structures ... o_custom = prefill_wrapper.run(q, k, v) ... assert torch.allclose(o_custom, outputs[i], rtol=1e-3, atol=1e-3) ... >>> outputs_custom_mask[0].shape torch.Size([100, 64, 128]) Note ---- To accelerate computation, FlashInfer's batch prefill/append attention operators create some auxiliary data structures, these data structures can be reused across multiple prefill/append attention calls (e.g. different Transformer layers). This wrapper class manages the lifecycle of these data structures. """ def __init__( self, float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD", use_cuda_graph: bool = False, qo_indptr_buf: Optional[torch.Tensor] = None, kv_indptr_buf: Optional[torch.Tensor] = None, custom_mask_buf: Optional[torch.Tensor] = None, mask_indptr_buf: Optional[torch.Tensor] = None, backend: str = "auto", jit_args: Optional[List[Any]] = None, jit_kwargs: Optional[Dict[str, Any]] = None, ) -> None: r"""Constructor of :class:`BatchPrefillWithRaggedKVCacheWrapper`. Parameters ---------- float_workspace_buffer : torch.Tensor The user reserved float workspace buffer used to store intermediate attention results in the split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors. kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. use_cuda_graph : bool Whether to enable CUDA graph capture for the prefill kernels, if enabled, the auxiliary data structures will be stored as the provided buffers. qo_indptr_buf : Optional[torch.Tensor] The user reserved GPU buffer to store the ``qo_indptr`` array, the size of the buffer should be ``[batch_size + 1]``. This argument is only effective when ``use_cuda_graph`` is ``True``. kv_indptr_buf : Optional[torch.Tensor] The user reserved GPU buffer to store the ``kv_indptr`` array, the size of the buffer should be ``[batch_size + 1]``. This argument is only effective when ``use_cuda_graph`` is ``True``. custom_mask_buf : Optional[torch.Tensor] The user reserved GPU buffer to store the custom mask tensor, should be large enough to store the maximum possible size of the packed custom mask tensor during the lifetime of the wrapper. This argument is only effective when ``use_cuda_graph`` is ``True`` and custom mask will be used in attention computation. mask_indptr_buf : Optional[torch.Tensor] The user reserved GPU buffer to store the ``mask_indptr`` array, the size of the buffer should be ``[batch_size]``. This argument is only effective when ``use_cuda_graph`` is ``True`` and custom mask will be used in attention computation. backend : str The implementation backend, could be ``auto``/``fa2``/``fa3`` or ``trtllm-gen``. Defaults to ``auto``. If set to ``auto``, the wrapper will automatically choose the backend based on the device architecture and kernel availability. jit_args : Optional[List[Any]] If provided, the wrapper will use the provided arguments to create the JIT module, otherwise, the wrapper will use default attention implementation. jit_kwargs : Optional[Dict[str, Any]] The keyword arguments to create the JIT module, defaults to None. """ _check_kv_layout(kv_layout) if jit_args is not None: if jit_kwargs is None: jit_kwargs = {} self._jit_module = get_batch_prefill_jit_module( jit_args[0], get_customize_batch_prefill_module(backend, *jit_args, **jit_kwargs), ) else: self._jit_module = None self._kv_layout = kv_layout self._float_workspace_buffer = float_workspace_buffer self.device = float_workspace_buffer.device self._int_workspace_buffer = torch.empty( (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device ) self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=torch.uint8, pin_memory=True, device="cpu", ) self._use_cuda_graph = use_cuda_graph if use_cuda_graph: if not torch.is_tensor(qo_indptr_buf): raise ValueError( "qo_indptr_buf should be a torch.Tensor in cuda graph mode" ) if not torch.is_tensor(kv_indptr_buf): raise ValueError( "kv_indptr_buf should be a torch.Tensor in cuda graph mode" ) self._fixed_batch_size = len(qo_indptr_buf) - 1 if len(kv_indptr_buf) != self._fixed_batch_size + 1: raise ValueError( "The length of kv_indptr_buf ({}) should be the same as qo_indptr_buf ({}).".format( len(kv_indptr_buf), self._fixed_batch_size ) ) # NOTE(Zihao): do not check custom_mask_buf and mask_indptr_buf here, # as they may not be used. self._qo_indptr_buf = qo_indptr_buf self._kv_indptr_buf = kv_indptr_buf self._custom_mask_buf = custom_mask_buf self._mask_indptr_buf = mask_indptr_buf self._max_total_num_rows = None self._backend = backend self._cached_module = None @property def is_cuda_graph_enabled(self) -> bool: return self._use_cuda_graph def reset_workspace_buffer( self, float_workspace_buffer: torch.Tensor, int_workspace_buffer ) -> None: r"""Reset the workspace buffer. Parameters ---------- float_workspace_buffer : torch.Tensor The new float workspace buffer, the device of the new float workspace buffer should be the same as the device of the input tensors. int_workspace_buffer : torch.Tensor The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors. """ self._float_workspace_buffer = float_workspace_buffer self._int_workspace_buffer = int_workspace_buffer self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=self._int_workspace_buffer.dtype, device="cpu", pin_memory=True, ) def plan( self, qo_indptr: torch.Tensor, kv_indptr: torch.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim_qk: int, head_dim_vo: Optional[int] = None, custom_mask: Optional[torch.Tensor] = None, packed_custom_mask: Optional[torch.Tensor] = None, causal: bool = False, pos_encoding_mode: str = "NONE", use_fp16_qk_reduction: bool = False, window_left: int = -1, logits_soft_cap: Optional[float] = None, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, q_data_type: Union[str, torch.dtype] = "float16", kv_data_type: Optional[Union[str, torch.dtype]] = None, non_blocking: bool = True, prefix_len_ptr: Optional[torch.Tensor] = None, token_pos_in_items_ptr: Optional[torch.Tensor] = None, token_pos_in_items_len: int = 0, max_item_len_ptr: Optional[torch.Tensor] = None, ) -> None: r"""Plan batch prefill/append attention on Ragged KV-Cache for given problem specification. Parameters ---------- qo_indptr : torch.Tensor The indptr of the query/output tensor, shape: ``[batch_size + 1]``. kv_indptr : torch.Tensor The indptr of the key/value tensor, shape: ``[batch_size + 1]``. num_qo_heads : int The number of query/output heads. num_kv_heads : int The number of key/value heads. head_dim_qk : int The dimension of the heads on query/key tensor. head_dim_vo : Optional[int] The dimension of the heads on value/output tensor. If not provided, will be set to ``head_dim_vo``. custom_mask : Optional[torch.Tensor] The flattened boolean mask tensor, shape: ``(sum(q_len[i] * k_len[i] for i in range(batch_size))``. The elements in the mask tensor should be either ``True`` or ``False``, where ``False`` means the corresponding element in the attention matrix will be masked out. Please refer to the :ref:`mask layout ` for more details about flattened layout of mask tensor. When :attr:`custom_mask` is provided, and :attr:`packed_custom_mask` is not, the function will pack the custom mask tensor into a 1D packed mask tensor, which introduces additional overhead. packed_custom_mask : Optional[torch.Tensor] The 1D packed uint8 mask tensor, if provided, the :attr:`custom_mask` will be ignored. The packed mask tensor is generated by :func:`flashinfer.quantization.packbits`. If provided, the custom mask will be added to the attention matrix before softmax and after scaling. The mask tensor should be in the same device as the input tensors. causal : bool Whether to apply causal mask to the attention matrix. This argument is ignored if ``mask`` is provided in :meth:`plan`. pos_encoding_mode : str The position encoding applied inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. Default is ``NONE``. use_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). window_left : int The left (inclusive) window size for the attention window, when set to ``-1``, the window size will be set to the full length of the sequence. Defaults to ``-1``. logits_soft_cap : Optional[float] The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to ``0``. If greater than 0, the logits will be capped according to formula: :math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`, where :math:`x` is the input logits. sm_scale : Optional[float] The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim_qk)``. rope_scale : Optional[float] The scale used in RoPE interpolation, if not provided, will be set to ``1.0``. rope_theta : Optional[float] The theta used in RoPE, if not provided, will be set to ``1e4``. q_data_type : Union[str, torch.dtype] The data type of the query tensor, defaults to torch.float16. kv_data_type : Optional[Union[str, torch.dtype]] The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`. non_blocking : bool Whether to copy the input tensors to the device asynchronously, defaults to ``True``. prefix_len_ptr :Optional[torch.Tensor] prefix length. A uint32 1D tensor indicating the prefix length of each prompt. The tensor size is equal to the batch size. token_pos_in_items_ptr : Optional[float] A uint16 1D tensor (it will be converted to uint16 in flashinfer) indicating the token position of each item and started from 0 (delimiter) for each item. E.g., if we have 3 items of length 3, 2, 4 respectively for this member. This vector will be looking like `[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0]` with 4 delimiters indexed as 0. For batch size > 1, we will concat them as 1D with zero paddings to make sure each has the same length, the padding length is defined by `token_pos_in_items_len` - length of the raw `token_pos_in_items_ptr` for each prompt. token_pos_in_items_len : int zero padding length for `token_pos_in_items_ptr` to better handle the bsz > 1 case. Still using the above 3,2,4 example. If we set `token_pos_in_items_len` to be 20, it will be `[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0]` with 7 padded zeros. (note there're 8 zeros in the end where the first one is the delimiter token 0 in the end of the prompt) max_item_len_ptr : Optional[float] a uint16 vector contains the max token length of all items for each prompt Note ---- The :meth:`plan` method should be called before any :meth:`run` or :meth:`run_return_lse` calls, auxiliary data structures will be created during this plan call and cached for multiple kernel runs. The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` is not equal to ``num_kv_heads``, the function will use `grouped query attention `_. The :meth:`plan` method cannot be used in Cuda Graph or in ``torch.compile``. """ q_data_type = canonicalize_torch_dtype(q_data_type) if kv_data_type is None: kv_data_type = q_data_type kv_data_type = canonicalize_torch_dtype(kv_data_type) if head_dim_vo is None: head_dim_vo = head_dim_qk if logits_soft_cap is None: logits_soft_cap = 0.0 batch_size = len(qo_indptr) - 1 if len(kv_indptr) != batch_size + 1: raise ValueError( "The kv_indptr length should be equal to mask_indptr length." ) if custom_mask is not None or packed_custom_mask is not None: mask_indptr = _compute_mask_indptr(qo_indptr, kv_indptr) if packed_custom_mask is None and custom_mask is not None: # create packed custom mask from custom mask packed_custom_mask, mask_indptr = segment_packbits( custom_mask.contiguous().view(-1), mask_indptr, bitorder="little", ) # NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors qo_indptr_host = qo_indptr.to("cpu") kv_indptr_host = kv_indptr.to("cpu") total_num_rows = qo_indptr_host[-1] if self.is_cuda_graph_enabled: if self._max_total_num_rows is None: self._max_total_num_rows = total_num_rows elif total_num_rows > self._max_total_num_rows: raise ValueError( "The total number of rows in qo_indptr {} in cuda graph mode cannot " "exceed the number of rows set during initialization {}.".format( total_num_rows, self._max_total_num_rows ) ) if batch_size != self._fixed_batch_size: raise ValueError( "The batch size should be fixed in cudagraph mode, the runtime batch size {} " " mismatches the batch size set during initialization {}.".format( batch_size, self._fixed_batch_size ) ) self._qo_indptr_buf.copy_(qo_indptr, non_blocking=non_blocking) self._kv_indptr_buf.copy_(kv_indptr, non_blocking=non_blocking) if packed_custom_mask is not None: if not torch.is_tensor(self._custom_mask_buf): raise ValueError( "custom_mask_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation." ) if not torch.is_tensor(self._mask_indptr_buf): raise ValueError( "mask_indptr_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in the attention computation." ) self._custom_mask_buf[: len(packed_custom_mask)] = packed_custom_mask self._mask_indptr_buf.copy_(mask_indptr, non_blocking=non_blocking) else: self._qo_indptr_buf = qo_indptr.to(self.device, non_blocking=non_blocking) self._kv_indptr_buf = kv_indptr.to(self.device, non_blocking=non_blocking) if packed_custom_mask is not None: self._custom_mask_buf = packed_custom_mask.to( self.device, non_blocking=non_blocking ) self._mask_indptr_buf = mask_indptr.to( self.device, non_blocking=non_blocking ) self._cached_q_data_type = q_data_type self._cached_kv_data_type = kv_data_type kv_len_arr = kv_indptr_host[1:] - kv_indptr_host[:-1] self._prefix_len_ptr = prefix_len_ptr self._token_pos_in_items_ptr = token_pos_in_items_ptr self._token_pos_in_items_len = token_pos_in_items_len self._max_item_len_ptr = max_item_len_ptr if self._jit_module is not None: self._cached_module = self._jit_module else: if self._backend == "auto": self._backend = determine_attention_backend( self.device, PosEncodingMode[pos_encoding_mode].value, use_fp16_qk_reduction, self._custom_mask_buf is not None, # use_custom_mask q_data_type, kv_data_type, ) get_module_args = ( q_data_type, kv_data_type, q_data_type, kv_indptr.dtype, head_dim_qk, head_dim_vo, PosEncodingMode[pos_encoding_mode].value, window_left >= 0, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap use_fp16_qk_reduction, ) if self._backend == "cutlass": # insert qo_indptr.device to 9th position (0-indexed) of get_module_args new_get_module_args = ( get_module_args[:9] + (qo_indptr.device,) + get_module_args[9:] ) self._cached_module = get_fmha_module(*new_get_module_args) else: self._cached_module = get_batch_prefill_module( self._backend, *get_module_args ) if self._backend == "cutlass": self._plan_info = fmha_varlen_plan( self._cached_module, qo_indptr, kv_indptr, num_qo_heads, causal ) self._max_qo_len = torch.max(qo_indptr[1:] - qo_indptr[:-1]).item() else: assert self._cached_module is not None, "cached module is not initialized" self._plan_info = self._cached_module.plan( self._float_workspace_buffer, self._int_workspace_buffer, self._pin_memory_int_workspace_buffer, qo_indptr_host, kv_indptr_host, kv_len_arr, self._max_total_num_rows or total_num_rows, batch_size, num_qo_heads, num_kv_heads, 1, # page_size self.is_cuda_graph_enabled, head_dim_qk, head_dim_vo, causal, ) self._causal = causal self._pos_encoding_mode = pos_encoding_mode self._use_fp16_qk_reduction = use_fp16_qk_reduction self._window_left = window_left self._logits_soft_cap = logits_soft_cap self._sm_scale = sm_scale self._rope_scale = rope_scale self._rope_theta = rope_theta begin_forward = plan def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False, pos_encoding_mode: str = "NONE", use_fp16_qk_reduction: bool = False, window_left: int = -1, logits_soft_cap: Optional[float] = None, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ) -> torch.Tensor: r"""Warning: This function is deprecated, please use :meth:`run` instead.""" self._causal = causal self._pos_encoding_mode = pos_encoding_mode self._use_fp16_qk_reduction = use_fp16_qk_reduction self._window_left = window_left self._logits_soft_cap = logits_soft_cap self._sm_scale = sm_scale self._rope_scale = rope_scale self._rope_theta = rope_theta return self.run(q, k, v) @overload def run( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, return_lse: Literal[False] = False, enable_pdl: Optional[bool] = None, ) -> torch.Tensor: ... @overload def run( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, return_lse: Literal[True] = True, enable_pdl: Optional[bool] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... def run( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, return_lse: bool = False, enable_pdl: Optional[bool] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Compute batch prefill/append attention between query and kv-cache stored as ragged tensor. Parameters ---------- q : torch.Tensor The query tensor, shape: ``[qo_indptr[-1], num_qo_heads, head_dim_qk]`` k : torch.Tensor The key tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim_qk]`` v : torch.Tensor The value tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim_vo]`` *args Additional arguments for the custom kernel. out : Optional[torch.Tensor] The output tensor, if not provided, will be allocated internally. lse : Optional[torch.Tensor] The log-sum-exp of attention logits, if not provided, will be allocated internally. return_lse : bool Whether to return the logsumexp of attention output enable_pdl : bool Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization Only supported for >= sm90, and currently only for FA2 and CUDA core decode. Returns ------- Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] If :attr:`return_lse` is ``False``, the attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim_vo]``. If :attr:`return_lse` is ``True``, a tuple of two tensors: * The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim_vo]``. * The logsumexp of attention output, shape: ``[qo_indptr[-1], num_qo_heads]``. """ if enable_pdl is None: enable_pdl = device_support_pdl(q.device) _check_cached_qkv_data_type( q, k, self._cached_q_data_type, self._cached_kv_data_type ) window_left = self._window_left logits_soft_cap = self._logits_soft_cap sm_scale = self._sm_scale rope_scale = self._rope_scale rope_theta = self._rope_theta if logits_soft_cap is None: logits_soft_cap = 0.0 if sm_scale is None: sm_scale = 1.0 / math.sqrt(q.size(-1)) if rope_scale is None: rope_scale = 1.0 if rope_theta is None: rope_theta = 1e4 if return_lse: if lse is None: lse = torch.empty( (q.size(0), q.size(1)), dtype=torch.float32, device=q.device ) else: check_shape_dtype_device( lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse" ) if out is None: out = torch.empty( q.shape[:-1] + v.shape[-1:], dtype=q.dtype, device=q.device ) else: check_shape_dtype_device( out, q.shape[:-1] + v.shape[-1:], q.dtype, q.device, "out" ) if self._backend == "cutlass": out, lse = fmha_varlen( q, k, v, self._qo_indptr_buf, self._kv_indptr_buf, plan_info=self._plan_info, causal=self._causal, sm_scale=sm_scale, max_qo_len=self._max_qo_len, out=out, lse=lse, ) return (out, lse) if return_lse else out if is_float8(q): logging.warning( "Our current prefill kernel implementation needs f16 input, the f8 inputs " " are casted to f16, which could result in performance degradation." ) q = q.to(torch.float16) k = k.to(torch.float16) v = v.to(torch.float16) if self._custom_mask_buf is not None: mask_mode = MaskMode.CUSTOM.value else: if self._causal: mask_mode = MaskMode.CAUSAL.value else: mask_mode = MaskMode.NON_CAUSAL.value run_args = [ self._float_workspace_buffer, self._int_workspace_buffer, self._plan_info, q, k, v, self._qo_indptr_buf, self._kv_indptr_buf, out, lse, mask_mode, TensorLayout[self._kv_layout].value, window_left, enable_pdl, ] if self._jit_module is not None: run_args.extend(list(args)) else: run_args += [ self._custom_mask_buf, self._mask_indptr_buf, _get_cache_alibi_slopes_buf(q.shape[1], self.device), self._prefix_len_ptr, self._token_pos_in_items_ptr, self._max_item_len_ptr, logits_soft_cap, sm_scale, rope_scale, rope_theta, self._token_pos_in_items_len, ] assert self._cached_module is not None, "cached module is not initialized" self._cached_module.ragged_run(*run_args) return (out, lse) if return_lse else out run_return_lse = functools.partialmethod(run, return_lse=True) def forward_return_lse( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False, pos_encoding_mode: str = "NONE", use_fp16_qk_reduction: bool = False, window_left: int = -1, logits_soft_cap: Optional[float] = None, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Warning: This function is deprecated, please use :meth:`run_return_lse` instead.""" self._causal = causal self._pos_encoding_mode = pos_encoding_mode self._use_fp16_qk_reduction = use_fp16_qk_reduction self._window_left = window_left self._logits_soft_cap = logits_soft_cap self._sm_scale = sm_scale self._rope_scale = rope_scale self._rope_theta = rope_theta return self.run_return_lse(q, k, v) def end_forward(self) -> None: r"""Warning: this function is deprecated and has no effect.""" pass def fmha_varlen_plan( module, qo_segment_offsets: torch.Tensor, kv_segment_offsets: torch.Tensor, num_qo_heads: int, causal: bool, ): num_ctas = torch.cuda.get_device_properties( qo_segment_offsets.device ).multi_processor_count work_indptr = torch.empty( num_ctas + 1, device=qo_segment_offsets.device, dtype=torch.int32 ) qo_tile_indices = torch.empty( 131072, device=qo_segment_offsets.device, dtype=torch.int32 ) head_indices = torch.empty( 131072, device=qo_segment_offsets.device, dtype=torch.int32 ) batch_indices = torch.empty( 131072, device=qo_segment_offsets.device, dtype=torch.int32 ) module.plan( qo_segment_offsets, kv_segment_offsets, work_indptr, qo_tile_indices, head_indices, batch_indices, 256, # qo_tile_size num_qo_heads, num_ctas, causal, ) return ( work_indptr, qo_tile_indices, head_indices, batch_indices, ) @overload def fmha_varlen( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, qo_segment_offsets: torch.Tensor, kv_segment_offsets: torch.Tensor, plan_info: Optional[List[torch.Tensor]] = None, max_qo_len: Optional[int] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, causal: bool = False, sm_scale: Optional[float] = None, return_lse: Literal[False] = False, ) -> torch.Tensor: ... @overload def fmha_varlen( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, qo_segment_offsets: torch.Tensor, kv_segment_offsets: torch.Tensor, plan_info: Optional[List[torch.Tensor]] = None, max_qo_len: Optional[int] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, causal: bool = False, sm_scale: Optional[float] = None, return_lse: Literal[True] = True, ) -> Tuple[torch.Tensor, torch.Tensor]: ... def fmha_varlen( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, qo_segment_offsets: torch.Tensor, kv_segment_offsets: torch.Tensor, plan_info: Optional[List[torch.Tensor]] = None, max_qo_len: Optional[int] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, causal: bool = False, sm_scale: Optional[float] = None, return_lse: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: workspace_buffer = _get_cache_buf( "fmha_varlen_cutlass_workspace", 32 * 1024 * 1024, q.device ) module = get_fmha_module( q.dtype, k.dtype, v.dtype, torch.int32, q.shape[2], v.shape[2], PosEncodingMode.NONE.value, False, # use_sliding_window False, # use_logits_soft_cap q.device, ) nnz_qo, num_qo_heads, head_dim_qk = q.shape nnz_kv, num_kv_heads, head_dim_vo = v.shape mask_mode_code = 1 if causal else 0 if sm_scale is None: sm_scale = 1.0 / math.sqrt(head_dim_qk) qo_total_len = nnz_qo if max_qo_len is None: max_qo_len = torch.max(qo_segment_offsets[1:] - qo_segment_offsets[:-1]).item() if plan_info is None: plan_info = fmha_varlen_plan( module, qo_segment_offsets, kv_segment_offsets, num_qo_heads, causal ) ( work_indptr, qo_tile_indices, head_indices, batch_indices, ) = plan_info if out is None: out = torch.empty( qo_total_len + max(max_qo_len, 128), num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype, )[max(max_qo_len, 128) :] if lse is None and return_lse: lse = torch.empty( qo_total_len, num_qo_heads, device=q.device, dtype=torch.float32 ) module.run( workspace_buffer, q, k, v, qo_segment_offsets, kv_segment_offsets, work_indptr, qo_tile_indices, head_indices, batch_indices, out, lse, mask_mode_code, sm_scale, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, max_qo_len, ) return out, lse @functools.cache def get_trtllm_gen_fmha_module(): mod = gen_trtllm_gen_fmha_module() op = mod.build_and_load() setup_cubin_loader(mod.get_library_path()) return op def trtllm_ragged_attention_deepseek( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, workspace_buffer: torch.Tensor, seq_lens: torch.Tensor, max_q_len: int, max_kv_len: int, bmm1_scale: float, bmm2_scale: float, o_sf_scale: float, batch_size: int, window_left: int, cum_seq_lens_q: torch.Tensor, cum_seq_lens_kv: torch.Tensor, enable_pdl: bool, is_causal: bool, return_lse: bool, attention_sinks: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Parameters ---------- query : torch.Tensor query tensor with shape [num_tokens, num_heads, head_dim] key : torch.Tensor key tensor with shape [num_tokens, num_heads, head_dim] value : torch.Tensor value tensor with shape [num_tokens, num_heads, head_dim] workspace_buffer : torch.Tensor workspace buffer seq_lens : torch.Tensor sequence lengths max_q_len : int max query length max_kv_len : int max key/value length bmm1_scale : float scale for bmm1, scale_q * scale_k * 1.0 / (head_dim_qk ** 0.5) bmm2_scale : float scale for bmm2, scale_v o_sf_scale : float scale for output batch_size : int batch size window_left : int window left cum_seq_lens_q : torch.Tensor cumulative sequence lengths for query cum_seq_lens_kv : torch.Tensor cumulative sequence lengths for key/value enable_pdl : bool enable pdl is_causal : bool is causal attention_sinks : Optional[torch.Tensor] attention sinks out : Optional[torch.Tensor] output tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1], value.shape[2]] lse : Optional[torch.Tensor] lse tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1]] Returns ------- out: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] output torch.Tensor or Tuple[torch.Tensor, torch.Tensor]. If return_lse is True, the output will be a tuple of two tensors, the first is the output tensor, the second is the lse tensor. If return_lse is False, the output will be a single tensor. """ assert query.shape[2] == 192 and key.shape[2] == 192 and value.shape[2] == 128, ( "currently only support deepseek r1 192 query and 128 value" ) if enable_pdl is None: enable_pdl = device_support_pdl(query.device) run_func = get_trtllm_gen_fmha_module().trtllm_ragged_attention sm_count = get_device_sm_count(query.device) if out is None: out = torch.empty( query.shape[0], query.shape[1], value.shape[2], device=query.device, dtype=query.dtype, ) if return_lse and lse is None: lse = torch.empty( query.shape[0], query.shape[1], device=query.device, dtype=torch.float32, ) workspace_size = workspace_buffer.numel() * workspace_buffer.element_size() run_func( out, query, key, value, workspace_buffer, seq_lens, max_q_len, max_kv_len, bmm1_scale, bmm2_scale, o_sf_scale, batch_size, window_left, cum_seq_lens_q, cum_seq_lens_kv, sm_count, enable_pdl, is_causal, workspace_size, attention_sinks, lse, ) if return_lse: return out, lse else: return out def trtllm_batch_context_with_kv_cache( query: torch.Tensor, kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], workspace_buffer: torch.Tensor, block_tables: torch.Tensor, seq_lens: torch.Tensor, max_q_len: int, max_kv_len: int, bmm1_scale: float, bmm2_scale: float, batch_size: int, cum_seq_lens_q: torch.Tensor, cum_seq_lens_kv: torch.Tensor, window_left: int = -1, out: Optional[Union[torch.Tensor, FP4Tensor]] = None, out_dtype: Optional[Union[torch.dtype, str]] = None, o_sf_scale: Optional[float] = None, o_sf_vec_size: Optional[int] = None, enable_pdl: Optional[bool] = None, sinks: Optional[List[torch.Tensor]] = None, ) -> Union[torch.Tensor, FP4Tensor]: """ Parameters ---------- query : torch.Tensor query tensor with shape [num_tokens, num_heads, head_dim] kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, num_kv_heads, page_size, head_dim] If kv_cache is a tuple of two tensors, it should be a tuple of two tensors with shape [num_pages, num_kv_heads, page_size, head_dim] workspace_buffer : torch.Tensor. Must be initialized to 0 for its first use. workspace block_tables : torch.Tensor page_table of kv cache, [batch_size, num_pages] seq_lens : torch.Tensor A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]`` max_q_len : int max sequence length for query max_kv_len : int max sequence length for kv_cache bmm1_scale : float fused scale for bmm1 input. bmm2_scale : float fused scale for bmm2 input. batch_size : int batch size cum_seq_lens_q : torch.Tensor cumulative sequence length for query. shape: ``[batch_size + 1]`` cum_seq_lens_kv : torch.Tensor cumulative sequence length for kv_cache. shape: ``[batch_size + 1]`` window_left : int = -1 The left (inclusive) window size for the attention window, when set to ``-1``, the window size will be set to the full length of the sequence. Defaults to ``-1``. out : Optional[Union[torch.Tensor, FP4Tensor]] = None output tensor, if not provided, will be allocated with ``out_dtype``, if ``out_dtype`` is not provided, will use the type of ``query``. out_dtype : Optional[Union[torch.dtype, str]] = None output dtype, if not provided, will use the type of ``out``. For nvfp4, use string ``nvfp4``. o_sf_scale : Optional[float] = None scale for nvfp4 output tensor scale factor. o_sf_vec_size : Optional[int] = None vector size for nvfp4 output tensor scale factor. sinks : Optional[List[torch.Tensor]] = None additional value per head in the denominator of the softmax. Returns ------- out: Union[torch.Tensor, FP4Tensor] output torch.Tensor or FP4Tensor. """ if enable_pdl is None: enable_pdl = device_support_pdl(query.device) if isinstance(kv_cache, tuple): k_cache, v_cache = kv_cache else: if kv_cache.shape[1] == 1: k_cache, v_cache = kv_cache, kv_cache else: assert kv_cache.shape[1] == 2, ( "When kv_cache is a single tensor, the second dimension must be 1 or 2" ) # NOTE(Zihao): unbind transforms [num_pages, 2, ...] to ([num_pages, ...], [num_pages, ...]) # it doesn't change underlying storage k_cache, v_cache = kv_cache.unbind(dim=1) run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_context sm_count = get_device_sm_count(query.device) if out_dtype == "nvfp4" or (out_dtype is None and isinstance(out, FP4Tensor)): assert query.dtype == torch.float8_e4m3fn, ( "query must be fp8 when out_dtype is nvfp4." ) assert o_sf_scale is not None assert o_sf_vec_size in [None, 16], "only o_sf_vec_size = 16 is supported" o_sf_vec_size = o_sf_vec_size or 16 fp4_out_shape = query.shape[:-1] + (ceil_div(query.shape[-1], 2),) if isinstance(out, FP4Tensor): fp4_out_scale_shape = ( out.scale.shape[0], round_up(query.shape[1] * query.shape[2] // o_sf_vec_size, 4), ) out_scale_factor = out.scale o_sf_start_index = out.scale_start_index out = out.data # out_dtype may be None out_dtype = out_dtype or "nvfp4" elif out is None: fp4_out_scale_shape = ( round_up(query.shape[0], 128), round_up(query.shape[1] * query.shape[2] // o_sf_vec_size, 4), ) out_scale_factor = torch.empty( fp4_out_scale_shape, dtype=torch.float8_e4m3fn, device=query.device ) o_sf_start_index = 0 out = torch.empty(fp4_out_shape, dtype=torch.uint8, device=query.device) else: raise ValueError(f"Invalid out: {out}") assert out_dtype == "nvfp4" assert isinstance(out, torch.Tensor) # Use uint8 as the container dtype to compliant with next fp4 gemm. check_shape_dtype_device(out, fp4_out_shape, torch.uint8, query.device, "out") check_shape_dtype_device( out_scale_factor, fp4_out_scale_shape, torch.float8_e4m3fn, query.device, "out_scale_factor", ) # Check o_sf_start_index is valid if ( o_sf_start_index < 0 or o_sf_start_index + out.shape[0] > out_scale_factor.shape[0] ): raise ValueError( f"o_sf_start_index is out of the valid range of out_scale_factor. " f"o_sf_start_index={o_sf_start_index}, out.shape[0]={out.shape[0]}, " f"out_scale_factor.shape[0]={out_scale_factor.shape[0]}" ) elif isinstance(out_dtype, torch.dtype) or out_dtype is None: assert o_sf_scale is None assert o_sf_vec_size is None out_scale_factor = None o_sf_start_index = 0 if out_dtype is None: out_dtype = out.dtype if out is not None else query.dtype out = out if out is not None else torch.empty_like(query, dtype=out_dtype) if out_dtype not in (query.dtype, torch.float16, torch.bfloat16): raise ValueError(f"Unsupported out_dtype: {out_dtype}") check_shape_dtype_device(out, query.shape, out_dtype, query.device, "out") else: raise ValueError(f"Invalid out_dtype: {out_dtype}") workspace_size = workspace_buffer.numel() * workspace_buffer.element_size() run_func( out, out_scale_factor, query, k_cache, v_cache, workspace_buffer, block_tables, seq_lens, max_q_len, max_kv_len, bmm1_scale, bmm2_scale, o_sf_scale or -1.0, o_sf_vec_size or -1, o_sf_start_index, batch_size, window_left, cum_seq_lens_q, cum_seq_lens_kv, sm_count, enable_pdl, workspace_size, sinks, ) return ( out if out_dtype != "nvfp4" else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape) )