""" Copyright (c) 2024 by FlashInfer team. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import math from typing import Optional, Tuple, Union import torch from .decode import get_batch_decode_module from .page import block_sparse_indices_to_vector_sparse_offsets from .prefill import _compute_page_mask_indptr, get_batch_prefill_module from .quantization import segment_packbits from .utils import ( MaskMode, PosEncodingMode, TensorLayout, _check_pos_encoding_mode, check_shape_dtype_device, _get_cache_alibi_slopes_buf, canonicalize_torch_dtype, determine_attention_backend, device_support_pdl, is_float8, ) def convert_bsr_mask_layout(mask: torch.Tensor, indptr: torch.Tensor) -> torch.Tensor: r"""Convert mask from BSR data layout to flashinfer's flattened mask layout. Parameters ---------- mask : torch.Tensor A boolean mask tensor with shape ``(nnz, R, C)``. indptr : torch.Tensor The indptr tensor in BSR format. Returns ------- flattened_mask : torch.Tensor A flattenedd mask tensor with shape ``(nnz * R * C,)``. """ nnz, R, C = mask.shape MB = len(indptr) - 1 mask_flashinfer = torch.empty((nnz * R * C,), dtype=mask.dtype, device=mask.device) for i in range(MB): mask_flashinfer[indptr[i] * R * C : indptr[i + 1] * R * C] = ( mask[indptr[i] : indptr[i + 1]].transpose(0, 1).reshape(-1) ) return mask_flashinfer class BlockSparseAttentionWrapper: r"""Wrapper class for attention computation with a block-sparse matrix as attention mask. The definition of block sparse matrix can be found at `bsr_matrix `_ in SciPy. This API supports any block size ``(R, C)``. Example ------- >>> import torch >>> import flashinfer >>> num_qo_heads = 32 >>> num_kv_heads = 8 >>> head_dim = 128 >>> # allocate 128MB workspace buffer >>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") >>> bsr_wrapper = flashinfer.BlockSparseAttentionWrapper(workspace_buffer) >>> # sparse mask: [[0, 0, 1], [1, 0, 1], [0, 1, 1]] >>> M = 3 >>> N = 3 >>> indptr = torch.tensor([0, 1, 3, 5], dtype=torch.int32, device="cuda:0") >>> indices = torch.tensor([2, 0, 2, 1, 2], dtype=torch.int32, device="cuda:0") >>> bsr_wrapper.plan( ... indptr, ... indices, ... M, ... N, ... 1, # R(block_rows)=1 ... 1, # C(block_columns)=1 ... num_qo_heads, ... num_kv_heads, ... head_dim, ... ) >>> q = torch.randn((M, num_qo_heads, head_dim), dtype=torch.float16, device="cuda:0") >>> k = torch.randn((N, num_kv_heads, head_dim), dtype=torch.float16, device="cuda:0") >>> v = torch.randn((N, num_kv_heads, head_dim), dtype=torch.float16, device="cuda:0") >>> o = bsr_wrapper.run(q, k, v) >>> # use dense implementation with attention mask for comparison >>> mask = torch.tensor([[0, 0, 1], [1, 0, 1], [0, 1, 1]], dtype=torch.bool, device="cuda:0") >>> o_ref = flashinfer.single_prefill_with_kv_cache(q, k, v, custom_mask=mask) >>> torch.allclose(o, o_ref) True """ def __init__( self, float_workspace_buffer: torch.Tensor, backend: str = "auto", ) -> None: r"""Constructs of :class:`BlockSparseAttentionWrapper`. Parameters ---------- float_workspace_buffer : torch.Tensor The user reserved float workspace buffer used to store intermediate attention results in the split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors. backend : str The implementation backend, could be ``auto``/``fa2`` or ``fa3``. Defaults to ``auto``. If set to ``auto``, the function will automatically choose the backend based on the device architecture and kernel availability. """ self._float_workspace_buffer = float_workspace_buffer self.device = float_workspace_buffer.device self._workspace_size = ( float_workspace_buffer.numel() * float_workspace_buffer.element_size() ) self._int_workspace_buffer = torch.empty( (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device ) if backend in ["fa3", "auto"]: # NOTE(Zihao): assume maximum accumulate kv length is 128M # NOTE(Yilong): 128M is required by video DiT models self._vector_sparse_indices_buffer = torch.empty( (128 * 1024 * 1024,), dtype=torch.int32, device=self.device ) # NOTE(Zihao): assume maximum batch size is 32768 self._vector_sparse_indptr_buffer = torch.empty( (32768,), dtype=torch.int32, device=self.device ) self._kv_lens_buffer = torch.empty( (32768,), dtype=torch.int32, device=self.device ) self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=torch.uint8, pin_memory=True, device="cpu", ) self._use_cuda_graph = False self._kv_layout = "NHD" self._qo_indptr: Optional[torch.Tensor] = None self._paged_kv_indptr_buf: Optional[torch.Tensor] = None self._paged_kv_indices_buf: Optional[torch.Tensor] = None self._paged_kv_last_page_len: Optional[torch.Tensor] = None self._packed_mask_buf: Optional[torch.Tensor] = None self._mask_indptr_buf: Optional[torch.Tensor] = None self.R: Optional[int] = None self.C: Optional[int] = None self.M: Optional[int] = None self.N: Optional[int] = None self._backend = backend def reset_workspace_buffer( self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor, vector_sparse_indices_buffer: Optional[torch.Tensor] = None, vector_sparse_indptr_buffer: Optional[torch.Tensor] = None, ) -> None: r"""Reset the workspace buffer. Parameters ---------- float_workspace_buffer : torch.Tensor The new float workspace buffer, the device of the new float workspace buffer should be the same as the device of the input tensors. int_workspace_buffer : torch.Tensor The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors. """ self._float_workspace_buffer = float_workspace_buffer self._int_workspace_buffer = int_workspace_buffer self._workspace_size = ( float_workspace_buffer.numel() * float_workspace_buffer.element_size() ) self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=self._int_workspace_buffer.dtype, pin_memory=True, ) # Enable user-defined size if vector_sparse_indices_buffer is not None: self._vector_sparse_indices_buffer = vector_sparse_indices_buffer if vector_sparse_indptr_buffer is not None: self._vector_sparse_indptr_buffer = vector_sparse_indptr_buffer def plan( self, indptr: torch.Tensor, indices: torch.Tensor, M: int, N: int, R: int, C: int, num_qo_heads: int, num_kv_heads: int, head_dim: int, mask: Optional[torch.Tensor] = None, packed_mask: Optional[torch.Tensor] = None, causal: bool = False, pos_encoding_mode: str = "NONE", use_fp16_qk_reduction: bool = False, logits_soft_cap: Optional[float] = None, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, q_data_type: Union[str, torch.dtype] = "float16", kv_data_type: Optional[Union[str, torch.dtype]] = None, o_data_type: Union[str, torch.dtype] = "float16", non_blocking: bool = True, ) -> None: r"""Create auxiliary data structures for block sparse attention. Parameters ---------- indptr : torch.Tensor The block index pointer of the block-sparse matrix on row dimension, shape ``(MB + 1,)``, where ``MB`` is the number of blocks in the row dimension. indices: torch.Tensor The block indices of the block-sparse matrix on column dimension, shape ``(nnz,)``, where ``nnz`` is the number of non-zero blocks. The elements in ``indices`` array should be less then ``NB``: the number of blocks in the column dimension. M : int The number of rows of the block-sparse matrix, ``MB = ceil_div(M, R)``. N : int The number of columns of the block-sparse matrix, ``NB = N // C``, ``N`` should be divisible by ``C``. R : int The number of rows in each block. C : int The number of columns in each block. num_qo_heads : int The number of heads in the query/output tensor. num_kv_heads : int The number of heads in the key/value tensor. head_dim : int The dimension of each head. mask : torch.Tensor, optional The mask tensor with shape ``(nnz, R, C,)``, where nnz is the number of non-zero blocks. If every block is full, then we don't need to provide the mask tensor. packed_mask : torch.Tensor, optional The 1D packed mask tensor, if provided, the :attr:`custom_mask` will be ignored. The packed mask tensor is generated by :func:`flashinfer.quantization.packbits`. causal : bool Whether to apply causal mask to the attention matrix. This is only effective when :attr:`custom_mask` is not provided in :meth:`plan`. pos_encoding_mode : str, optional The position encoding applied inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. Default is ``NONE``. use_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). logits_soft_cap : Optional[float] The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to ``0``. If greater than 0, the logits will be capped according to formula: :math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`, where :math:`x` is the input logits. sm_scale : Optional[float] The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. rope_scale : Optional[float] The scale used in RoPE interpolation, if not provided, will be set to ``1.0``. rope_theta : Optional[float] The theta used in RoPE, if not provided, will be set to ``1e4``. q_data_type : str, optional The data type of the query tensor. kv_data_type : Optional[Union[str, torch.dtype]] The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`. o_data_type : str, optional The data type of the output tensor. Default is ``half``. As output dtype cannot be inferred by input dtype in quantization non_blocking : bool Whether to copy the input tensors to the device asynchronously, defaults to ``True``. The :meth:`plan` method should be called before any :meth:`run` or :meth:`run_return_lse` calls, auxiliary data structures will be created during this call and cached for multiple kernel runs. The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` is not equal to ``num_kv_heads``, the function will use `grouped query attention `_. """ q_data_type = canonicalize_torch_dtype(q_data_type) if kv_data_type is None: kv_data_type = q_data_type kv_data_type = canonicalize_torch_dtype(kv_data_type) self._o_dtype = canonicalize_torch_dtype(o_data_type) if logits_soft_cap is None: logits_soft_cap = 0.0 num_blocks_row = len(indptr) - 1 qo_indptr_host = R * torch.arange(num_blocks_row + 1, dtype=torch.int32) qo_indptr_host[-1] = M qo_indptr = qo_indptr_host.to(indptr.device, non_blocking=non_blocking) if indices.max().item() * C > N: raise ValueError("indices out of bound") last_block_len = torch.full( (num_blocks_row,), C, dtype=torch.int32, device=indptr.device ) if mask is not None or packed_mask is not None: mask_indptr = _compute_page_mask_indptr( qo_indptr, indptr, # paged_kv_indptr last_block_len, # paged_kv_last_page_len C, # page_size ) if packed_mask is None and mask is not None: # first convert BSR mask to flashinfer layout mask = convert_bsr_mask_layout(mask, indptr) # create packed mask from mask packed_mask, mask_indptr = segment_packbits( mask.contiguous().view(-1), mask_indptr, bitorder="little" ) self._qo_indptr = qo_indptr.to(self.device, non_blocking=non_blocking) self._paged_kv_indptr_buf = indptr.to(self.device, non_blocking=non_blocking) self._paged_kv_indices_buf = indices.to(self.device, non_blocking=non_blocking) self._paged_kv_last_page_len = last_block_len.to( self.device, non_blocking=non_blocking ) if packed_mask is not None: self._packed_mask_buf = packed_mask.to( self.device, non_blocking=non_blocking ) self._mask_indptr_buf = mask_indptr.to( self.device, non_blocking=non_blocking ) mask_mode = MaskMode.CUSTOM.value else: self._packed_mask_buf = None self._mask_indptr_buf = None mask_mode = MaskMode.CAUSAL.value if causal else MaskMode.NON_CAUSAL.value self._mask_mode = mask_mode self.M = M self.N = N self.R = R self.C = C kv_indptr_host = indptr.to("cpu") # NOTE(Zihao): we haven't supported mask in cuda-core implementations but it should # be easy to add support for it if needed, leave it as a future work. # at this moment, when mask is provided, we use the tensor-core implementation if ( R * (num_qo_heads // num_kv_heads) < 4 and mask_mode != MaskMode.CUSTOM.value and q_data_type not in [torch.float8_e4m3fn, torch.float8_e5m2] ): # If the operation is not compute-bound, we use the cuda-core implementation self._use_tensor_cores = False self._cached_module = get_batch_decode_module( q_data_type, kv_data_type, self._o_dtype, indptr.dtype, head_dim, head_dim, PosEncodingMode[pos_encoding_mode].value, False, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap ) self._plan_info = self._cached_module.plan( self._float_workspace_buffer, self._int_workspace_buffer, self._pin_memory_int_workspace_buffer, kv_indptr_host, num_blocks_row, num_qo_heads, num_kv_heads, C, False, # is_cuda_graph_enabled -1, # window_left logits_soft_cap, # logits_soft_cap head_dim, head_dim, torch.empty(0, dtype=q_data_type), torch.empty(0, dtype=kv_data_type), ) else: # if the operation is compute-bound, we use the tensor-core implementation self._use_tensor_cores = True if self._backend == "auto": self._backend = determine_attention_backend( self.device, PosEncodingMode[pos_encoding_mode].value, use_fp16_qk_reduction, mask_mode == MaskMode.CUSTOM.value, # use_custom_mask q_data_type, kv_data_type, ) get_module_args = ( q_data_type, kv_data_type, self._o_dtype, indptr.dtype, head_dim, # head_dim_qk head_dim, # head_dim_vo PosEncodingMode[pos_encoding_mode].value, False, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap use_fp16_qk_reduction, ) self._cached_module = get_batch_prefill_module( self._backend, *get_module_args ) kv_lens_arr_host = (kv_indptr_host[1:] - kv_indptr_host[:-1]) * self.C self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_( kv_lens_arr_host, ) if self._backend == "fa3": if self.C != 1: vector_sparse_indptr_host = torch.cat( [ torch.tensor([0], dtype=torch.int32), torch.cumsum(kv_lens_arr_host, dim=0, dtype=torch.int32), ], dim=0, ) self._vector_sparse_indptr_buffer[ : len(vector_sparse_indptr_host) ].copy_(vector_sparse_indptr_host, non_blocking=non_blocking) kv_indptr_host = vector_sparse_indptr_host self._plan_info = self._cached_module.plan( self._float_workspace_buffer, self._int_workspace_buffer, self._pin_memory_int_workspace_buffer, qo_indptr_host, kv_indptr_host, kv_lens_arr_host, M, # total_num_rows num_blocks_row, # batch_size num_qo_heads, num_kv_heads, self.C, # page_size False, # is_cuda_graph_enabled, head_dim, head_dim, causal, ) self._pos_encoding_mode = pos_encoding_mode self._use_fp16_qk_reduction = use_fp16_qk_reduction self._logits_soft_cap = logits_soft_cap self._sm_scale = sm_scale self._rope_scale = rope_scale self._rope_theta = rope_theta begin_forward = plan def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale_q: Optional[torch.Tensor] = None, scale_k: Optional[torch.Tensor] = None, scale_v: Optional[torch.Tensor] = None, pos_encoding_mode: str = "NONE", use_fp16_qk_reduction: bool = False, logits_soft_cap: Optional[float] = None, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ) -> torch.Tensor: r"""Warning: This method is deprecated, please use :meth:`run` instead.""" self._pos_encoding_mode = pos_encoding_mode self._use_fp16_qk_reduction = use_fp16_qk_reduction self._logits_soft_cap = logits_soft_cap self._sm_scale = sm_scale self._rope_scale = rope_scale self._rope_theta = rope_theta return self.run(q, k, v, scale_q, scale_k, scale_v) def run( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale_q: Optional[torch.Tensor] = None, scale_k: Optional[torch.Tensor] = None, scale_v: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, return_lse: bool = False, enable_pdl: Optional[bool] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Compute block-sparse attention between Q/K/V tensors. Parameters ---------- q : torch.Tensor The query tensor with shape ``(M, num_qo_heads, head_dim)``. k : torch.Tensor The key tensor with shape ``(N, num_kv_heads, head_dim)``. v : torch.Tensor The value tensor with shape ``(N, num_kv_heads, head_dim)``. scale_q : Optional[torch.Tensor] The scale tensor for query, per-head quantization with shape: ``[num_qo_heads]``. Used with FP8 Quantization. If not provided, will be set to ``1.0``. scale_k : Optional[torch.Tensor] The scale tensor for key, per-head quantization with shape: ``[num_kv_heads]``. Used with FP8 Quantization. If not provided, will be set to ``1.0``. scale_v : Optional[torch.Tensor] The scale tensor for value, per-head quantization with shape: ``[num_kv_heads]``. Used with FP8 Quantization. If not provided, will be set to ``1.0``. out : Optional[torch.Tensor] The output tensor, if not provided, will be allocated internally. lse : Optional[torch.Tensor] The log-sum-exp of attention logits, if not provided, will be allocated internally. return_lse : bool Whether to return the log-sum-exp of attention logits enable_pdl : bool Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization Only supported for >= sm90, and currently only for FA2 and CUDA core decode. Returns ------- Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] If :attr:`return_lse` is ``False``, the attention output, shape: ``[M, num_qo_heads, head_dim]``. If :attr:`return_lse` is ``True``, a tuple of two tensors: * The attention output, shape: ``[M, num_qo_heads, head_dim]``. * The logsumexp of attention output, shape: ``[M, num_qo_heads]``. """ if enable_pdl is None: enable_pdl = device_support_pdl(q.device) pos_encoding_mode = self._pos_encoding_mode logits_soft_cap = self._logits_soft_cap sm_scale = self._sm_scale rope_scale = self._rope_scale rope_theta = self._rope_theta _check_pos_encoding_mode(pos_encoding_mode) if logits_soft_cap is None: logits_soft_cap = 0.0 if sm_scale is None: sm_scale = 1.0 / math.sqrt(q.size(-1)) if rope_scale is None: rope_scale = 1.0 if rope_theta is None: rope_theta = 1e4 k = k.reshape(-1, self.C, *k.shape[-2:]) v = v.reshape(-1, self.C, *v.shape[-2:]) stride_block = k.stride(0) stride_n = k.stride(1) if return_lse: if lse is None: lse = torch.empty( (q.size(0), q.size(1)), dtype=torch.float32, device=q.device ) else: check_shape_dtype_device( lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse" ) if out is None: out = torch.empty_like(q, dtype=self._o_dtype) else: check_shape_dtype_device(out, q.shape, self._o_dtype, q.device, "out") if is_float8(q): assert q.dtype == k.dtype == v.dtype assert q.shape[-1] == k.shape[-1] == v.shape[-1] assert self._backend == "fa3" and self._use_tensor_cores if scale_q is None: scale_q = torch.ones(q.shape[1], dtype=torch.float32, device=q.device) if scale_k is None: scale_k = torch.ones(k.shape[1], dtype=torch.float32, device=q.device) if scale_v is None: scale_v = torch.ones(v.shape[1], dtype=torch.float32, device=q.device) if self._use_tensor_cores: if self._backend == "fa3": if ( self._vector_sparse_indices_buffer.numel() <= self._paged_kv_indices_buf.numel() * self.C ): raise ValueError( "_vector_sparse_indices_buffer is not large enough. Please increase the size." ) sparse_indices = block_sparse_indices_to_vector_sparse_offsets( self._paged_kv_indices_buf, self._paged_kv_indptr_buf, self._vector_sparse_indices_buffer, # output self._vector_sparse_indptr_buffer, self._kv_lens_buffer, stride_block // stride_n, 1, # stride_n // stride_n self.C, # block_size ) sparse_indptr = self._vector_sparse_indptr_buffer else: sparse_indices = self._paged_kv_indices_buf sparse_indptr = self._paged_kv_indptr_buf self._cached_module.paged_run( self._float_workspace_buffer, self._int_workspace_buffer, self._plan_info, q, k, v, self._qo_indptr, sparse_indptr, sparse_indices, self._paged_kv_last_page_len, out, lse, self._mask_mode, TensorLayout[self._kv_layout].value, -1, # window_left enable_pdl, # ADDITIONAL_FUNC_PARAMS self._packed_mask_buf, self._mask_indptr_buf, _get_cache_alibi_slopes_buf(q.shape[1], self.device), None, # maybe_prefix_len_ptr None, # maybe_token_pos_in_items_ptr None, # maybe_max_item_len_ptr logits_soft_cap, sm_scale, scale_q, scale_k, scale_v, rope_scale, rope_theta, 0, # token_pos_in_items_len self._workspace_size, # workspace_size ) else: self._cached_module.run( self._float_workspace_buffer, self._int_workspace_buffer, self._plan_info, q, k, v, self._paged_kv_indptr_buf, self._paged_kv_indices_buf, self._paged_kv_last_page_len, out, lse, TensorLayout[self._kv_layout].value, -1, # window_left enable_pdl, # ADDITIONAL_FUNC_PARAMS _get_cache_alibi_slopes_buf(q.shape[1], self.device), logits_soft_cap, sm_scale, rope_scale, rope_theta, ) return (out, lse) if return_lse else out def end_forward(self) -> None: r"""Warning: This method is deprecated and has no effect.""" pass class VariableBlockSparseAttentionWrapper: r"""Wrapper class for attention computation with a block-sparse matrix as attention mask. This API supports variable block sizes provided by ``block_row_sz`` and ``block_col_sz``. Besides, each ``kv_head_idx`` can specify its own sparse patterns without using the same mask. Example ------- >>> import torch >>> import flashinfer >>> num_qo_heads = 1 >>> num_kv_heads = 1 >>> head_dim = 128 >>> seq_len = 6 # This corresponds to the `block_row_sz` and `block_col_sz` >>> # allocate 128MB workspace buffer >>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") >>> wrapper = flashinfer.VariableBlockSparseAttentionWrapper(workspace_buffer) >>> block_mask_map = torch.tensor([[[0, 0, 1], [1, 0, 1], [0, 1, 1]]], dtype=torch.bool, device="cuda:0") >>> block_row_sz = torch.tensor([[1, 2, 3]], dtype=torch.int32, device="cuda:0") >>> block_col_sz = torch.tensor([[3, 1, 2]], dtype=torch.int32, device="cuda:0") >>> wrapper.plan( ... block_mask_map, ... block_row_sz, ... block_col_sz, ... num_qo_heads, ... num_kv_heads, ... head_dim, ... ) >>> q = torch.randn((num_qo_heads, seq_len, head_dim), dtype=torch.float16, device="cuda:0") >>> k = torch.randn((num_kv_heads, seq_len, head_dim), dtype=torch.float16, device="cuda:0") >>> v = torch.randn((num_kv_heads, seq_len, head_dim), dtype=torch.float16, device="cuda:0") >>> o = wrapper.run(q, k, v) """ def __init__( self, float_workspace_buffer: torch.Tensor, backend: str = "auto", ) -> None: r"""Constructs of :class:`VariableBlockSparseAttentionWrapper`. Parameters ---------- float_workspace_buffer : torch.Tensor The user reserved float workspace buffer used to store intermediate attention results in the split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors. backend : str The implementation backend, could be ``auto``/``fa2`` or ``fa3``. Defaults to ``auto``. If set to ``auto``, the function will automatically choose the backend based on the device architecture and kernel availability. """ self._float_workspace_buffer = float_workspace_buffer self.device = float_workspace_buffer.device self._workspace_size = ( float_workspace_buffer.numel() * float_workspace_buffer.element_size() ) self._int_workspace_buffer = torch.empty( (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device ) if backend in ["fa3", "auto"]: self._vector_sparse_indices_buffer = torch.empty( (128 * 1024 * 1024,), dtype=torch.int32, device=self.device ) self._vector_sparse_indptr_buffer = torch.empty( (32768,), dtype=torch.int32, device=self.device ) self._kv_lens_buffer = torch.empty( (32768,), dtype=torch.int32, device=self.device ) self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=torch.uint8, pin_memory=True, device="cpu", ) self._use_cuda_graph = False self._kv_layout = "NHD" self._qo_indptr: Optional[torch.Tensor] = None self._paged_kv_indptr_buf: Optional[torch.Tensor] = None self._paged_kv_indices_buf: Optional[torch.Tensor] = None self._paged_kv_last_page_len: Optional[torch.Tensor] = None self._backend = backend def reset_workspace_buffer( self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor, vector_sparse_indices_buffer: Optional[torch.Tensor] = None, vector_sparse_indptr_buffer: Optional[torch.Tensor] = None, ) -> None: r"""Reset the workspace buffer. Parameters ---------- float_workspace_buffer : torch.Tensor The new float workspace buffer, the device of the new float workspace buffer should be the same as the device of the input tensors. int_workspace_buffer : torch.Tensor The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors. """ self._float_workspace_buffer = float_workspace_buffer self._int_workspace_buffer = int_workspace_buffer self._workspace_size = ( float_workspace_buffer.numel() * float_workspace_buffer.element_size() ) self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=self._int_workspace_buffer.dtype, pin_memory=True, ) # Enable user-defined size if vector_sparse_indices_buffer is not None: self._vector_sparse_indices_buffer = vector_sparse_indices_buffer if vector_sparse_indptr_buffer is not None: self._vector_sparse_indptr_buffer = vector_sparse_indptr_buffer def plan( self, block_mask_map: torch.Tensor, block_row_sz: torch.Tensor, block_col_sz: torch.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, causal: bool = False, pos_encoding_mode: str = "NONE", use_fp16_qk_reduction: bool = False, logits_soft_cap: Optional[float] = None, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, non_blocking: bool = True, q_data_type: Union[str, torch.dtype] = "float16", kv_data_type: Optional[Union[str, torch.dtype]] = None, ) -> None: r"""Create auxiliary data structures for block sparse attention. Parameters ---------- block_mask_map : torch.Tensor The block mask map (boolean), shape ``(num_kv_heads, MB, NB)``, where ``MB`` is the number of blocks in the row dimension, ``NB`` is the number of blocks in the column dimension. block_row_sz : torch.Tensor The block row size, shape ``(num_kv_heads, MB,)``. block_col_sz : torch.Tensor The block column size, shape ``(num_kv_heads, NB,)``. num_qo_heads : int The number of heads in the query/output tensor. num_kv_heads : int The number of heads in the key/value tensor. Note that a group of ``qo_heads`` shares the same sparse pattern of ``kv_heads``. head_dim : int The dimension of each head. causal : bool Whether to apply causal mask to the attention matrix. pos_encoding_mode : str, optional The position encoding applied inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. Default is ``NONE``. use_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). logits_soft_cap : Optional[float] The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to ``0``. If greater than 0, the logits will be capped according to formula: :math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`, where :math:`x` is the input logits. sm_scale : Optional[float] The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. rope_scale : Optional[float] The scale used in RoPE interpolation, if not provided, will be set to ``1.0``. rope_theta : Optional[float] The theta used in RoPE, if not provided, will be set to ``1e4``. non_blocking : bool Whether to copy the input tensors to the device asynchronously, defaults to ``True``. The :meth:`plan` method should be called before any :meth:`run` or :meth:`run_return_lse` calls, auxiliary data structures will be created during this call and cached for multiple kernel runs. The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` is not equal to ``num_kv_heads``, the function will use `grouped query attention `_. """ q_data_type = canonicalize_torch_dtype(q_data_type) if kv_data_type is None: kv_data_type = q_data_type kv_data_type = canonicalize_torch_dtype(kv_data_type) self._o_dtype = q_data_type if logits_soft_cap is None: logits_soft_cap = 0.0 # num_blocks are constant across kv_heads num_blocks_row = block_row_sz.shape[-1] num_blocks_col = block_col_sz.shape[-1] # q layout: [seq_len, num_kv_heads, gqa_group_size, head_dim] # padded into: [seq_len * num_kv_heads, 1, gqa_group_size, head_dim] qo_indptr = torch.cat( [ torch.zeros(1, dtype=torch.int32, device=block_row_sz.device), torch.cumsum(block_row_sz.flatten(), dim=0, dtype=torch.int32), ], dim=0, ) qo_indptr_host = qo_indptr.to("cpu", non_blocking=non_blocking) last_block_len = torch.full( (num_blocks_row * num_kv_heads,), 1, dtype=torch.int32, device=block_mask_map.device, ) # We use page_size == 1 for variable length support # HND kv layout: [num_kv_heads, num_blocks, block_size, head_dim] # padded into: [num_kv_heads * num_blocks, block_size, 1, head_dim] # for customized attention mask for each kv_head # NOTE(Yilong): This could be perf bottleneck. Consider Triton implementation. def _block_mask_map_to_expanded_indices( block_mask_map: torch.Tensor, # [H, R, C] bool / {0,1} block_col_sz: torch.Tensor, # [H, C] int ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: block_mask_map: bool/int [num_kv_heads, num_blocks_row, num_blocks_col] block_col_sz: int32/64 [num_kv_heads, num_blocks_col] Returns: kv_indptr: [H*R + 1] int32 — CSR indptr kv_indices: [nnz] int32 — token indices per (head, row) """ device = block_mask_map.device dtype_i = torch.int32 # 1) Calculate the total length of each row (head, row) row_lengths = (block_mask_map * block_col_sz[:, None, :]).sum(-1) # [H,R] kv_indptr = torch.cat( [ torch.zeros(1, dtype=dtype_i, device=device), torch.cumsum(row_lengths.flatten(), 0), ], dim=0, ) # 2) Calculate the offset of each column block within its head col_offset = ( torch.cumsum(block_col_sz.to(dtype_i), 1) - block_col_sz ) # [H,C] head_len = block_col_sz.sum(1, dtype=dtype_i) head_offset = torch.cumsum(head_len, 0) - head_len # 3) Find all selected (h,r,c) h_idx, r_idx, c_idx = block_mask_map.nonzero(as_tuple=True) lengths = block_col_sz[h_idx, c_idx].to(dtype_i) # [N] base = head_offset[h_idx] + col_offset[h_idx, c_idx] # [N] # 4) Expand variable-length column blocks into token-level indices cum = torch.cumsum(lengths, 0) starts = torch.repeat_interleave(cum - lengths, lengths) # [total] offsets_within = torch.arange(cum[-1], device=device) - starts kv_indices = torch.repeat_interleave(base, lengths) + offsets_within return kv_indptr.to(dtype=dtype_i, device=device), kv_indices.to( dtype=dtype_i, device=device ) kv_indptr, kv_indices = _block_mask_map_to_expanded_indices( block_mask_map, block_col_sz ) kv_indptr_host = kv_indptr.to("cpu", non_blocking=non_blocking) kv_indices_host = kv_indices.to("cpu", non_blocking=non_blocking) self._qo_indptr = qo_indptr.to(self.device, non_blocking=non_blocking) self._paged_kv_indptr_buf = kv_indptr.to(self.device, non_blocking=non_blocking) self._paged_kv_indices_buf = kv_indices.to( self.device, non_blocking=non_blocking ) self._paged_kv_last_page_len = last_block_len.to( self.device, non_blocking=non_blocking ) torch.cuda.synchronize() # for non-blocking copy self._mask_mode = MaskMode.CAUSAL.value if causal else MaskMode.NON_CAUSAL.value # Sanity check assert num_qo_heads % num_kv_heads == 0, ( "num_qo_heads must be a multiple of num_kv_heads" ) assert num_blocks_row * num_kv_heads + 1 == kv_indptr_host.shape[0] assert kv_indptr_host[-1].item() == kv_indices_host.shape[0], ( f"{kv_indptr_host[-1].item()} != {kv_indices_host.shape[0]}" ) assert num_kv_heads == block_mask_map.shape[0] assert num_kv_heads == block_row_sz.shape[0] assert num_kv_heads == block_col_sz.shape[0] assert num_blocks_row == block_mask_map.shape[1] assert num_blocks_col == block_mask_map.shape[2] if self._backend == "auto": self._backend = determine_attention_backend( self.device, PosEncodingMode[pos_encoding_mode].value, use_fp16_qk_reduction, self._mask_mode == MaskMode.CUSTOM.value, # use_custom_mask q_data_type, kv_data_type, ) get_module_args = ( q_data_type, kv_data_type, self._o_dtype, kv_indptr_host.dtype, head_dim, # head_dim_qk head_dim, # head_dim_vo PosEncodingMode[pos_encoding_mode].value, False, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap use_fp16_qk_reduction, ) self._cached_module = get_batch_prefill_module(self._backend, *get_module_args) kv_lens_arr_host = kv_indptr_host[1:] - kv_indptr_host[:-1] # page_size == 1 self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_( kv_lens_arr_host, ) if self._backend == "fa3": if self._vector_sparse_indptr_buffer.numel() <= kv_indptr.numel(): raise ValueError( "_vector_sparse_indptr_buffer is not large enough. Please increase the buffer size." ) self._vector_sparse_indptr_buffer[: len(kv_indptr)].copy_( kv_indptr, non_blocking=non_blocking ) self._plan_info = self._cached_module.plan( self._float_workspace_buffer, self._int_workspace_buffer, self._pin_memory_int_workspace_buffer, qo_indptr_host, kv_indptr_host, kv_lens_arr_host, qo_indptr_host[-1].item(), # total_num_rows num_blocks_row * num_kv_heads, # batch_size num_qo_heads // num_kv_heads, # num_qo_heads (gqa_group_size) 1, # num_kv_heads, 1, # page_size False, # is_cuda_graph_enabled, head_dim, head_dim, causal, ) self._pos_encoding_mode = pos_encoding_mode self._use_fp16_qk_reduction = use_fp16_qk_reduction self._logits_soft_cap = logits_soft_cap self._sm_scale = sm_scale self._rope_scale = rope_scale self._rope_theta = rope_theta self._num_kv_heads = num_kv_heads self._gqa_group_size = num_qo_heads // num_kv_heads def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, pos_encoding_mode: str = "NONE", use_fp16_qk_reduction: bool = False, logits_soft_cap: Optional[float] = None, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ) -> torch.Tensor: r"""Warning: This method is deprecated, please use :meth:`run` instead.""" self._pos_encoding_mode = pos_encoding_mode self._use_fp16_qk_reduction = use_fp16_qk_reduction self._logits_soft_cap = logits_soft_cap self._sm_scale = sm_scale self._rope_scale = rope_scale self._rope_theta = rope_theta return self.run(q, k, v) def run( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, return_lse: bool = False, enable_pdl: Optional[bool] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Compute block-sparse attention between Q/K/V tensors. Parameters ---------- q : torch.Tensor The query tensor with shape ``(num_qo_heads, qo_len, head_dim)``. k : torch.Tensor The key tensor with shape ``(num_kv_heads, kv_len, head_dim)``. v : torch.Tensor The value tensor with shape ``(num_kv_heads, kv_len, head_dim)``. out : Optional[torch.Tensor] The output tensor, if not provided, will be allocated internally. lse : Optional[torch.Tensor] The log-sum-exp of attention logits, if not provided, will be allocated internally. return_lse : bool Whether to return the log-sum-exp of attention logits enable_pdl : bool Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization Only supported for >= sm90, and currently only for FA2 and CUDA core decode. Returns ------- Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] If :attr:`return_lse` is ``False``, the attention output, shape: ``[M, num_qo_heads, head_dim]``. If :attr:`return_lse` is ``True``, a tuple of two tensors: * The attention output, shape: ``[M, num_qo_heads, head_dim]``. * The logsumexp of attention output, shape: ``[M, num_qo_heads]``. """ # NOTE(Zihao): defer import of einops import einops if enable_pdl is None: enable_pdl = device_support_pdl(q.device) pos_encoding_mode = self._pos_encoding_mode logits_soft_cap = self._logits_soft_cap sm_scale = self._sm_scale rope_scale = self._rope_scale rope_theta = self._rope_theta _check_pos_encoding_mode(pos_encoding_mode) if logits_soft_cap is None: logits_soft_cap = 0.0 if sm_scale is None: sm_scale = 1.0 / math.sqrt(q.size(-1)) if rope_scale is None: rope_scale = 1.0 if rope_theta is None: rope_theta = 1e4 # reshape to pad num_kv_heads into seq_len # input [num_qo_heads, qo_len, head_dim] # kernel layout is NHD -> [qo_len * num_kv_heads, gqa_group_size, head_dim] q = einops.rearrange( q, "(num_kv_heads gqa_group_size) qo_len head_dim -> (num_kv_heads qo_len) gqa_group_size head_dim", num_kv_heads=self._num_kv_heads, ).contiguous() # HND -> [kv_len * num_kv_heads (num_pages), 1 (page_size), 1 (new_num_kv_heads), head_dim] k = einops.rearrange( k, "num_kv_heads kv_len head_dim -> (num_kv_heads kv_len) 1 1 head_dim", ).contiguous() v = einops.rearrange( v, "num_kv_heads kv_len head_dim -> (num_kv_heads kv_len) 1 1 head_dim", ).contiguous() stride_block = k.stride(0) stride_n = k.stride(1) if return_lse: if lse is None: lse = torch.empty( (q.size(0), q.size(1)), dtype=torch.float32, device=q.device ) else: check_shape_dtype_device( lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse" ) if out is None: out = torch.empty_like(q, dtype=self._o_dtype) else: check_shape_dtype_device(out, q.shape, self._o_dtype, q.device, "out") if self._backend == "fa3": if ( self._vector_sparse_indices_buffer.numel() <= self._paged_kv_indices_buf.numel() ): raise ValueError( "_vector_sparse_indices_buffer is not large enough. Please increase the buffer size." ) sparse_indices = block_sparse_indices_to_vector_sparse_offsets( self._paged_kv_indices_buf, self._paged_kv_indptr_buf, self._vector_sparse_indices_buffer, # output self._vector_sparse_indptr_buffer, self._kv_lens_buffer, stride_block // stride_n, 1, # stride_n // stride_n 1, # block_size ) sparse_indptr = self._vector_sparse_indptr_buffer else: sparse_indices = self._paged_kv_indices_buf sparse_indptr = self._paged_kv_indptr_buf self._cached_module.paged_run( self._float_workspace_buffer, self._int_workspace_buffer, self._plan_info, q, k, v, self._qo_indptr, sparse_indptr, sparse_indices, self._paged_kv_last_page_len, out, lse, self._mask_mode, TensorLayout[self._kv_layout].value, -1, # window_left enable_pdl, # ADDITIONAL_FUNC_PARAMS # Not supported yet None, # packed_mask_buf None, # mask_indptr_buf None, # alibi_slopes_buf None, None, None, logits_soft_cap, sm_scale, None, # scale_q None, # scale_k None, # scale_v rope_scale, rope_theta, 0, # token_pos_in_items_len self._workspace_size, ) # [qo_len * num_kv_heads, gqa_group_size, head_dim] -> HND out = einops.rearrange( out, "(num_kv_heads qo_len) gqa_group_size head_dim -> (num_kv_heads gqa_group_size) qo_len head_dim", num_kv_heads=self._num_kv_heads, ).contiguous() if return_lse: lse = einops.rearrange( lse, "(num_kv_heads qo_len) gqa_group_size -> (num_kv_heads gqa_group_size) qo_len", num_kv_heads=self._num_kv_heads, ).contiguous() return (out, lse) if return_lse else out