sglang_v0.5.2/flashinfer_0.3.1/flashinfer/attention.py

189 lines
5.8 KiB
Python

"""
Copyright (c) 2025 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import functools
import math
from typing import Optional, Tuple, Union
import torch
from .jit import gen_batch_attention_module
from .utils import (
MaskMode,
PosEncodingMode,
TensorLayout,
_check_kv_layout,
_unpack_paged_kv_cache,
)
@functools.cache
def get_holistic_attention_module(*args):
return gen_batch_attention_module(*args).build_and_load()
class BatchAttention:
def __init__(
self,
kv_layout: str = "NHD",
device: str = "cuda",
):
_check_kv_layout(kv_layout)
self._kv_layout = kv_layout
self.float_workspace_buffer = torch.empty(
384 * 1024 * 1024,
dtype=torch.uint8,
device=torch.device(device),
)
self.int_workspace_buffer = torch.empty(
8 * 1024 * 1024,
dtype=torch.uint8,
device=torch.device(device),
)
self.page_locked_int_workspace_buffer = torch.empty(
8 * 1024 * 1024,
dtype=torch.uint8,
device=torch.device("cpu"),
pin_memory=True,
)
def plan(
self,
qo_indptr: torch.Tensor,
kv_indptr: torch.Tensor,
kv_indices: torch.Tensor,
kv_len_arr: torch.Tensor,
num_qo_heads: int,
num_kv_heads: int,
head_dim_qk: int,
head_dim_vo: int,
page_size: int,
causal: bool = False,
sm_scale: float = None,
logits_soft_cap: Optional[float] = None,
q_data_type: torch.dtype = torch.bfloat16,
kv_data_type: torch.dtype = torch.bfloat16,
use_profiler: bool = False,
) -> None:
if logits_soft_cap is None:
logits_soft_cap = 0.0
self._logits_soft_cap = logits_soft_cap
# get jit module
get_module_args = (
q_data_type,
kv_data_type,
q_data_type,
kv_indptr.dtype,
head_dim_qk,
head_dim_vo,
PosEncodingMode["NONE"].value,
logits_soft_cap > 0.0,
use_profiler, # different compiler path
)
self.module = get_holistic_attention_module(*get_module_args)
qo_indptr_host = qo_indptr.to(torch.device("cpu"), non_blocking=True)
kv_indptr_host = kv_indptr.to(torch.device("cpu"), non_blocking=True)
kv_len_arr_host = kv_len_arr.to(torch.device("cpu"), non_blocking=True)
torch.cuda.synchronize()
batch_size = kv_len_arr.shape[0]
self._page_size = page_size
self._sm_scale = sm_scale
self._mask_mode = MaskMode.CAUSAL.value if causal else MaskMode.NON_CAUSAL.value
self._num_qo_heads = num_qo_heads
self._num_kv_heads = num_kv_heads
self._page_size = page_size
self._sm_scale = sm_scale
self._use_profiler = use_profiler
# No addtional buf allocated for CUDA graph tensor
# Allocate outside FlashInfer
self._kv_indices = kv_indices
self._plan_info = self.module.plan(
self.float_workspace_buffer,
self.int_workspace_buffer,
self.page_locked_int_workspace_buffer,
qo_indptr_host,
kv_indptr_host,
kv_len_arr_host,
batch_size,
num_qo_heads,
num_kv_heads,
head_dim_vo,
causal,
)
def run(
self,
q: torch.Tensor,
kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
logits_soft_cap: float = 0.0,
profiler_buffer: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if profiler_buffer is None:
if self._use_profiler:
raise ValueError(
"Profiler is enabled, profiler_buffer must be provided"
)
if logits_soft_cap > 0.0 and self._logits_soft_cap <= 0.0:
raise ValueError(
"logits_soft_cap used in kernel run but not provided in plan(). This will cause template deduction error."
)
k_cache, v_cache = _unpack_paged_kv_cache(kv_cache, self._kv_layout)
if out is None:
out = torch.empty_like(q)
if lse is None:
# lse shape: [batch_size, num_qo_heads]
lse = torch.empty(
q.shape[0], q.shape[1], device=q.device, dtype=torch.float32
)
head_dim_qk = q.shape[2]
if self._sm_scale is None:
self._sm_scale = 1.0 / math.sqrt(head_dim_qk)
# profiler_buffer is optional
profiler_args = (profiler_buffer,) if self._use_profiler else ()
self.module.run(
self.float_workspace_buffer,
self.int_workspace_buffer,
self._plan_info,
q,
k_cache,
v_cache,
self._kv_indices,
out,
lse,
self._mask_mode,
TensorLayout[self._kv_layout].value,
self._num_qo_heads,
self._num_kv_heads,
self._page_size,
self._sm_scale,
logits_soft_cap,
# ADDITIONAL_FUNC_PARAMS
# PROFILER_FUNC_PARAMS
*profiler_args,
)
return out, lse