1257 lines
51 KiB
Python
1257 lines
51 KiB
Python
"""
|
|
Copyright (c) 2024 by FlashInfer team.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import math
|
|
from typing import Optional, Tuple, Union
|
|
|
|
import torch
|
|
|
|
from .decode import get_batch_decode_module
|
|
from .page import block_sparse_indices_to_vector_sparse_offsets
|
|
from .prefill import _compute_page_mask_indptr, get_batch_prefill_module
|
|
from .quantization import segment_packbits
|
|
from .utils import (
|
|
MaskMode,
|
|
PosEncodingMode,
|
|
TensorLayout,
|
|
_check_pos_encoding_mode,
|
|
check_shape_dtype_device,
|
|
_get_cache_alibi_slopes_buf,
|
|
canonicalize_torch_dtype,
|
|
determine_attention_backend,
|
|
device_support_pdl,
|
|
is_float8,
|
|
)
|
|
|
|
|
|
def convert_bsr_mask_layout(mask: torch.Tensor, indptr: torch.Tensor) -> torch.Tensor:
|
|
r"""Convert mask from BSR data layout to flashinfer's flattened mask layout.
|
|
|
|
Parameters
|
|
----------
|
|
mask : torch.Tensor
|
|
A boolean mask tensor with shape ``(nnz, R, C)``.
|
|
indptr : torch.Tensor
|
|
The indptr tensor in BSR format.
|
|
|
|
Returns
|
|
-------
|
|
flattened_mask : torch.Tensor
|
|
A flattenedd mask tensor with shape ``(nnz * R * C,)``.
|
|
"""
|
|
nnz, R, C = mask.shape
|
|
MB = len(indptr) - 1
|
|
mask_flashinfer = torch.empty((nnz * R * C,), dtype=mask.dtype, device=mask.device)
|
|
for i in range(MB):
|
|
mask_flashinfer[indptr[i] * R * C : indptr[i + 1] * R * C] = (
|
|
mask[indptr[i] : indptr[i + 1]].transpose(0, 1).reshape(-1)
|
|
)
|
|
return mask_flashinfer
|
|
|
|
|
|
class BlockSparseAttentionWrapper:
|
|
r"""Wrapper class for attention computation with a block-sparse matrix as attention mask.
|
|
The definition of block sparse matrix can be found at
|
|
`bsr_matrix <https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.bsr_matrix.html>`_
|
|
in SciPy.
|
|
|
|
This API supports any block size ``(R, C)``.
|
|
|
|
Example
|
|
-------
|
|
>>> import torch
|
|
>>> import flashinfer
|
|
>>> num_qo_heads = 32
|
|
>>> num_kv_heads = 8
|
|
>>> head_dim = 128
|
|
>>> # allocate 128MB workspace buffer
|
|
>>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
|
|
>>> bsr_wrapper = flashinfer.BlockSparseAttentionWrapper(workspace_buffer)
|
|
>>> # sparse mask: [[0, 0, 1], [1, 0, 1], [0, 1, 1]]
|
|
>>> M = 3
|
|
>>> N = 3
|
|
>>> indptr = torch.tensor([0, 1, 3, 5], dtype=torch.int32, device="cuda:0")
|
|
>>> indices = torch.tensor([2, 0, 2, 1, 2], dtype=torch.int32, device="cuda:0")
|
|
>>> bsr_wrapper.plan(
|
|
... indptr,
|
|
... indices,
|
|
... M,
|
|
... N,
|
|
... 1, # R(block_rows)=1
|
|
... 1, # C(block_columns)=1
|
|
... num_qo_heads,
|
|
... num_kv_heads,
|
|
... head_dim,
|
|
... )
|
|
>>> q = torch.randn((M, num_qo_heads, head_dim), dtype=torch.float16, device="cuda:0")
|
|
>>> k = torch.randn((N, num_kv_heads, head_dim), dtype=torch.float16, device="cuda:0")
|
|
>>> v = torch.randn((N, num_kv_heads, head_dim), dtype=torch.float16, device="cuda:0")
|
|
>>> o = bsr_wrapper.run(q, k, v)
|
|
>>> # use dense implementation with attention mask for comparison
|
|
>>> mask = torch.tensor([[0, 0, 1], [1, 0, 1], [0, 1, 1]], dtype=torch.bool, device="cuda:0")
|
|
>>> o_ref = flashinfer.single_prefill_with_kv_cache(q, k, v, custom_mask=mask)
|
|
>>> torch.allclose(o, o_ref)
|
|
True
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
float_workspace_buffer: torch.Tensor,
|
|
backend: str = "auto",
|
|
) -> None:
|
|
r"""Constructs of :class:`BlockSparseAttentionWrapper`.
|
|
|
|
Parameters
|
|
----------
|
|
float_workspace_buffer : torch.Tensor
|
|
The user reserved float workspace buffer used to store intermediate attention results
|
|
in the split-k algorithm. The recommended size is 128MB, the device of the workspace
|
|
buffer should be the same as the device of the input tensors.
|
|
backend : str
|
|
The implementation backend, could be ``auto``/``fa2`` or ``fa3``. Defaults to ``auto``.
|
|
If set to ``auto``, the function will automatically choose the backend based on the
|
|
device architecture and kernel availability.
|
|
"""
|
|
self._float_workspace_buffer = float_workspace_buffer
|
|
self.device = float_workspace_buffer.device
|
|
self._workspace_size = (
|
|
float_workspace_buffer.numel() * float_workspace_buffer.element_size()
|
|
)
|
|
self._int_workspace_buffer = torch.empty(
|
|
(8 * 1024 * 1024,), dtype=torch.uint8, device=self.device
|
|
)
|
|
if backend in ["fa3", "auto"]:
|
|
# NOTE(Zihao): assume maximum accumulate kv length is 128M
|
|
# NOTE(Yilong): 128M is required by video DiT models
|
|
self._vector_sparse_indices_buffer = torch.empty(
|
|
(128 * 1024 * 1024,), dtype=torch.int32, device=self.device
|
|
)
|
|
# NOTE(Zihao): assume maximum batch size is 32768
|
|
self._vector_sparse_indptr_buffer = torch.empty(
|
|
(32768,), dtype=torch.int32, device=self.device
|
|
)
|
|
|
|
self._kv_lens_buffer = torch.empty(
|
|
(32768,), dtype=torch.int32, device=self.device
|
|
)
|
|
self._pin_memory_int_workspace_buffer = torch.empty(
|
|
self._int_workspace_buffer.shape,
|
|
dtype=torch.uint8,
|
|
pin_memory=True,
|
|
device="cpu",
|
|
)
|
|
self._use_cuda_graph = False
|
|
self._kv_layout = "NHD"
|
|
self._qo_indptr: Optional[torch.Tensor] = None
|
|
self._paged_kv_indptr_buf: Optional[torch.Tensor] = None
|
|
self._paged_kv_indices_buf: Optional[torch.Tensor] = None
|
|
self._paged_kv_last_page_len: Optional[torch.Tensor] = None
|
|
self._packed_mask_buf: Optional[torch.Tensor] = None
|
|
self._mask_indptr_buf: Optional[torch.Tensor] = None
|
|
self.R: Optional[int] = None
|
|
self.C: Optional[int] = None
|
|
self.M: Optional[int] = None
|
|
self.N: Optional[int] = None
|
|
self._backend = backend
|
|
|
|
def reset_workspace_buffer(
|
|
self,
|
|
float_workspace_buffer: torch.Tensor,
|
|
int_workspace_buffer: torch.Tensor,
|
|
vector_sparse_indices_buffer: Optional[torch.Tensor] = None,
|
|
vector_sparse_indptr_buffer: Optional[torch.Tensor] = None,
|
|
) -> None:
|
|
r"""Reset the workspace buffer.
|
|
|
|
Parameters
|
|
----------
|
|
float_workspace_buffer : torch.Tensor
|
|
The new float workspace buffer, the device of the new float workspace buffer should
|
|
be the same as the device of the input tensors.
|
|
|
|
int_workspace_buffer : torch.Tensor
|
|
The new int workspace buffer, the device of the new int workspace buffer should
|
|
be the same as the device of the input tensors.
|
|
"""
|
|
self._float_workspace_buffer = float_workspace_buffer
|
|
self._int_workspace_buffer = int_workspace_buffer
|
|
self._workspace_size = (
|
|
float_workspace_buffer.numel() * float_workspace_buffer.element_size()
|
|
)
|
|
self._pin_memory_int_workspace_buffer = torch.empty(
|
|
self._int_workspace_buffer.shape,
|
|
dtype=self._int_workspace_buffer.dtype,
|
|
pin_memory=True,
|
|
)
|
|
|
|
# Enable user-defined size
|
|
if vector_sparse_indices_buffer is not None:
|
|
self._vector_sparse_indices_buffer = vector_sparse_indices_buffer
|
|
if vector_sparse_indptr_buffer is not None:
|
|
self._vector_sparse_indptr_buffer = vector_sparse_indptr_buffer
|
|
|
|
def plan(
|
|
self,
|
|
indptr: torch.Tensor,
|
|
indices: torch.Tensor,
|
|
M: int,
|
|
N: int,
|
|
R: int,
|
|
C: int,
|
|
num_qo_heads: int,
|
|
num_kv_heads: int,
|
|
head_dim: int,
|
|
mask: Optional[torch.Tensor] = None,
|
|
packed_mask: Optional[torch.Tensor] = None,
|
|
causal: bool = False,
|
|
pos_encoding_mode: str = "NONE",
|
|
use_fp16_qk_reduction: bool = False,
|
|
logits_soft_cap: Optional[float] = None,
|
|
sm_scale: Optional[float] = None,
|
|
rope_scale: Optional[float] = None,
|
|
rope_theta: Optional[float] = None,
|
|
q_data_type: Union[str, torch.dtype] = "float16",
|
|
kv_data_type: Optional[Union[str, torch.dtype]] = None,
|
|
o_data_type: Union[str, torch.dtype] = "float16",
|
|
non_blocking: bool = True,
|
|
) -> None:
|
|
r"""Create auxiliary data structures for block sparse attention.
|
|
|
|
Parameters
|
|
----------
|
|
indptr : torch.Tensor
|
|
The block index pointer of the block-sparse matrix on row dimension, shape ``(MB + 1,)``,
|
|
where ``MB`` is the number of blocks in the row dimension.
|
|
indices: torch.Tensor
|
|
The block indices of the block-sparse matrix on column dimension, shape ``(nnz,)``, where
|
|
``nnz`` is the number of non-zero blocks. The elements in ``indices`` array should be less then ``NB``:
|
|
the number of blocks in the column dimension.
|
|
M : int
|
|
The number of rows of the block-sparse matrix, ``MB = ceil_div(M, R)``.
|
|
N : int
|
|
The number of columns of the block-sparse matrix, ``NB = N // C``, ``N`` should be divisible by ``C``.
|
|
R : int
|
|
The number of rows in each block.
|
|
C : int
|
|
The number of columns in each block.
|
|
num_qo_heads : int
|
|
The number of heads in the query/output tensor.
|
|
num_kv_heads : int
|
|
The number of heads in the key/value tensor.
|
|
head_dim : int
|
|
The dimension of each head.
|
|
mask : torch.Tensor, optional
|
|
The mask tensor with shape ``(nnz, R, C,)``, where nnz is the number of non-zero blocks.
|
|
If every block is full, then we don't need to provide the mask tensor.
|
|
packed_mask : torch.Tensor, optional
|
|
The 1D packed mask tensor, if provided, the :attr:`custom_mask` will be ignored.
|
|
The packed mask tensor is generated by :func:`flashinfer.quantization.packbits`.
|
|
causal : bool
|
|
Whether to apply causal mask to the attention matrix.
|
|
This is only effective when :attr:`custom_mask` is not provided in
|
|
:meth:`plan`.
|
|
pos_encoding_mode : str, optional
|
|
The position encoding applied inside attention kernels, could be
|
|
``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``.
|
|
Default is ``NONE``.
|
|
use_fp16_qk_reduction : bool
|
|
Whether to use f16 for qk reduction (faster at the cost of slight precision
|
|
loss).
|
|
logits_soft_cap : Optional[float]
|
|
The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not
|
|
provided, will be set to ``0``. If greater than 0, the logits will be capped according to
|
|
formula:
|
|
:math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`,
|
|
where :math:`x` is the input logits.
|
|
sm_scale : Optional[float]
|
|
The scale used in softmax, if not provided, will be set to
|
|
``1.0 / sqrt(head_dim)``.
|
|
rope_scale : Optional[float]
|
|
The scale used in RoPE interpolation, if not provided, will be set to
|
|
``1.0``.
|
|
rope_theta : Optional[float]
|
|
The theta used in RoPE, if not provided, will be set to ``1e4``.
|
|
q_data_type : str, optional
|
|
The data type of the query tensor.
|
|
kv_data_type : Optional[Union[str, torch.dtype]]
|
|
The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
|
|
o_data_type : str, optional
|
|
The data type of the output tensor. Default is ``half``. As output dtype cannot
|
|
be inferred by input dtype in quantization
|
|
non_blocking : bool
|
|
Whether to copy the input tensors to the device asynchronously, defaults to ``True``.
|
|
|
|
|
|
The :meth:`plan` method should be called before any :meth:`run` or
|
|
:meth:`run_return_lse` calls, auxiliary data structures will be created
|
|
during this call and cached for multiple kernel runs.
|
|
|
|
The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads``
|
|
is not equal to ``num_kv_heads``, the function will use
|
|
`grouped query attention <https://arxiv.org/abs/2305.13245>`_.
|
|
"""
|
|
q_data_type = canonicalize_torch_dtype(q_data_type)
|
|
if kv_data_type is None:
|
|
kv_data_type = q_data_type
|
|
kv_data_type = canonicalize_torch_dtype(kv_data_type)
|
|
self._o_dtype = canonicalize_torch_dtype(o_data_type)
|
|
|
|
if logits_soft_cap is None:
|
|
logits_soft_cap = 0.0
|
|
|
|
num_blocks_row = len(indptr) - 1
|
|
qo_indptr_host = R * torch.arange(num_blocks_row + 1, dtype=torch.int32)
|
|
qo_indptr_host[-1] = M
|
|
qo_indptr = qo_indptr_host.to(indptr.device, non_blocking=non_blocking)
|
|
if indices.max().item() * C > N:
|
|
raise ValueError("indices out of bound")
|
|
last_block_len = torch.full(
|
|
(num_blocks_row,), C, dtype=torch.int32, device=indptr.device
|
|
)
|
|
|
|
if mask is not None or packed_mask is not None:
|
|
mask_indptr = _compute_page_mask_indptr(
|
|
qo_indptr,
|
|
indptr, # paged_kv_indptr
|
|
last_block_len, # paged_kv_last_page_len
|
|
C, # page_size
|
|
)
|
|
if packed_mask is None and mask is not None:
|
|
# first convert BSR mask to flashinfer layout
|
|
mask = convert_bsr_mask_layout(mask, indptr)
|
|
# create packed mask from mask
|
|
packed_mask, mask_indptr = segment_packbits(
|
|
mask.contiguous().view(-1), mask_indptr, bitorder="little"
|
|
)
|
|
|
|
self._qo_indptr = qo_indptr.to(self.device, non_blocking=non_blocking)
|
|
self._paged_kv_indptr_buf = indptr.to(self.device, non_blocking=non_blocking)
|
|
self._paged_kv_indices_buf = indices.to(self.device, non_blocking=non_blocking)
|
|
self._paged_kv_last_page_len = last_block_len.to(
|
|
self.device, non_blocking=non_blocking
|
|
)
|
|
if packed_mask is not None:
|
|
self._packed_mask_buf = packed_mask.to(
|
|
self.device, non_blocking=non_blocking
|
|
)
|
|
self._mask_indptr_buf = mask_indptr.to(
|
|
self.device, non_blocking=non_blocking
|
|
)
|
|
mask_mode = MaskMode.CUSTOM.value
|
|
else:
|
|
self._packed_mask_buf = None
|
|
self._mask_indptr_buf = None
|
|
mask_mode = MaskMode.CAUSAL.value if causal else MaskMode.NON_CAUSAL.value
|
|
self._mask_mode = mask_mode
|
|
|
|
self.M = M
|
|
self.N = N
|
|
self.R = R
|
|
self.C = C
|
|
|
|
kv_indptr_host = indptr.to("cpu")
|
|
|
|
# NOTE(Zihao): we haven't supported mask in cuda-core implementations but it should
|
|
# be easy to add support for it if needed, leave it as a future work.
|
|
# at this moment, when mask is provided, we use the tensor-core implementation
|
|
if (
|
|
R * (num_qo_heads // num_kv_heads) < 4
|
|
and mask_mode != MaskMode.CUSTOM.value
|
|
and q_data_type not in [torch.float8_e4m3fn, torch.float8_e5m2]
|
|
):
|
|
# If the operation is not compute-bound, we use the cuda-core implementation
|
|
self._use_tensor_cores = False
|
|
self._cached_module = get_batch_decode_module(
|
|
q_data_type,
|
|
kv_data_type,
|
|
self._o_dtype,
|
|
indptr.dtype,
|
|
head_dim,
|
|
head_dim,
|
|
PosEncodingMode[pos_encoding_mode].value,
|
|
False, # use_sliding_window
|
|
logits_soft_cap > 0, # use_logits_soft_cap
|
|
)
|
|
|
|
self._plan_info = self._cached_module.plan(
|
|
self._float_workspace_buffer,
|
|
self._int_workspace_buffer,
|
|
self._pin_memory_int_workspace_buffer,
|
|
kv_indptr_host,
|
|
num_blocks_row,
|
|
num_qo_heads,
|
|
num_kv_heads,
|
|
C,
|
|
False, # is_cuda_graph_enabled
|
|
-1, # window_left
|
|
logits_soft_cap, # logits_soft_cap
|
|
head_dim,
|
|
head_dim,
|
|
torch.empty(0, dtype=q_data_type),
|
|
torch.empty(0, dtype=kv_data_type),
|
|
)
|
|
else:
|
|
# if the operation is compute-bound, we use the tensor-core implementation
|
|
self._use_tensor_cores = True
|
|
|
|
if self._backend == "auto":
|
|
self._backend = determine_attention_backend(
|
|
self.device,
|
|
PosEncodingMode[pos_encoding_mode].value,
|
|
use_fp16_qk_reduction,
|
|
mask_mode == MaskMode.CUSTOM.value, # use_custom_mask
|
|
q_data_type,
|
|
kv_data_type,
|
|
)
|
|
|
|
get_module_args = (
|
|
q_data_type,
|
|
kv_data_type,
|
|
self._o_dtype,
|
|
indptr.dtype,
|
|
head_dim, # head_dim_qk
|
|
head_dim, # head_dim_vo
|
|
PosEncodingMode[pos_encoding_mode].value,
|
|
False, # use_sliding_window
|
|
logits_soft_cap > 0, # use_logits_soft_cap
|
|
use_fp16_qk_reduction,
|
|
)
|
|
self._cached_module = get_batch_prefill_module(
|
|
self._backend, *get_module_args
|
|
)
|
|
|
|
kv_lens_arr_host = (kv_indptr_host[1:] - kv_indptr_host[:-1]) * self.C
|
|
self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_(
|
|
kv_lens_arr_host,
|
|
)
|
|
|
|
if self._backend == "fa3":
|
|
if self.C != 1:
|
|
vector_sparse_indptr_host = torch.cat(
|
|
[
|
|
torch.tensor([0], dtype=torch.int32),
|
|
torch.cumsum(kv_lens_arr_host, dim=0, dtype=torch.int32),
|
|
],
|
|
dim=0,
|
|
)
|
|
self._vector_sparse_indptr_buffer[
|
|
: len(vector_sparse_indptr_host)
|
|
].copy_(vector_sparse_indptr_host, non_blocking=non_blocking)
|
|
kv_indptr_host = vector_sparse_indptr_host
|
|
|
|
self._plan_info = self._cached_module.plan(
|
|
self._float_workspace_buffer,
|
|
self._int_workspace_buffer,
|
|
self._pin_memory_int_workspace_buffer,
|
|
qo_indptr_host,
|
|
kv_indptr_host,
|
|
kv_lens_arr_host,
|
|
M, # total_num_rows
|
|
num_blocks_row, # batch_size
|
|
num_qo_heads,
|
|
num_kv_heads,
|
|
self.C, # page_size
|
|
False, # is_cuda_graph_enabled,
|
|
head_dim,
|
|
head_dim,
|
|
causal,
|
|
)
|
|
|
|
self._pos_encoding_mode = pos_encoding_mode
|
|
self._use_fp16_qk_reduction = use_fp16_qk_reduction
|
|
self._logits_soft_cap = logits_soft_cap
|
|
self._sm_scale = sm_scale
|
|
self._rope_scale = rope_scale
|
|
self._rope_theta = rope_theta
|
|
|
|
begin_forward = plan
|
|
|
|
def forward(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
scale_q: Optional[torch.Tensor] = None,
|
|
scale_k: Optional[torch.Tensor] = None,
|
|
scale_v: Optional[torch.Tensor] = None,
|
|
pos_encoding_mode: str = "NONE",
|
|
use_fp16_qk_reduction: bool = False,
|
|
logits_soft_cap: Optional[float] = None,
|
|
sm_scale: Optional[float] = None,
|
|
rope_scale: Optional[float] = None,
|
|
rope_theta: Optional[float] = None,
|
|
) -> torch.Tensor:
|
|
r"""Warning: This method is deprecated, please use :meth:`run` instead."""
|
|
self._pos_encoding_mode = pos_encoding_mode
|
|
self._use_fp16_qk_reduction = use_fp16_qk_reduction
|
|
self._logits_soft_cap = logits_soft_cap
|
|
self._sm_scale = sm_scale
|
|
self._rope_scale = rope_scale
|
|
self._rope_theta = rope_theta
|
|
return self.run(q, k, v, scale_q, scale_k, scale_v)
|
|
|
|
def run(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
scale_q: Optional[torch.Tensor] = None,
|
|
scale_k: Optional[torch.Tensor] = None,
|
|
scale_v: Optional[torch.Tensor] = None,
|
|
out: Optional[torch.Tensor] = None,
|
|
lse: Optional[torch.Tensor] = None,
|
|
return_lse: bool = False,
|
|
enable_pdl: Optional[bool] = None,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
r"""Compute block-sparse attention between Q/K/V tensors.
|
|
|
|
Parameters
|
|
----------
|
|
q : torch.Tensor
|
|
The query tensor with shape ``(M, num_qo_heads, head_dim)``.
|
|
k : torch.Tensor
|
|
The key tensor with shape ``(N, num_kv_heads, head_dim)``.
|
|
v : torch.Tensor
|
|
The value tensor with shape ``(N, num_kv_heads, head_dim)``.
|
|
scale_q : Optional[torch.Tensor]
|
|
The scale tensor for query, per-head quantization with shape: ``[num_qo_heads]``.
|
|
Used with FP8 Quantization. If not provided, will be set to ``1.0``.
|
|
scale_k : Optional[torch.Tensor]
|
|
The scale tensor for key, per-head quantization with shape: ``[num_kv_heads]``.
|
|
Used with FP8 Quantization. If not provided, will be set to ``1.0``.
|
|
scale_v : Optional[torch.Tensor]
|
|
The scale tensor for value, per-head quantization with shape: ``[num_kv_heads]``.
|
|
Used with FP8 Quantization. If not provided, will be set to ``1.0``.
|
|
out : Optional[torch.Tensor]
|
|
The output tensor, if not provided, will be allocated internally.
|
|
lse : Optional[torch.Tensor]
|
|
The log-sum-exp of attention logits, if not provided, will be allocated internally.
|
|
return_lse : bool
|
|
Whether to return the log-sum-exp of attention logits
|
|
enable_pdl : bool
|
|
Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization
|
|
Only supported for >= sm90, and currently only for FA2 and CUDA core decode.
|
|
|
|
Returns
|
|
-------
|
|
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
|
If :attr:`return_lse` is ``False``, the attention output, shape: ``[M, num_qo_heads, head_dim]``.
|
|
If :attr:`return_lse` is ``True``, a tuple of two tensors:
|
|
|
|
* The attention output, shape: ``[M, num_qo_heads, head_dim]``.
|
|
* The logsumexp of attention output, shape: ``[M, num_qo_heads]``.
|
|
"""
|
|
if enable_pdl is None:
|
|
enable_pdl = device_support_pdl(q.device)
|
|
|
|
pos_encoding_mode = self._pos_encoding_mode
|
|
logits_soft_cap = self._logits_soft_cap
|
|
sm_scale = self._sm_scale
|
|
rope_scale = self._rope_scale
|
|
rope_theta = self._rope_theta
|
|
_check_pos_encoding_mode(pos_encoding_mode)
|
|
if logits_soft_cap is None:
|
|
logits_soft_cap = 0.0
|
|
if sm_scale is None:
|
|
sm_scale = 1.0 / math.sqrt(q.size(-1))
|
|
if rope_scale is None:
|
|
rope_scale = 1.0
|
|
if rope_theta is None:
|
|
rope_theta = 1e4
|
|
k = k.reshape(-1, self.C, *k.shape[-2:])
|
|
v = v.reshape(-1, self.C, *v.shape[-2:])
|
|
|
|
stride_block = k.stride(0)
|
|
stride_n = k.stride(1)
|
|
|
|
if return_lse:
|
|
if lse is None:
|
|
lse = torch.empty(
|
|
(q.size(0), q.size(1)), dtype=torch.float32, device=q.device
|
|
)
|
|
else:
|
|
check_shape_dtype_device(
|
|
lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
|
|
)
|
|
|
|
if out is None:
|
|
out = torch.empty_like(q, dtype=self._o_dtype)
|
|
else:
|
|
check_shape_dtype_device(out, q.shape, self._o_dtype, q.device, "out")
|
|
|
|
if is_float8(q):
|
|
assert q.dtype == k.dtype == v.dtype
|
|
assert q.shape[-1] == k.shape[-1] == v.shape[-1]
|
|
assert self._backend == "fa3" and self._use_tensor_cores
|
|
|
|
if scale_q is None:
|
|
scale_q = torch.ones(q.shape[1], dtype=torch.float32, device=q.device)
|
|
if scale_k is None:
|
|
scale_k = torch.ones(k.shape[1], dtype=torch.float32, device=q.device)
|
|
if scale_v is None:
|
|
scale_v = torch.ones(v.shape[1], dtype=torch.float32, device=q.device)
|
|
|
|
if self._use_tensor_cores:
|
|
if self._backend == "fa3":
|
|
if (
|
|
self._vector_sparse_indices_buffer.numel()
|
|
<= self._paged_kv_indices_buf.numel() * self.C
|
|
):
|
|
raise ValueError(
|
|
"_vector_sparse_indices_buffer is not large enough. Please increase the size."
|
|
)
|
|
|
|
sparse_indices = block_sparse_indices_to_vector_sparse_offsets(
|
|
self._paged_kv_indices_buf,
|
|
self._paged_kv_indptr_buf,
|
|
self._vector_sparse_indices_buffer, # output
|
|
self._vector_sparse_indptr_buffer,
|
|
self._kv_lens_buffer,
|
|
stride_block // stride_n,
|
|
1, # stride_n // stride_n
|
|
self.C, # block_size
|
|
)
|
|
sparse_indptr = self._vector_sparse_indptr_buffer
|
|
else:
|
|
sparse_indices = self._paged_kv_indices_buf
|
|
sparse_indptr = self._paged_kv_indptr_buf
|
|
|
|
self._cached_module.paged_run(
|
|
self._float_workspace_buffer,
|
|
self._int_workspace_buffer,
|
|
self._plan_info,
|
|
q,
|
|
k,
|
|
v,
|
|
self._qo_indptr,
|
|
sparse_indptr,
|
|
sparse_indices,
|
|
self._paged_kv_last_page_len,
|
|
out,
|
|
lse,
|
|
self._mask_mode,
|
|
TensorLayout[self._kv_layout].value,
|
|
-1, # window_left
|
|
enable_pdl,
|
|
# ADDITIONAL_FUNC_PARAMS
|
|
self._packed_mask_buf,
|
|
self._mask_indptr_buf,
|
|
_get_cache_alibi_slopes_buf(q.shape[1], self.device),
|
|
None, # maybe_prefix_len_ptr
|
|
None, # maybe_token_pos_in_items_ptr
|
|
None, # maybe_max_item_len_ptr
|
|
logits_soft_cap,
|
|
sm_scale,
|
|
scale_q,
|
|
scale_k,
|
|
scale_v,
|
|
rope_scale,
|
|
rope_theta,
|
|
0, # token_pos_in_items_len
|
|
self._workspace_size, # workspace_size
|
|
)
|
|
else:
|
|
self._cached_module.run(
|
|
self._float_workspace_buffer,
|
|
self._int_workspace_buffer,
|
|
self._plan_info,
|
|
q,
|
|
k,
|
|
v,
|
|
self._paged_kv_indptr_buf,
|
|
self._paged_kv_indices_buf,
|
|
self._paged_kv_last_page_len,
|
|
out,
|
|
lse,
|
|
TensorLayout[self._kv_layout].value,
|
|
-1, # window_left
|
|
enable_pdl,
|
|
# ADDITIONAL_FUNC_PARAMS
|
|
_get_cache_alibi_slopes_buf(q.shape[1], self.device),
|
|
logits_soft_cap,
|
|
sm_scale,
|
|
rope_scale,
|
|
rope_theta,
|
|
)
|
|
|
|
return (out, lse) if return_lse else out
|
|
|
|
def end_forward(self) -> None:
|
|
r"""Warning: This method is deprecated and has no effect."""
|
|
pass
|
|
|
|
|
|
class VariableBlockSparseAttentionWrapper:
|
|
r"""Wrapper class for attention computation with a block-sparse matrix as attention mask.
|
|
This API supports variable block sizes provided by ``block_row_sz`` and ``block_col_sz``.
|
|
Besides, each ``kv_head_idx`` can specify its own sparse patterns without using the same mask.
|
|
|
|
Example
|
|
-------
|
|
>>> import torch
|
|
>>> import flashinfer
|
|
>>> num_qo_heads = 1
|
|
>>> num_kv_heads = 1
|
|
>>> head_dim = 128
|
|
>>> seq_len = 6 # This corresponds to the `block_row_sz` and `block_col_sz`
|
|
>>> # allocate 128MB workspace buffer
|
|
>>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
|
|
>>> wrapper = flashinfer.VariableBlockSparseAttentionWrapper(workspace_buffer)
|
|
>>> block_mask_map = torch.tensor([[[0, 0, 1], [1, 0, 1], [0, 1, 1]]], dtype=torch.bool, device="cuda:0")
|
|
>>> block_row_sz = torch.tensor([[1, 2, 3]], dtype=torch.int32, device="cuda:0")
|
|
>>> block_col_sz = torch.tensor([[3, 1, 2]], dtype=torch.int32, device="cuda:0")
|
|
>>> wrapper.plan(
|
|
... block_mask_map,
|
|
... block_row_sz,
|
|
... block_col_sz,
|
|
... num_qo_heads,
|
|
... num_kv_heads,
|
|
... head_dim,
|
|
... )
|
|
>>> q = torch.randn((num_qo_heads, seq_len, head_dim), dtype=torch.float16, device="cuda:0")
|
|
>>> k = torch.randn((num_kv_heads, seq_len, head_dim), dtype=torch.float16, device="cuda:0")
|
|
>>> v = torch.randn((num_kv_heads, seq_len, head_dim), dtype=torch.float16, device="cuda:0")
|
|
>>> o = wrapper.run(q, k, v)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
float_workspace_buffer: torch.Tensor,
|
|
backend: str = "auto",
|
|
) -> None:
|
|
r"""Constructs of :class:`VariableBlockSparseAttentionWrapper`.
|
|
|
|
Parameters
|
|
----------
|
|
float_workspace_buffer : torch.Tensor
|
|
The user reserved float workspace buffer used to store intermediate attention results
|
|
in the split-k algorithm. The recommended size is 128MB, the device of the workspace
|
|
buffer should be the same as the device of the input tensors.
|
|
backend : str
|
|
The implementation backend, could be ``auto``/``fa2`` or ``fa3``. Defaults to ``auto``.
|
|
If set to ``auto``, the function will automatically choose the backend based on the
|
|
device architecture and kernel availability.
|
|
"""
|
|
self._float_workspace_buffer = float_workspace_buffer
|
|
self.device = float_workspace_buffer.device
|
|
self._workspace_size = (
|
|
float_workspace_buffer.numel() * float_workspace_buffer.element_size()
|
|
)
|
|
self._int_workspace_buffer = torch.empty(
|
|
(8 * 1024 * 1024,), dtype=torch.uint8, device=self.device
|
|
)
|
|
if backend in ["fa3", "auto"]:
|
|
self._vector_sparse_indices_buffer = torch.empty(
|
|
(128 * 1024 * 1024,), dtype=torch.int32, device=self.device
|
|
)
|
|
self._vector_sparse_indptr_buffer = torch.empty(
|
|
(32768,), dtype=torch.int32, device=self.device
|
|
)
|
|
|
|
self._kv_lens_buffer = torch.empty(
|
|
(32768,), dtype=torch.int32, device=self.device
|
|
)
|
|
self._pin_memory_int_workspace_buffer = torch.empty(
|
|
self._int_workspace_buffer.shape,
|
|
dtype=torch.uint8,
|
|
pin_memory=True,
|
|
device="cpu",
|
|
)
|
|
self._use_cuda_graph = False
|
|
self._kv_layout = "NHD"
|
|
self._qo_indptr: Optional[torch.Tensor] = None
|
|
self._paged_kv_indptr_buf: Optional[torch.Tensor] = None
|
|
self._paged_kv_indices_buf: Optional[torch.Tensor] = None
|
|
self._paged_kv_last_page_len: Optional[torch.Tensor] = None
|
|
self._backend = backend
|
|
|
|
def reset_workspace_buffer(
|
|
self,
|
|
float_workspace_buffer: torch.Tensor,
|
|
int_workspace_buffer: torch.Tensor,
|
|
vector_sparse_indices_buffer: Optional[torch.Tensor] = None,
|
|
vector_sparse_indptr_buffer: Optional[torch.Tensor] = None,
|
|
) -> None:
|
|
r"""Reset the workspace buffer.
|
|
|
|
Parameters
|
|
----------
|
|
float_workspace_buffer : torch.Tensor
|
|
The new float workspace buffer, the device of the new float workspace buffer should
|
|
be the same as the device of the input tensors.
|
|
|
|
int_workspace_buffer : torch.Tensor
|
|
The new int workspace buffer, the device of the new int workspace buffer should
|
|
be the same as the device of the input tensors.
|
|
"""
|
|
self._float_workspace_buffer = float_workspace_buffer
|
|
self._int_workspace_buffer = int_workspace_buffer
|
|
self._workspace_size = (
|
|
float_workspace_buffer.numel() * float_workspace_buffer.element_size()
|
|
)
|
|
self._pin_memory_int_workspace_buffer = torch.empty(
|
|
self._int_workspace_buffer.shape,
|
|
dtype=self._int_workspace_buffer.dtype,
|
|
pin_memory=True,
|
|
)
|
|
|
|
# Enable user-defined size
|
|
if vector_sparse_indices_buffer is not None:
|
|
self._vector_sparse_indices_buffer = vector_sparse_indices_buffer
|
|
if vector_sparse_indptr_buffer is not None:
|
|
self._vector_sparse_indptr_buffer = vector_sparse_indptr_buffer
|
|
|
|
def plan(
|
|
self,
|
|
block_mask_map: torch.Tensor,
|
|
block_row_sz: torch.Tensor,
|
|
block_col_sz: torch.Tensor,
|
|
num_qo_heads: int,
|
|
num_kv_heads: int,
|
|
head_dim: int,
|
|
causal: bool = False,
|
|
pos_encoding_mode: str = "NONE",
|
|
use_fp16_qk_reduction: bool = False,
|
|
logits_soft_cap: Optional[float] = None,
|
|
sm_scale: Optional[float] = None,
|
|
rope_scale: Optional[float] = None,
|
|
rope_theta: Optional[float] = None,
|
|
non_blocking: bool = True,
|
|
q_data_type: Union[str, torch.dtype] = "float16",
|
|
kv_data_type: Optional[Union[str, torch.dtype]] = None,
|
|
) -> None:
|
|
r"""Create auxiliary data structures for block sparse attention.
|
|
|
|
Parameters
|
|
----------
|
|
block_mask_map : torch.Tensor
|
|
The block mask map (boolean), shape ``(num_kv_heads, MB, NB)``, where ``MB`` is the number of blocks in the row dimension,
|
|
``NB`` is the number of blocks in the column dimension.
|
|
block_row_sz : torch.Tensor
|
|
The block row size, shape ``(num_kv_heads, MB,)``.
|
|
block_col_sz : torch.Tensor
|
|
The block column size, shape ``(num_kv_heads, NB,)``.
|
|
num_qo_heads : int
|
|
The number of heads in the query/output tensor.
|
|
num_kv_heads : int
|
|
The number of heads in the key/value tensor. Note that a group of ``qo_heads`` shares the same sparse pattern of ``kv_heads``.
|
|
head_dim : int
|
|
The dimension of each head.
|
|
causal : bool
|
|
Whether to apply causal mask to the attention matrix.
|
|
pos_encoding_mode : str, optional
|
|
The position encoding applied inside attention kernels, could be
|
|
``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``.
|
|
Default is ``NONE``.
|
|
use_fp16_qk_reduction : bool
|
|
Whether to use f16 for qk reduction (faster at the cost of slight precision
|
|
loss).
|
|
logits_soft_cap : Optional[float]
|
|
The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not
|
|
provided, will be set to ``0``. If greater than 0, the logits will be capped according to
|
|
formula:
|
|
:math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`,
|
|
where :math:`x` is the input logits.
|
|
sm_scale : Optional[float]
|
|
The scale used in softmax, if not provided, will be set to
|
|
``1.0 / sqrt(head_dim)``.
|
|
rope_scale : Optional[float]
|
|
The scale used in RoPE interpolation, if not provided, will be set to
|
|
``1.0``.
|
|
rope_theta : Optional[float]
|
|
The theta used in RoPE, if not provided, will be set to ``1e4``.
|
|
non_blocking : bool
|
|
Whether to copy the input tensors to the device asynchronously, defaults to ``True``.
|
|
|
|
|
|
The :meth:`plan` method should be called before any :meth:`run` or
|
|
:meth:`run_return_lse` calls, auxiliary data structures will be created
|
|
during this call and cached for multiple kernel runs.
|
|
|
|
The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads``
|
|
is not equal to ``num_kv_heads``, the function will use
|
|
`grouped query attention <https://arxiv.org/abs/2305.13245>`_.
|
|
"""
|
|
q_data_type = canonicalize_torch_dtype(q_data_type)
|
|
if kv_data_type is None:
|
|
kv_data_type = q_data_type
|
|
kv_data_type = canonicalize_torch_dtype(kv_data_type)
|
|
self._o_dtype = q_data_type
|
|
|
|
if logits_soft_cap is None:
|
|
logits_soft_cap = 0.0
|
|
|
|
# num_blocks are constant across kv_heads
|
|
num_blocks_row = block_row_sz.shape[-1]
|
|
num_blocks_col = block_col_sz.shape[-1]
|
|
|
|
# q layout: [seq_len, num_kv_heads, gqa_group_size, head_dim]
|
|
# padded into: [seq_len * num_kv_heads, 1, gqa_group_size, head_dim]
|
|
qo_indptr = torch.cat(
|
|
[
|
|
torch.zeros(1, dtype=torch.int32, device=block_row_sz.device),
|
|
torch.cumsum(block_row_sz.flatten(), dim=0, dtype=torch.int32),
|
|
],
|
|
dim=0,
|
|
)
|
|
qo_indptr_host = qo_indptr.to("cpu", non_blocking=non_blocking)
|
|
last_block_len = torch.full(
|
|
(num_blocks_row * num_kv_heads,),
|
|
1,
|
|
dtype=torch.int32,
|
|
device=block_mask_map.device,
|
|
) # We use page_size == 1 for variable length support
|
|
|
|
# HND kv layout: [num_kv_heads, num_blocks, block_size, head_dim]
|
|
# padded into: [num_kv_heads * num_blocks, block_size, 1, head_dim]
|
|
# for customized attention mask for each kv_head
|
|
# NOTE(Yilong): This could be perf bottleneck. Consider Triton implementation.
|
|
def _block_mask_map_to_expanded_indices(
|
|
block_mask_map: torch.Tensor, # [H, R, C] bool / {0,1}
|
|
block_col_sz: torch.Tensor, # [H, C] int
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Args:
|
|
block_mask_map: bool/int [num_kv_heads, num_blocks_row, num_blocks_col]
|
|
block_col_sz: int32/64 [num_kv_heads, num_blocks_col]
|
|
Returns:
|
|
kv_indptr: [H*R + 1] int32 — CSR indptr
|
|
kv_indices: [nnz] int32 — token indices per (head, row)
|
|
"""
|
|
device = block_mask_map.device
|
|
dtype_i = torch.int32
|
|
|
|
# 1) Calculate the total length of each row (head, row)
|
|
row_lengths = (block_mask_map * block_col_sz[:, None, :]).sum(-1) # [H,R]
|
|
kv_indptr = torch.cat(
|
|
[
|
|
torch.zeros(1, dtype=dtype_i, device=device),
|
|
torch.cumsum(row_lengths.flatten(), 0),
|
|
],
|
|
dim=0,
|
|
)
|
|
|
|
# 2) Calculate the offset of each column block within its head
|
|
col_offset = (
|
|
torch.cumsum(block_col_sz.to(dtype_i), 1) - block_col_sz
|
|
) # [H,C]
|
|
head_len = block_col_sz.sum(1, dtype=dtype_i)
|
|
head_offset = torch.cumsum(head_len, 0) - head_len
|
|
|
|
# 3) Find all selected (h,r,c)
|
|
h_idx, r_idx, c_idx = block_mask_map.nonzero(as_tuple=True)
|
|
lengths = block_col_sz[h_idx, c_idx].to(dtype_i) # [N]
|
|
base = head_offset[h_idx] + col_offset[h_idx, c_idx] # [N]
|
|
|
|
# 4) Expand variable-length column blocks into token-level indices
|
|
cum = torch.cumsum(lengths, 0)
|
|
starts = torch.repeat_interleave(cum - lengths, lengths) # [total]
|
|
offsets_within = torch.arange(cum[-1], device=device) - starts
|
|
kv_indices = torch.repeat_interleave(base, lengths) + offsets_within
|
|
|
|
return kv_indptr.to(dtype=dtype_i, device=device), kv_indices.to(
|
|
dtype=dtype_i, device=device
|
|
)
|
|
|
|
kv_indptr, kv_indices = _block_mask_map_to_expanded_indices(
|
|
block_mask_map, block_col_sz
|
|
)
|
|
kv_indptr_host = kv_indptr.to("cpu", non_blocking=non_blocking)
|
|
kv_indices_host = kv_indices.to("cpu", non_blocking=non_blocking)
|
|
|
|
self._qo_indptr = qo_indptr.to(self.device, non_blocking=non_blocking)
|
|
self._paged_kv_indptr_buf = kv_indptr.to(self.device, non_blocking=non_blocking)
|
|
self._paged_kv_indices_buf = kv_indices.to(
|
|
self.device, non_blocking=non_blocking
|
|
)
|
|
self._paged_kv_last_page_len = last_block_len.to(
|
|
self.device, non_blocking=non_blocking
|
|
)
|
|
torch.cuda.synchronize() # for non-blocking copy
|
|
self._mask_mode = MaskMode.CAUSAL.value if causal else MaskMode.NON_CAUSAL.value
|
|
|
|
# Sanity check
|
|
assert num_qo_heads % num_kv_heads == 0, (
|
|
"num_qo_heads must be a multiple of num_kv_heads"
|
|
)
|
|
assert num_blocks_row * num_kv_heads + 1 == kv_indptr_host.shape[0]
|
|
assert kv_indptr_host[-1].item() == kv_indices_host.shape[0], (
|
|
f"{kv_indptr_host[-1].item()} != {kv_indices_host.shape[0]}"
|
|
)
|
|
assert num_kv_heads == block_mask_map.shape[0]
|
|
assert num_kv_heads == block_row_sz.shape[0]
|
|
assert num_kv_heads == block_col_sz.shape[0]
|
|
assert num_blocks_row == block_mask_map.shape[1]
|
|
assert num_blocks_col == block_mask_map.shape[2]
|
|
|
|
if self._backend == "auto":
|
|
self._backend = determine_attention_backend(
|
|
self.device,
|
|
PosEncodingMode[pos_encoding_mode].value,
|
|
use_fp16_qk_reduction,
|
|
self._mask_mode == MaskMode.CUSTOM.value, # use_custom_mask
|
|
q_data_type,
|
|
kv_data_type,
|
|
)
|
|
|
|
get_module_args = (
|
|
q_data_type,
|
|
kv_data_type,
|
|
self._o_dtype,
|
|
kv_indptr_host.dtype,
|
|
head_dim, # head_dim_qk
|
|
head_dim, # head_dim_vo
|
|
PosEncodingMode[pos_encoding_mode].value,
|
|
False, # use_sliding_window
|
|
logits_soft_cap > 0, # use_logits_soft_cap
|
|
use_fp16_qk_reduction,
|
|
)
|
|
self._cached_module = get_batch_prefill_module(self._backend, *get_module_args)
|
|
|
|
kv_lens_arr_host = kv_indptr_host[1:] - kv_indptr_host[:-1] # page_size == 1
|
|
self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_(
|
|
kv_lens_arr_host,
|
|
)
|
|
|
|
if self._backend == "fa3":
|
|
if self._vector_sparse_indptr_buffer.numel() <= kv_indptr.numel():
|
|
raise ValueError(
|
|
"_vector_sparse_indptr_buffer is not large enough. Please increase the buffer size."
|
|
)
|
|
self._vector_sparse_indptr_buffer[: len(kv_indptr)].copy_(
|
|
kv_indptr, non_blocking=non_blocking
|
|
)
|
|
|
|
self._plan_info = self._cached_module.plan(
|
|
self._float_workspace_buffer,
|
|
self._int_workspace_buffer,
|
|
self._pin_memory_int_workspace_buffer,
|
|
qo_indptr_host,
|
|
kv_indptr_host,
|
|
kv_lens_arr_host,
|
|
qo_indptr_host[-1].item(), # total_num_rows
|
|
num_blocks_row * num_kv_heads, # batch_size
|
|
num_qo_heads // num_kv_heads, # num_qo_heads (gqa_group_size)
|
|
1, # num_kv_heads,
|
|
1, # page_size
|
|
False, # is_cuda_graph_enabled,
|
|
head_dim,
|
|
head_dim,
|
|
causal,
|
|
)
|
|
|
|
self._pos_encoding_mode = pos_encoding_mode
|
|
self._use_fp16_qk_reduction = use_fp16_qk_reduction
|
|
self._logits_soft_cap = logits_soft_cap
|
|
self._sm_scale = sm_scale
|
|
self._rope_scale = rope_scale
|
|
self._rope_theta = rope_theta
|
|
self._num_kv_heads = num_kv_heads
|
|
self._gqa_group_size = num_qo_heads // num_kv_heads
|
|
|
|
def forward(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
pos_encoding_mode: str = "NONE",
|
|
use_fp16_qk_reduction: bool = False,
|
|
logits_soft_cap: Optional[float] = None,
|
|
sm_scale: Optional[float] = None,
|
|
rope_scale: Optional[float] = None,
|
|
rope_theta: Optional[float] = None,
|
|
) -> torch.Tensor:
|
|
r"""Warning: This method is deprecated, please use :meth:`run` instead."""
|
|
self._pos_encoding_mode = pos_encoding_mode
|
|
self._use_fp16_qk_reduction = use_fp16_qk_reduction
|
|
self._logits_soft_cap = logits_soft_cap
|
|
self._sm_scale = sm_scale
|
|
self._rope_scale = rope_scale
|
|
self._rope_theta = rope_theta
|
|
return self.run(q, k, v)
|
|
|
|
def run(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
out: Optional[torch.Tensor] = None,
|
|
lse: Optional[torch.Tensor] = None,
|
|
return_lse: bool = False,
|
|
enable_pdl: Optional[bool] = None,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
r"""Compute block-sparse attention between Q/K/V tensors.
|
|
|
|
Parameters
|
|
----------
|
|
q : torch.Tensor
|
|
The query tensor with shape ``(num_qo_heads, qo_len, head_dim)``.
|
|
k : torch.Tensor
|
|
The key tensor with shape ``(num_kv_heads, kv_len, head_dim)``.
|
|
v : torch.Tensor
|
|
The value tensor with shape ``(num_kv_heads, kv_len, head_dim)``.
|
|
out : Optional[torch.Tensor]
|
|
The output tensor, if not provided, will be allocated internally.
|
|
lse : Optional[torch.Tensor]
|
|
The log-sum-exp of attention logits, if not provided, will be allocated internally.
|
|
return_lse : bool
|
|
Whether to return the log-sum-exp of attention logits
|
|
enable_pdl : bool
|
|
Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization
|
|
Only supported for >= sm90, and currently only for FA2 and CUDA core decode.
|
|
|
|
Returns
|
|
-------
|
|
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
|
If :attr:`return_lse` is ``False``, the attention output, shape: ``[M, num_qo_heads, head_dim]``.
|
|
If :attr:`return_lse` is ``True``, a tuple of two tensors:
|
|
|
|
* The attention output, shape: ``[M, num_qo_heads, head_dim]``.
|
|
* The logsumexp of attention output, shape: ``[M, num_qo_heads]``.
|
|
"""
|
|
# NOTE(Zihao): defer import of einops
|
|
import einops
|
|
|
|
if enable_pdl is None:
|
|
enable_pdl = device_support_pdl(q.device)
|
|
|
|
pos_encoding_mode = self._pos_encoding_mode
|
|
logits_soft_cap = self._logits_soft_cap
|
|
sm_scale = self._sm_scale
|
|
rope_scale = self._rope_scale
|
|
rope_theta = self._rope_theta
|
|
_check_pos_encoding_mode(pos_encoding_mode)
|
|
if logits_soft_cap is None:
|
|
logits_soft_cap = 0.0
|
|
if sm_scale is None:
|
|
sm_scale = 1.0 / math.sqrt(q.size(-1))
|
|
if rope_scale is None:
|
|
rope_scale = 1.0
|
|
if rope_theta is None:
|
|
rope_theta = 1e4
|
|
|
|
# reshape to pad num_kv_heads into seq_len
|
|
# input [num_qo_heads, qo_len, head_dim]
|
|
# kernel layout is NHD -> [qo_len * num_kv_heads, gqa_group_size, head_dim]
|
|
q = einops.rearrange(
|
|
q,
|
|
"(num_kv_heads gqa_group_size) qo_len head_dim -> (num_kv_heads qo_len) gqa_group_size head_dim",
|
|
num_kv_heads=self._num_kv_heads,
|
|
).contiguous()
|
|
# HND -> [kv_len * num_kv_heads (num_pages), 1 (page_size), 1 (new_num_kv_heads), head_dim]
|
|
k = einops.rearrange(
|
|
k,
|
|
"num_kv_heads kv_len head_dim -> (num_kv_heads kv_len) 1 1 head_dim",
|
|
).contiguous()
|
|
v = einops.rearrange(
|
|
v,
|
|
"num_kv_heads kv_len head_dim -> (num_kv_heads kv_len) 1 1 head_dim",
|
|
).contiguous()
|
|
|
|
stride_block = k.stride(0)
|
|
stride_n = k.stride(1)
|
|
|
|
if return_lse:
|
|
if lse is None:
|
|
lse = torch.empty(
|
|
(q.size(0), q.size(1)), dtype=torch.float32, device=q.device
|
|
)
|
|
else:
|
|
check_shape_dtype_device(
|
|
lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
|
|
)
|
|
|
|
if out is None:
|
|
out = torch.empty_like(q, dtype=self._o_dtype)
|
|
else:
|
|
check_shape_dtype_device(out, q.shape, self._o_dtype, q.device, "out")
|
|
|
|
if self._backend == "fa3":
|
|
if (
|
|
self._vector_sparse_indices_buffer.numel()
|
|
<= self._paged_kv_indices_buf.numel()
|
|
):
|
|
raise ValueError(
|
|
"_vector_sparse_indices_buffer is not large enough. Please increase the buffer size."
|
|
)
|
|
|
|
sparse_indices = block_sparse_indices_to_vector_sparse_offsets(
|
|
self._paged_kv_indices_buf,
|
|
self._paged_kv_indptr_buf,
|
|
self._vector_sparse_indices_buffer, # output
|
|
self._vector_sparse_indptr_buffer,
|
|
self._kv_lens_buffer,
|
|
stride_block // stride_n,
|
|
1, # stride_n // stride_n
|
|
1, # block_size
|
|
)
|
|
sparse_indptr = self._vector_sparse_indptr_buffer
|
|
else:
|
|
sparse_indices = self._paged_kv_indices_buf
|
|
sparse_indptr = self._paged_kv_indptr_buf
|
|
|
|
self._cached_module.paged_run(
|
|
self._float_workspace_buffer,
|
|
self._int_workspace_buffer,
|
|
self._plan_info,
|
|
q,
|
|
k,
|
|
v,
|
|
self._qo_indptr,
|
|
sparse_indptr,
|
|
sparse_indices,
|
|
self._paged_kv_last_page_len,
|
|
out,
|
|
lse,
|
|
self._mask_mode,
|
|
TensorLayout[self._kv_layout].value,
|
|
-1, # window_left
|
|
enable_pdl,
|
|
# ADDITIONAL_FUNC_PARAMS
|
|
# Not supported yet
|
|
None, # packed_mask_buf
|
|
None, # mask_indptr_buf
|
|
None, # alibi_slopes_buf
|
|
None,
|
|
None,
|
|
None,
|
|
logits_soft_cap,
|
|
sm_scale,
|
|
None, # scale_q
|
|
None, # scale_k
|
|
None, # scale_v
|
|
rope_scale,
|
|
rope_theta,
|
|
0, # token_pos_in_items_len
|
|
self._workspace_size,
|
|
)
|
|
|
|
# [qo_len * num_kv_heads, gqa_group_size, head_dim] -> HND
|
|
out = einops.rearrange(
|
|
out,
|
|
"(num_kv_heads qo_len) gqa_group_size head_dim -> (num_kv_heads gqa_group_size) qo_len head_dim",
|
|
num_kv_heads=self._num_kv_heads,
|
|
).contiguous()
|
|
|
|
if return_lse:
|
|
lse = einops.rearrange(
|
|
lse,
|
|
"(num_kv_heads qo_len) gqa_group_size -> (num_kv_heads gqa_group_size) qo_len",
|
|
num_kv_heads=self._num_kv_heads,
|
|
).contiguous()
|
|
|
|
return (out, lse) if return_lse else out
|