sglang_v0.5.2/flashinfer_0.3.1/flashinfer/mla.py

467 lines
18 KiB
Python

"""
Copyright (c) 2023 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import functools
from typing import Literal, Optional, Tuple, Union, overload
import torch
from .jit import JitSpec
from .jit import env as jit_env
from .jit import (
gen_batch_mla_module,
gen_jit_spec,
current_compilation_context,
)
from .utils import MaskMode, check_shape_dtype_device, determine_mla_backend
def _check_cutlass_shape(q_nope_pe, ckv_kpe_cache, kv_len, page_table):
if q_nope_pe.ndim != 3:
raise ValueError(f"Expected q_nope_pe.ndim == 3, got {q_nope_pe.ndim}")
if ckv_kpe_cache.ndim != 3:
raise ValueError(f"Expected ckv_kpe_cache.ndim == 3, got {ckv_kpe_cache.ndim}")
if kv_len.ndim != 1:
raise ValueError(f"Expected kv_len.ndim == 1, got {kv_len.ndim}")
if page_table.ndim != 2:
raise ValueError(f"Expected page_table.ndim == 2, got {page_table.ndim}")
B_q, H, D_q = q_nope_pe.shape
D_ckv = ckv_kpe_cache.shape[2]
if H != 128:
raise ValueError(f"Expected 128 heads for q_nope_pe, got {H}")
if D_q != D_ckv or D_q != 576:
raise ValueError(
f"Expected head dim 576 for q_nope_pe and ckv_kpe_cache, got {D_q} and {D_ckv}"
)
B_block_table, block_num = page_table.shape
block_size = ckv_kpe_cache.shape[1]
if B_q != B_block_table:
raise ValueError(
f"Expected batch size {B_q} for q_nope_pe and block_table, got {B_q} and {B_block_table}"
)
if block_num % (128 / block_size) != 0:
raise ValueError(
f"Expected block_num % (128 / block_size) == 0, got {block_num=} and {block_size=}"
)
def gen_mla_module() -> JitSpec:
nvcc_flags = current_compilation_context.get_nvcc_flags_list(
supported_major_versions=[10, 11]
)
return gen_jit_spec(
"mla",
[
jit_env.FLASHINFER_CSRC_DIR / "cutlass_mla.cu",
jit_env.FLASHINFER_CSRC_DIR / "flashinfer_mla_ops.cu",
],
extra_cuda_cflags=nvcc_flags,
)
@functools.cache
def get_mla_module():
return gen_mla_module().build_and_load()
@functools.cache
def get_batch_mla_module(backend, *args):
return gen_batch_mla_module(backend, *args).build_and_load()
class BatchMLAPagedAttentionWrapper:
r"""Wrapper class for MLA (`Multi-head Latent Attention <https://arxiv.org/abs/2405.04434>`_)
PagedAttention on DeepSeek models. This kernel can be used in decode, and incremental prefill
and should be used together with `Matrix Absorption trick
<https://github.com/madsys-dev/deepseekv2-profile/blob/main/workspace/blog/optimizing-mla.md>`_:
where :math:`W_{UQ}` is absorbed with :math:`W_{UK}`, and :math:`W_{UV}` is
absorbed with :math:`W_{O}`.
For MLA attention without Matrix Absorption (``head_dim_qk=192`` and ``head_dim_vo=128``, which is
used in prefilling self-attention stage), please use
:class:`flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper`.
More information about The Paged KV-Cache layout in MLA is explained in our tutorial
:ref:`MLA Page Layout <mla-page-layout>`.
For more details about the MLA computation, Matrix Absorption and FlashInfer's MLA implementation,
please refer to our `blog post <http://flashinfer.ai/2025/02/10/flashinfer-deepseek-mla.html>`_.
Example
-------
>>> import torch
>>> import flashinfer
>>> num_local_heads = 128
>>> batch_size = 114
>>> head_dim_ckv = 512
>>> head_dim_kpe = 64
>>> page_size = 1
>>> mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
... torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0),
... backend="fa2"
... )
>>> q_indptr = torch.arange(0, batch_size + 1).to(0).int() # for decode, each query length is 1
>>> kv_lens = torch.full((batch_size,), 999, dtype=torch.int32).to(0)
>>> kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * 999
>>> kv_indices = torch.arange(0, batch_size * 999).to(0).int()
>>> q_nope = torch.randn(
... batch_size * 1, num_local_heads, head_dim_ckv, dtype=torch.bfloat16, device="cuda"
... )
>>> q_pe = torch.zeros(
... batch_size * 1, num_local_heads, head_dim_kpe, dtype=torch.bfloat16, device="cuda"
... )
>>> ckv = torch.randn(
... batch_size * 999, 1, head_dim_ckv, dtype=torch.bfloat16, device="cuda"
... )
>>> kpe = torch.zeros(
... batch_size * 999, 1, head_dim_kpe, dtype=torch.bfloat16, device="cuda"
... )
>>> sm_scale = 1.0 / ((128 + 64) ** 0.5) # use head dimension before matrix absorption
>>> mla_wrapper.plan(
... q_indptr,
... kv_indptr,
... kv_indices,
... kv_lens,
... num_local_heads,
... head_dim_ckv,
... head_dim_kpe,
... page_size,
... False, # causal
... sm_scale,
... q_nope.dtype,
... ckv.dtype,
... )
>>> o = mla_wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=False)
>>> o.shape
torch.Size([114, 128, 512])
"""
def __init__(
self,
float_workspace_buffer: torch.Tensor,
use_cuda_graph: bool = False,
qo_indptr: Optional[torch.Tensor] = None,
kv_indptr: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
kv_len_arr: Optional[torch.Tensor] = None,
backend: str = "auto",
) -> None:
r"""Constructor for BatchMLAPagedAttentionWrapper.
Parameters
----------
float_workspace_buffer : torch.Tensor
The user reserved workspace buffer used to store intermediate attention results in
split-k algorithm. The recommended size is 128MB, the device of the workspace buffer
should be the same as the device of the input tensors.
use_cuda_graph : bool, optional
Whether to enable CUDA graph capture for the prefill kernels, if enabled, the
auxiliary data structures will be stored in provided buffers. The ``batch_size``
cannot change during the lifecycle of this wrapper when CUDAGraph is enabled.
qo_indptr_buf : Optional[torch.Tensor]
The user reserved buffer to store the ``qo_indptr`` array, the size of the buffer
should be ``[batch_size + 1]``.
This argument is only effective when ``use_cuda_graph`` is ``True``.
kv_indptr_buf : Optional[torch.Tensor]
The user reserved buffer to store the ``kv_indptr`` array, the size of the buffer
should be ``[batch_size + 1]``.
This argument is only effective when ``use_cuda_graph`` is ``True``.
kv_indices_buf : Optional[torch.Tensor]
The user reserved buffer to store the ``kv_indices`` array.
This argument is only effective when ``use_cuda_graph`` is ``True``.
kv_len_arr_buf : Optional[torch.Tensor]
The user reserved buffer to store the ``kv_len_arr`` array, the size of the buffer
should be ``[batch_size]``.
This argument is only effective when ``use_cuda_graph`` is ``True``.
backend : str
The implementation backend, could be ``auto``/``fa2`` or ``fa3``. Defaults to ``auto``.
If set to ``auto``, the function will automatically choose the backend based on the
device architecture and kernel availability. If ``cutlass`` is provided, the MLA
kernels will be generated by CUTLASS and only float_workspace_buffer is required and
other arguments are ignored.
"""
self._float_workspace_buffer = float_workspace_buffer
self.device = float_workspace_buffer.device
if backend == "cutlass":
self._backend = backend
return
self._int_workspace_buffer = torch.empty(
(8 * 1024 * 1024,), dtype=torch.uint8, device=self.device
)
self._pin_memory_int_workspace_buffer = torch.empty(
self._int_workspace_buffer.shape,
dtype=self._int_workspace_buffer.dtype,
pin_memory=True,
device="cpu",
)
self._use_cuda_graph = use_cuda_graph
self._qo_indptr_buf = qo_indptr
self._kv_indptr_buf = kv_indptr
self._kv_indices_buf = kv_indices
self._kv_len_arr_buf = kv_len_arr
if backend == "auto":
self._backend = determine_mla_backend(self.device)
else:
self._backend = backend
def plan(
self,
qo_indptr: torch.Tensor,
kv_indptr: torch.Tensor,
kv_indices: torch.Tensor,
kv_len_arr: torch.Tensor,
num_heads: int,
head_dim_ckv: int,
head_dim_kpe: int,
page_size: int,
causal: bool,
sm_scale: float,
q_data_type: torch.dtype,
kv_data_type: torch.dtype,
use_profiler: bool = False,
) -> None:
r"""Plan the MLA attention computation.
Parameters
----------
qo_indptr : torch.IntTensor
The indptr of the query/output tensor, shape: ``[batch_size + 1]``.
For decoding attention, the length of each query is 1, and the content
of the tensor should be ``[0, 1, 2, ..., batch_size]``.
kv_indptr : torch.IntTensor
The indptr of the paged kv-cache, shape: ``[batch_size + 1]``.
kv_indices : torch.IntTensor
The page indices of the paged kv-cache, shape: ``[kv_indptr[-1]]`` or larger.
kv_len_arr : torch.IntTensor
The query length of each request, shape: ``[batch_size]``.
num_heads : int
The number of heads in query/output tensor.
head_dim_ckv : int
The head dimension of compressed-kv.
head_dim_kpe : int
The head dimension for rope k-cache.
page_size : int
The page size of the paged kv-cache.
causal : bool
Whether to use causal attention.
sm_scale : float
The scale factor for softmax operation.
q_data_type : torch.dtype
The data type of the query tensor.
kv_data_type : torch.dtype
The data type of the kv-cache tensor.
use_profiler : bool, optional
Whether to enable intra-kernel profiler, default is False.
"""
for tensor, name in [
(kv_len_arr, "kv_len_arr"),
(kv_indptr, "kv_indptr"),
(qo_indptr, "qo_indptr"),
(kv_indices, "kv_indices"),
]:
if tensor.dtype != torch.int32:
raise ValueError(
f"Expected {name}.dtype == torch.int32, got {tensor.dtype}"
)
self._cached_module = get_batch_mla_module(
self._backend,
q_data_type,
kv_data_type,
q_data_type,
qo_indptr.dtype,
head_dim_ckv,
head_dim_kpe,
use_profiler,
)
qo_indptr_host = qo_indptr.to("cpu")
kv_indptr_host = kv_indptr.to("cpu")
kv_len_arr_host = kv_len_arr.to("cpu")
if self._use_cuda_graph:
self._qo_indptr_buf.copy_(qo_indptr, non_blocking=True)
self._kv_indptr_buf.copy_(kv_indptr, non_blocking=True)
self._kv_indices_buf[: len(kv_indices)].copy_(kv_indices, non_blocking=True)
self._kv_len_arr_buf.copy_(kv_len_arr, non_blocking=True)
else:
self._qo_indptr_buf = qo_indptr.to(self.device, non_blocking=True)
self._kv_indptr_buf = kv_indptr.to(self.device, non_blocking=True)
self._kv_indices_buf = kv_indices.to(self.device, non_blocking=True)
self._kv_len_arr_buf = kv_len_arr.to(self.device, non_blocking=True)
self._causal = causal
self._page_size = page_size
self._sm_scale = sm_scale
self._use_profiler = use_profiler
self._plan_info = self._cached_module.plan.default(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_host,
kv_indptr_host,
kv_len_arr_host,
num_heads,
head_dim_ckv, # head_dim_o
causal,
)
@overload
def run(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
ckv_cache: torch.Tensor,
kpe_cache: torch.Tensor,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
return_lse: Literal[False] = False,
profiler_buffer: Optional[torch.Tensor] = None,
kv_len: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
) -> torch.Tensor: ...
@overload
def run(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
ckv_cache: torch.Tensor,
kpe_cache: torch.Tensor,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
return_lse: Literal[True] = True,
profiler_buffer: Optional[torch.Tensor] = None,
kv_len: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ...
def run(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
ckv_cache: torch.Tensor,
kpe_cache: torch.Tensor,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
return_lse: bool = False,
profiler_buffer: Optional[torch.Tensor] = None,
kv_len: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
r"""Run the MLA attention computation.
Parameters
----------
q_nope : torch.Tensor
The query tensor without rope, shape: ``[batch_size, num_heads, head_dim_ckv]``.
q_pe : torch.Tensor
The rope part of the query tensor, shape: ``[batch_size, num_heads, head_dim_kpe]``.
ckv_cache : torch.Tensor
The compressed kv-cache tensor (without rope), shape: ``[num_pages, page_size, head_dim_ckv]``.
``head_dim_ckv`` is 512 in DeepSeek v2/v3 models.
kpe_cache : torch.Tensor
The rope part of the kv-cache tensor, shape: ``[num_pages, page_size, head_dim_kpe]``.
``head_dim_kpe`` is 64 in DeepSeek v2/v3 models.
out : Optional[torch.Tensor]
The output tensor, if not provided, will be allocated internally.
lse : Optional[torch.Tensor]
The log-sum-exp of attention logits, if not provided, will be allocated internally.
return_lse : bool, optional
Whether to return the log-sum-exp value, default is False.
profiler_buffer : Optional[torch.Tensor]
The buffer to store the profiler data.
kv_len : Optional[torch.Tensor]
The query length of each request, shape: ``[batch_size]``. Required when ``backend`` is ``cutlass``.
page_table : Optional[torch.Tensor]
The page table of the paged kv-cache, shape: ``[batch_size, num_pages]``. Required when ``backend`` is ``cutlass``.
"""
if self._backend == "cutlass":
if return_lse:
raise ValueError("return_lse does not support cutlass backend for now.")
if profiler_buffer is not None:
raise ValueError(
"profiler_buffer does not support cutlass backend for now."
)
self._cached_module = get_mla_module()
if out is None:
out = torch.empty_like(q_nope)
else:
check_shape_dtype_device(
out, q_nope.shape, q_nope.dtype, q_nope.device, "out"
)
q_nope_pe = torch.cat([q_nope, q_pe], dim=-1)
ckv_kpe_cache = torch.cat([ckv_cache, kpe_cache], dim=-1)
_check_cutlass_shape(q_nope_pe, ckv_kpe_cache, kv_len, page_table)
lse = torch.empty(0, dtype=torch.float32, device=self.device)
self._cached_module.cutlass_mla_paged_attention.default(
self._float_workspace_buffer,
out,
lse,
q_nope_pe,
ckv_kpe_cache,
kv_len,
page_table,
)
return out
if profiler_buffer is None:
if self._use_profiler:
raise ValueError(
"Profiler is enabled, profiler_buffer must be provided"
)
num_heads = q_nope.shape[1]
page_size = self._page_size
sm_scale = self._sm_scale
causal = self._causal
mask_mode = MaskMode.CAUSAL.value if causal else MaskMode.NON_CAUSAL.value
device = self.device
if out is None:
out = torch.empty_like(q_nope)
else:
check_shape_dtype_device(
out, q_nope.shape, q_nope.dtype, q_nope.device, "out"
)
if return_lse:
if lse is None:
lse = torch.empty(q_nope.shape[:2], dtype=torch.float32, device=device)
else:
check_shape_dtype_device(
lse, q_nope.shape[:2], torch.float32, q_nope.device, "lse"
)
profiler_args = (profiler_buffer,) if self._use_profiler else ()
self._cached_module.run.default(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._plan_info,
q_nope,
q_pe,
ckv_cache,
kpe_cache,
self._kv_indices_buf,
out,
lse,
mask_mode,
num_heads,
page_size,
sm_scale,
*profiler_args,
)
return (out, lse) if return_lse else out