sglang_v0.5.2/flashinfer_0.3.1/flashinfer/prefill.py

3456 lines
128 KiB
Python
Executable File

"""
Copyright (c) 2023 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import functools
import logging
import math
from types import SimpleNamespace
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload
import torch
from .jit import (
gen_batch_prefill_module,
gen_customize_batch_prefill_module,
gen_fmha_cutlass_sm100a_module,
gen_single_prefill_module,
get_batch_prefill_uri,
get_single_prefill_uri,
setup_cubin_loader,
gen_trtllm_gen_fmha_module,
)
from .cudnn import cudnn_batch_prefill_with_kv_cache
from .page import block_sparse_indices_to_vector_sparse_offsets, get_seq_lens
from .quantization import packbits, segment_packbits
from .utils import (
FP4Tensor,
MaskMode,
PosEncodingMode,
TensorLayout,
_check_cached_qkv_data_type,
_check_kv_layout,
_check_pos_encoding_mode,
check_shape_dtype_device,
_get_cache_alibi_slopes_buf,
_get_cache_buf,
_unpack_paged_kv_cache,
canonicalize_torch_dtype,
determine_attention_backend,
device_support_pdl,
get_device_sm_count,
is_float8,
is_sm100a_supported,
is_sm110a_supported,
register_custom_op,
register_fake_op,
ceil_div,
round_up,
)
@functools.cache
def get_fmha_module(
dtype_q: torch.dtype,
dtype_kv: torch.dtype,
dtype_o: torch.dtype,
dtype_idx: torch.dtype,
head_dim_qk: int,
head_dim_vo: int,
pos_encoding_mode: int,
use_sliding_window: bool,
use_logits_soft_cap: bool,
device: torch.device,
use_fp16_qk_reduction: bool = False,
):
if is_sm100a_supported(device) or is_sm110a_supported(device):
return gen_fmha_cutlass_sm100a_module(
dtype_q,
dtype_kv,
dtype_o,
dtype_idx,
head_dim_qk,
head_dim_vo,
pos_encoding_mode,
use_sliding_window,
use_logits_soft_cap,
).build_and_load()
else:
raise ValueError("SM100A is not supported on this device")
def make_hashable_cache(func):
"""
Decorator that converts unhashable arguments (like lists) to hashable ones (tuples)
before applying functools.cache.
"""
@functools.cache
def cached_wrapper(*args, **kwargs):
return func(*args, **kwargs)
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Convert unhashable arguments to hashable ones
hashable_args = []
for arg in args:
if isinstance(arg, list):
hashable_args.append(tuple(arg))
else:
hashable_args.append(arg)
hashable_kwargs = {}
for key, value in kwargs.items():
if isinstance(value, list):
hashable_kwargs[key] = tuple(value)
else:
hashable_kwargs[key] = value
return cached_wrapper(*hashable_args, **hashable_kwargs)
return wrapper
@make_hashable_cache
def get_customize_batch_prefill_module(
backend: str,
uri: str,
dtype_q: torch.dtype,
dtype_kv: torch.dtype,
dtype_o: torch.dtype,
idtype: torch.dtype,
head_dim_qk: int,
head_dim_vo: int,
additional_tensor_names: List[str],
additional_tensor_dtypes: List[str],
additional_scalar_names: List[str],
additional_scalar_dtypes: List[str],
variant_name: str,
variant_decl: str,
pos_encoding_mode: int = 0,
use_sliding_window: bool = False,
use_logits_soft_cap: bool = False,
use_fp16_qk_reduction: bool = False,
fp8_enabled: bool = False,
):
return gen_customize_batch_prefill_module(
backend,
uri,
dtype_q,
dtype_kv,
dtype_o,
idtype,
head_dim_qk,
head_dim_vo,
additional_tensor_names,
additional_tensor_dtypes,
additional_scalar_names,
additional_scalar_dtypes,
variant_name,
variant_decl,
pos_encoding_mode,
use_sliding_window,
use_logits_soft_cap,
use_fp16_qk_reduction,
fp8_enabled,
).build_and_load()
@functools.cache
def get_trtllm_gen_prefill_module():
mod = gen_trtllm_gen_fmha_module()
op = mod.build_and_load()
setup_cubin_loader(mod.get_library_path())
def _paged_run(
query: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
workspace_buffer: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_q_len: int,
max_kv_len: int,
bmm1_scale: float,
bmm2_scale: float,
batch_size: int,
cum_seq_lens_q: torch.Tensor,
cum_seq_lens_kv: torch.Tensor,
enable_pdl: bool,
workspace_size: int,
window_left: int = -1,
out: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
) -> torch.Tensor:
sm_count = get_device_sm_count(query.device)
if out is None:
out = torch.empty_like(query)
op.trtllm_paged_attention_context(
out,
None, # fp4 output not supported in wrapper api yet.
query,
k_cache,
v_cache,
workspace_buffer,
block_tables,
seq_lens,
max_q_len,
max_kv_len,
bmm1_scale,
bmm2_scale,
-1, # o_sf_scale
-1, # o_sf_vec_size
0, # o_sf_start_index
batch_size,
window_left,
cum_seq_lens_q,
cum_seq_lens_kv,
sm_count,
enable_pdl,
workspace_size,
sinks,
)
return out
def _ragged_run(*args, **kwargs):
# TODO(Zihao): trtllm-gen backend already supports variable length attention,
# but not integrated into flashinfer yet.
raise NotImplementedError(
"Variable length is not implemented for trtllm-gen backend yet."
)
def _plan(*args, **kwargs):
pass
return SimpleNamespace(
paged_run=_paged_run,
ragged_run=_ragged_run,
plan=_plan,
)
@functools.cache
def get_single_prefill_module(backend, *args):
uri = get_single_prefill_uri(backend, *args)
module = gen_single_prefill_module(backend, *args).build_and_load()
run_func = module.run.default
# torch library for single_prefill_with_kv_cache
@register_custom_op(
f"flashinfer::{uri}_run", mutates_args=("tmp", "o", "maybe_lse")
)
def run_single_prefill(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
tmp: torch.Tensor,
o: torch.Tensor,
maybe_lse: Optional[torch.Tensor],
mask_mode: int,
layout: int,
window_left: int,
maybe_packed_custom_mask: Optional[torch.Tensor],
maybe_alibi_slopes: Optional[torch.Tensor],
logits_soft_cap: float,
sm_scale: float,
scale_q: Optional[torch.Tensor],
scale_k: Optional[torch.Tensor],
scale_v: Optional[torch.Tensor],
rope_scale: float,
rope_theta: float,
) -> None:
if backend == "fa3":
if not is_float8(q):
run_func(
q,
k,
v,
tmp,
o,
maybe_lse,
mask_mode,
layout,
window_left,
logits_soft_cap,
sm_scale,
)
else:
# FP8 enabled
run_func(
q,
k,
v,
tmp,
o,
maybe_lse,
mask_mode,
layout,
window_left,
scale_q,
scale_k,
scale_v,
sm_scale,
)
else:
run_func(
q,
k,
v,
tmp,
o,
maybe_lse,
mask_mode,
layout,
window_left,
maybe_packed_custom_mask,
maybe_alibi_slopes,
logits_soft_cap,
sm_scale,
1.0 / rope_scale, # rope_rcp_scale
1.0 / rope_theta, # rope_rcp_theta
)
return o
@register_fake_op(f"flashinfer::{uri}_run")
def _fake_run_single_prefill(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
tmp: torch.Tensor,
o: torch.Tensor,
maybe_lse: Optional[torch.Tensor],
mask_mode: int,
layout: int,
window_left: int,
maybe_packed_custom_mask: Optional[torch.Tensor],
maybe_alibi_slopes: Optional[torch.Tensor],
logits_soft_cap: float,
sm_scale: float,
rope_scale: float,
rope_theta: float,
) -> None:
pass
# Register the module
return SimpleNamespace(run=run_single_prefill)
@functools.cache
def get_batch_prefill_module(backend, *args):
if backend == "trtllm-gen":
uri = "trtllm_gen_context"
module = get_trtllm_gen_prefill_module()
plan_func = module.plan
ragged_run_func = module.ragged_run
paged_run_func = module.paged_run
else:
uri = get_batch_prefill_uri(backend, *args)
module = gen_batch_prefill_module(backend, *args).build_and_load()
plan_func = module.plan.default
ragged_run_func = module.ragged_run.default
paged_run_func = module.paged_run.default
# torch library for ragged_run
@register_custom_op(
f"flashinfer::{uri}_ragged_run",
mutates_args=(
"float_workspace_buffer",
"int_workspace_buffer",
"o",
"maybe_lse",
),
)
def ragged_run(
float_workspace_buffer: torch.Tensor,
int_workspace_buffer: torch.Tensor,
plan_info_vec: List[int],
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
qo_indptr: torch.Tensor,
kv_indptr: torch.Tensor,
o: torch.Tensor,
maybe_lse: Optional[torch.Tensor],
mask_mode: int,
layout: int,
window_left: int,
enable_pdl: bool,
maybe_custom_mask: Optional[torch.Tensor],
maybe_mask_indptr: Optional[torch.Tensor],
maybe_alibi_slopes: Optional[torch.Tensor],
maybe_prefix_len_ptr: Optional[torch.Tensor],
maybe_token_pos_in_items_ptr: Optional[torch.Tensor],
maybe_max_item_len_ptr: Optional[torch.Tensor],
logits_soft_cap: float,
sm_scale: float,
rope_scale: float,
rope_theta: float,
token_pos_in_items_len: int,
) -> None:
if backend == "fa2":
ragged_run_func(
float_workspace_buffer,
int_workspace_buffer,
plan_info_vec,
q,
k,
v,
qo_indptr,
kv_indptr,
o,
maybe_lse,
mask_mode,
layout,
window_left,
enable_pdl,
maybe_custom_mask,
maybe_mask_indptr,
maybe_alibi_slopes,
maybe_prefix_len_ptr,
maybe_token_pos_in_items_ptr,
maybe_max_item_len_ptr,
logits_soft_cap,
sm_scale,
1.0 / rope_scale, # rope_rcp_scale
1.0 / rope_theta, # rope_rcp_theta
token_pos_in_items_len,
)
else:
ragged_run_func(
float_workspace_buffer,
int_workspace_buffer,
plan_info_vec,
q,
k,
v,
qo_indptr,
kv_indptr,
o,
maybe_lse,
mask_mode,
layout,
window_left,
enable_pdl,
maybe_prefix_len_ptr,
maybe_token_pos_in_items_ptr,
maybe_max_item_len_ptr,
logits_soft_cap,
sm_scale,
token_pos_in_items_len,
)
return o
@register_fake_op(f"flashinfer::{uri}_ragged_run")
def _fake_ragged_run(
float_workspace_buffer: torch.Tensor,
int_workspace_buffer: torch.Tensor,
plan_info_vec: List[int],
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
qo_indptr: torch.Tensor,
kv_indptr: torch.Tensor,
o: torch.Tensor,
maybe_lse: Optional[torch.Tensor],
mask_mode: int,
layout: int,
window_left: int,
enable_pdl: bool,
maybe_custom_mask: Optional[torch.Tensor],
maybe_mask_indptr: Optional[torch.Tensor],
maybe_alibi_slopes: Optional[torch.Tensor],
maybe_prefix_len_ptr: Optional[torch.Tensor],
maybe_token_pos_in_items_ptr: Optional[torch.Tensor],
maybe_max_item_len_ptr: Optional[torch.Tensor],
logits_soft_cap: float,
sm_scale: float,
rope_scale: float,
rope_theta: float,
token_pos_in_items_len: int,
) -> None:
pass
# torch library for paged_run
@register_custom_op(
f"flashinfer::{uri}_paged_run",
mutates_args=(
"float_workspace_buffer",
"int_workspace_buffer",
"paged_k_cache",
"paged_v_cache",
"o",
"maybe_lse",
),
)
def paged_run(
float_workspace_buffer: torch.Tensor,
int_workspace_buffer: torch.Tensor,
plan_info_vec: List[int],
q: torch.Tensor,
paged_k_cache: torch.Tensor,
paged_v_cache: torch.Tensor,
qo_indptr: torch.Tensor,
paged_kv_indptr: torch.Tensor,
paged_kv_indices: torch.Tensor,
paged_kv_last_page_len: torch.Tensor,
o: torch.Tensor,
maybe_lse: Optional[torch.Tensor],
mask_mode: int,
layout: int,
window_left: int,
enable_pdl: bool,
maybe_custom_mask: Optional[torch.Tensor],
maybe_mask_indptr: Optional[torch.Tensor],
maybe_alibi_slopes: Optional[torch.Tensor],
maybe_prefix_len_ptr: Optional[torch.Tensor],
maybe_token_pos_in_items_ptr: Optional[torch.Tensor],
maybe_max_item_len_ptr: Optional[torch.Tensor],
logits_soft_cap: float,
sm_scale: float,
scale_q: Optional[torch.Tensor],
scale_k: Optional[torch.Tensor],
scale_v: Optional[torch.Tensor],
rope_scale: float,
rope_theta: float,
token_pos_in_items_len: int,
workspace_size: int,
num_qo_heads: Optional[int] = None,
num_kv_heads: Optional[int] = None,
block_tables: Optional[torch.Tensor] = None,
kv_lens_buffer: Optional[torch.Tensor] = None,
page_size: Optional[int] = None,
max_q_len: Optional[int] = None,
max_kv_len: Optional[int] = None,
batch_size: Optional[int] = None,
cum_seq_lens_q: Optional[torch.Tensor] = None,
cum_seq_lens_kv: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
) -> None:
if backend == "trtllm-gen":
assert maybe_lse is None
assert num_qo_heads is not None
assert num_kv_heads is not None
assert block_tables is not None
assert kv_lens_buffer is not None
assert page_size is not None
assert max_kv_len is not None
assert batch_size is not None
assert cum_seq_lens_q is not None
assert cum_seq_lens_kv is not None
assert enable_pdl is not None
assert workspace_size > 0, "workspace_size must be greater than 0"
o = paged_run_func(
q.contiguous(), # NOTE(Siyuan): without contiguous, the result is incorrect
paged_k_cache,
paged_v_cache,
int_workspace_buffer,
block_tables,
kv_lens_buffer,
max_q_len,
max_kv_len,
sm_scale,
1.0, # NOTE(Siyuan): update this to expose bmm2 scale
batch_size,
cum_seq_lens_q,
cum_seq_lens_kv,
enable_pdl,
workspace_size,
window_left,
out=o,
sinks=sinks,
)
elif backend == "fa2":
assert not is_float8(q)
paged_run_func(
float_workspace_buffer,
int_workspace_buffer,
plan_info_vec,
q,
paged_k_cache,
paged_v_cache,
qo_indptr,
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_len,
o,
maybe_lse,
mask_mode,
layout,
window_left,
enable_pdl,
maybe_custom_mask,
maybe_mask_indptr,
maybe_alibi_slopes,
maybe_prefix_len_ptr,
maybe_token_pos_in_items_ptr,
maybe_max_item_len_ptr,
logits_soft_cap,
sm_scale,
1.0 / rope_scale, # rope_rcp_scale
1.0 / rope_theta, # rope_rcp_theta
token_pos_in_items_len,
)
else:
if not is_float8(q):
paged_run_func(
float_workspace_buffer,
int_workspace_buffer,
plan_info_vec,
q,
paged_k_cache,
paged_v_cache,
qo_indptr,
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_len,
o,
maybe_lse,
mask_mode,
layout,
window_left,
enable_pdl,
maybe_prefix_len_ptr,
maybe_token_pos_in_items_ptr,
maybe_max_item_len_ptr,
logits_soft_cap,
sm_scale,
token_pos_in_items_len,
)
else:
paged_run_func(
float_workspace_buffer,
int_workspace_buffer,
plan_info_vec,
q,
paged_k_cache,
paged_v_cache,
qo_indptr,
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_len,
o,
maybe_lse,
mask_mode,
layout,
window_left,
enable_pdl,
scale_q,
scale_k,
scale_v,
sm_scale,
)
return o
@register_fake_op(f"flashinfer::{uri}_paged_run")
def _fake_paged_run(
float_workspace_buffer: torch.Tensor,
int_workspace_buffer: torch.Tensor,
plan_info_vec: List[int],
q: torch.Tensor,
paged_k_cache: torch.Tensor,
paged_v_cache: torch.Tensor,
qo_indptr: torch.Tensor,
paged_kv_indptr: torch.Tensor,
paged_kv_indices: torch.Tensor,
paged_kv_last_page_len: torch.Tensor,
o: torch.Tensor,
maybe_lse: Optional[torch.Tensor],
mask_mode: int,
layout: int,
window_left: int,
enable_pdl: bool,
maybe_custom_mask: Optional[torch.Tensor],
maybe_mask_indptr: Optional[torch.Tensor],
maybe_alibi_slopes: Optional[torch.Tensor],
maybe_prefix_len_ptr: Optional[torch.Tensor],
maybe_token_pos_in_items_ptr: Optional[torch.Tensor],
maybe_max_item_len_ptr: Optional[torch.Tensor],
logits_soft_cap: float,
sm_scale: float,
rope_scale: float,
rope_theta: float,
token_pos_in_items_len: int,
workspace_size: int,
num_qo_heads: Optional[int] = None,
num_kv_heads: Optional[int] = None,
block_tables: Optional[torch.Tensor] = None,
kv_lens_buffer: Optional[torch.Tensor] = None,
page_size: Optional[int] = None,
max_q_len: Optional[int] = None,
max_kv_len: Optional[int] = None,
batch_size: Optional[int] = None,
cum_seq_lens_q: Optional[torch.Tensor] = None,
cum_seq_lens_kv: Optional[torch.Tensor] = None,
) -> None:
pass
# Register the module.
#
# Note that plan is not part of model logic. It should not be included in
# Cuda Graph or torch.compile. So, we don't provide a torch library for plan.
return SimpleNamespace(
plan=plan_func,
ragged_run=ragged_run,
paged_run=paged_run,
)
@functools.cache
def get_batch_prefill_jit_module(module_name: str, jit_module: Any):
plan_func = jit_module.plan.default
ragged_run_func = jit_module.ragged_run.default
paged_run_func = jit_module.paged_run.default
# torch library for ragged_run
@register_custom_op(
f"flashinfer::{module_name}_ragged_run",
mutates_args=(
"float_workspace_buffer",
"int_workspace_buffer",
"o",
"maybe_lse",
),
)
def ragged_run(
float_workspace_buffer: torch.Tensor,
int_workspace_buffer: torch.Tensor,
plan_info_vec: List[int],
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
qo_indptr: torch.Tensor,
kv_indptr: torch.Tensor,
o: torch.Tensor,
maybe_lse: Optional[torch.Tensor],
mask_mode: int,
layout: int,
window_left: int,
*args,
) -> None:
ragged_run_func(
float_workspace_buffer,
int_workspace_buffer,
plan_info_vec,
q,
k,
v,
qo_indptr,
kv_indptr,
o,
maybe_lse,
mask_mode,
layout,
window_left,
*args,
)
@register_fake_op(f"flashinfer::{module_name}_ragged_run")
def _fake_ragged_run(
float_workspace_buffer: torch.Tensor,
int_workspace_buffer: torch.Tensor,
plan_info_vec: List[int],
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
qo_indptr: torch.Tensor,
kv_indptr: torch.Tensor,
o: torch.Tensor,
maybe_lse: Optional[torch.Tensor],
mask_mode: int,
layout: int,
window_left: int,
*args,
) -> None:
pass
# torch library for paged_run
@register_custom_op(
f"flashinfer::{module_name}_paged_run",
mutates_args=(
"float_workspace_buffer",
"int_workspace_buffer",
"paged_k_cache",
"paged_v_cache",
"o",
"maybe_lse",
),
)
def paged_run(
float_workspace_buffer: torch.Tensor,
int_workspace_buffer: torch.Tensor,
plan_info_vec: List[int],
q: torch.Tensor,
paged_k_cache: torch.Tensor,
paged_v_cache: torch.Tensor,
qo_indptr: torch.Tensor,
paged_kv_indptr: torch.Tensor,
paged_kv_indices: torch.Tensor,
paged_kv_last_page_len: torch.Tensor,
o: torch.Tensor,
maybe_lse: Optional[torch.Tensor],
mask_mode: int,
layout: int,
window_left: int,
*args,
) -> None:
paged_run_func(
float_workspace_buffer,
int_workspace_buffer,
plan_info_vec,
q,
paged_k_cache,
paged_v_cache,
qo_indptr,
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_len,
o,
maybe_lse,
mask_mode,
layout,
window_left,
*args,
)
@register_fake_op(f"flashinfer::{module_name}_paged_run")
def _fake_paged_run(
float_workspace_buffer: torch.Tensor,
int_workspace_buffer: torch.Tensor,
plan_info_vec: List[int],
q: torch.Tensor,
paged_k_cache: torch.Tensor,
paged_v_cache: torch.Tensor,
qo_indptr: torch.Tensor,
paged_kv_indptr: torch.Tensor,
paged_kv_indices: torch.Tensor,
paged_kv_last_page_len: torch.Tensor,
o: torch.Tensor,
maybe_lse: Optional[torch.Tensor],
mask_mode: int,
layout: int,
window_left: int,
*args,
) -> None:
pass
# Register the module.
#
# Note that plan is not part of model logic. It should not be included in
# Cuda Graph or torch.compile. So, we don't provide a torch library for plan.
return SimpleNamespace(
plan=plan_func,
ragged_run=ragged_run,
paged_run=paged_run,
)
def single_prefill_with_kv_cache_with_jit_module(
jit_module: Any,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
*args,
kv_layout: str = "NHD",
mask_mode: int = MaskMode.NON_CAUSAL.value,
window_left: int = -1,
return_lse: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
device = q.device
tmp = _get_cache_buf(
"single_prefill_with_kv_cache_tmp", 32 * 1024 * 1024, device=device
)
o = torch.empty(q.shape[:-1] + v.shape[-1:], dtype=q.dtype, device=device)
lse = None
if return_lse:
lse = torch.empty((q.size(0), q.size(1)), dtype=torch.float32, device=device)
jit_module.run.default(
q,
k,
v,
tmp,
o,
lse,
mask_mode,
TensorLayout[kv_layout].value,
window_left,
*args,
)
return (o, lse) if return_lse else o
@overload
def single_prefill_with_kv_cache(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale_q: Optional[torch.Tensor] = None,
scale_k: Optional[torch.Tensor] = None,
scale_v: Optional[torch.Tensor] = None,
o_dtype: Optional[torch.dtype] = None,
custom_mask: Optional[torch.Tensor] = None,
packed_custom_mask: Optional[torch.Tensor] = None,
causal: bool = False,
kv_layout: str = "NHD",
pos_encoding_mode: str = "NONE",
use_fp16_qk_reduction: bool = False,
sm_scale: Optional[float] = None,
window_left: int = -1,
logits_soft_cap: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
backend: str = "auto",
return_lse: Literal[False] = False,
) -> torch.Tensor: ...
@overload
def single_prefill_with_kv_cache(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale_q: Optional[torch.Tensor] = None,
scale_k: Optional[torch.Tensor] = None,
scale_v: Optional[torch.Tensor] = None,
o_dtype: Optional[torch.dtype] = None,
custom_mask: Optional[torch.Tensor] = None,
packed_custom_mask: Optional[torch.Tensor] = None,
causal: bool = False,
kv_layout: str = "NHD",
pos_encoding_mode: str = "NONE",
use_fp16_qk_reduction: bool = False,
sm_scale: Optional[float] = None,
window_left: int = -1,
logits_soft_cap: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
backend: str = "auto",
return_lse: Literal[True] = True,
) -> Tuple[torch.Tensor, torch.Tensor]: ...
def single_prefill_with_kv_cache(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale_q: Optional[torch.Tensor] = None,
scale_k: Optional[torch.Tensor] = None,
scale_v: Optional[torch.Tensor] = None,
o_dtype: Optional[torch.dtype] = None,
custom_mask: Optional[torch.Tensor] = None,
packed_custom_mask: Optional[torch.Tensor] = None,
causal: bool = False,
kv_layout: str = "NHD",
pos_encoding_mode: str = "NONE",
use_fp16_qk_reduction: bool = False,
sm_scale: Optional[float] = None,
window_left: int = -1,
logits_soft_cap: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
backend: str = "auto",
return_lse: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
r"""Prefill/Append attention with KV cache for single request, return the attention
output.
Parameters
----------
q : torch.Tensor
The query tensor, shape: ``[qo_len, num_qo_heads, head_dim_qk]``.
k : torch.Tensor
The key tensor, shape: ``[kv_len, num_kv_heads, head_dim_qk]`` if :attr:`kv_layout`
is ``NHD``, or ``[num_kv_heads, kv_len, head_dim_qk]`` if :attr:`kv_layout` is
``HND``.
v : torch.Tensor
The key tensor, shape: ``[kv_len, num_kv_heads, head_dim_vo]`` if :attr:`kv_layout`
is ``NHD``, ``[num_kv_heads, kv_len, head_dim_vo]`` if :attr:`kv_layout` is
``HND``.
scale_q : Optional[torch.Tensor]
The scale tensor for query, per-head quantization with shape: ``[num_qo_heads]``.
Used with FP8 Quantization. If not provided, will be set to ``1.0``.
scale_k : Optional[torch.Tensor]
The scale tensor for key, per-head quantization with shape: ``[num_kv_heads]``.
Used with FP8 Quantization. If not provided, will be set to ``1.0``.
scale_v : Optional[torch.Tensor]
The scale tensor for value, per-head quantization with shape: ``[num_kv_heads]``.
Used with FP8 Quantization. If not provided, will be set to ``1.0``.
o_dtype : Optional[torch.dtype]
The output tensor data type, if not provided, will be set to the same as the q.
This is necessary as output dtype cannot be automatically inferred in quant.
custom_mask : Optional[torch.Tensor]
The custom boolean mask tensor, shape: ``[qo_len, kv_len]``.
The elements in the mask tensor should be either ``True`` or ``False``,
where ``False`` means the corresponding element in the attention matrix will be
masked out.
When :attr:`custom_mask` is provided, and :attr:`packed_custom_mask` is not, the
function will pack the custom mask tensor into a 1D packed mask tensor, which introduces
additional overhead.
packed_custom_mask : Optional[torch.Tensor]
The 1D packed uint8 mask tensor, if provided, the :attr:`custom_mask` will be ignored.
The packed mask tensor is generated by :func:`flashinfer.quantization.packbits`.
causal : bool
Whether to apply causal mask to the attention matrix.
This is only effective when :attr:`custom_mask` is not provided.
kv_layout : str
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
pos_encoding_mode : str
The position encoding applied inside attention kernels, could be
``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``.
Default is ``NONE``.
use_fp16_qk_reduction : bool
Whether to use f16 for qk reduction (faster at the cost of slight precision
loss).
window_left : int
The left (inclusive) window size for the attention window, when set to ``-1``, the window
size will be set to the full length of the sequence. Defaults to ``-1``.
logits_soft_cap : Optional[float]
The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not
provided, will be set to ``0``. If greater than 0, the logits will be capped according to
formula:
:math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`,
where :math:`x` is the input logits.
sm_scale : Optional[float]
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim_qk)``.
rope_scale : Optional[float]
The scale used in RoPE interpolation, if not provided, will be set to 1.0.
rope_theta : Optional[float]
The theta used in RoPE, if not provided, will be set to 1e4.
backend : str
The implementation backend, could be ``auto``/``fa2`` or ``fa3``. Defaults to ``auto``.
If set to ``auto``, the function will automatically choose the backend based on the
device architecture and kernel availability.
return_lse : bool
Whether to return the log sum exp value of the attention logits.
Returns
-------
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
If :attr:`return_lse` is ``False``, the attention output, shape: ``[qo_len, num_qo_heads, head_dim_vo]``.
If :attr:`return_lse` is ``True``, a tuple of two tensors:
* The attention output, shape: ``[qo_len, num_qo_heads, head_dim_vo]``.
* The log sum exp value, shape: ``[qo_len, num_qo_heads]``.
Examples
--------
>>> import torch
>>> import flashinfer
>>> qo_len = 128
>>> kv_len = 4096
>>> num_qo_heads = 32
>>> num_kv_heads = 4
>>> head_dim = 128
>>> q = torch.randn(qo_len, num_qo_heads, head_dim).half().to("cuda:0")
>>> k = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0")
>>> v = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0")
>>> o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True,
use_fp16_qk_reduction=True)
>>> o.shape
torch.Size([128, 32, 128])
>>> mask = torch.tril(
>>> torch.full((qo_len, kv_len), True, device="cuda:0"),
>>> diagonal=(kv_len - qo_len),
>>> )
>>> mask
tensor([[ True, True, True, ..., False, False, False],
[ True, True, True, ..., False, False, False],
[ True, True, True, ..., False, False, False],
...,
[ True, True, True, ..., True, False, False],
[ True, True, True, ..., True, True, False],
[ True, True, True, ..., True, True, True]], device='cuda:0')
>>> o_custom = flashinfer.single_prefill_with_kv_cache(q, k, v, custom_mask=mask)
>>> torch.allclose(o, o_custom, rtol=1e-3, atol=1e-3)
True
Note
----
The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` is
not equal to ``num_kv_heads``, the function will use
`grouped query attention <https://arxiv.org/abs/2305.13245>`_.
"""
_check_pos_encoding_mode(pos_encoding_mode)
_check_kv_layout(kv_layout)
tmp = _get_cache_buf("single_prefill_with_kv_cache_tmp", 32 * 1024 * 1024, q.device)
if logits_soft_cap is None:
logits_soft_cap = 0.0
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(q.size(-1))
if rope_scale is None:
rope_scale = 1.0
if rope_theta is None:
rope_theta = 1e4
if custom_mask is not None and packed_custom_mask is None:
# create packed custom mask from custom mask
packed_custom_mask = packbits(
custom_mask.contiguous().view(-1), bitorder="little"
)
if packed_custom_mask is not None:
mask_mode = MaskMode.CUSTOM.value
else:
if causal:
mask_mode = MaskMode.CAUSAL.value
else:
mask_mode = MaskMode.NON_CAUSAL.value
lse = None
if return_lse:
lse = torch.empty((q.size(0), q.size(1)), dtype=torch.float32, device=q.device)
if is_float8(q):
# FP8 quant enabled, do sanity check:
# 1. unsupported feature
# 2. dtype check
assert window_left == -1
assert q.dtype == k.dtype == v.dtype
assert q.shape[-1] == k.shape[-1] == v.shape[-1]
if scale_q is None:
scale_q = torch.ones(q.shape[1], dtype=torch.float32, device=q.device)
if scale_k is None:
scale_k = torch.ones(k.shape[1], dtype=torch.float32, device=q.device)
if scale_v is None:
scale_v = torch.ones(v.shape[1], dtype=torch.float32, device=q.device)
if backend == "auto":
backend = determine_attention_backend(
q.device,
PosEncodingMode[pos_encoding_mode].value,
use_fp16_qk_reduction,
packed_custom_mask is not None, # use_custom_mask
q.dtype,
k.dtype,
)
# o_dtype should be provided for FP8 attention
if o_dtype is None:
o_dtype = q.dtype
out = torch.empty(q.shape[:-1] + v.shape[-1:], dtype=o_dtype, device=q.device)
module = get_single_prefill_module(
backend,
q.dtype,
k.dtype,
out.dtype,
q.shape[-1], # head_dim_qk
v.shape[-1], # head_dim_vo
PosEncodingMode[pos_encoding_mode].value,
window_left >= 0, # use_sliding_window
logits_soft_cap > 0, # use_logits_soft_cap
use_fp16_qk_reduction,
)
module.run(
q,
k,
v,
tmp,
out,
lse,
mask_mode,
TensorLayout[kv_layout].value,
window_left,
packed_custom_mask,
_get_cache_alibi_slopes_buf(q.shape[1], q.device),
logits_soft_cap,
sm_scale,
scale_q,
scale_k,
scale_v,
rope_scale,
rope_theta,
)
return (out, lse) if return_lse else out
single_prefill_with_kv_cache_return_lse = functools.partial(
single_prefill_with_kv_cache, return_lse=True
)
def _compute_page_mask_indptr(
qo_indptr: torch.Tensor,
paged_kv_indptr: torch.Tensor,
paged_kv_last_page_len: torch.Tensor,
page_size: int,
) -> torch.Tensor:
if len(qo_indptr) != len(paged_kv_indptr):
raise ValueError(
"The length of qo_indptr and paged_kv_indptr should be the same."
)
mask_indptr = torch.empty_like(qo_indptr)
mask_indptr[0] = 0
mask_indptr[1:] = torch.cumsum(
(qo_indptr[1:] - qo_indptr[:-1])
* (
(paged_kv_indptr[1:] - paged_kv_indptr[:-1] - 1) * page_size
+ paged_kv_last_page_len
),
0,
)
return mask_indptr
class BatchPrefillWithPagedKVCacheWrapper:
r"""Wrapper class for prefill/append attention with paged kv-cache for batch of
requests.
Check :ref:`our tutorial <kv-layout>` for page table layout.
Example
-------
>>> import torch
>>> import flashinfer
>>> num_layers = 32
>>> num_qo_heads = 64
>>> num_kv_heads = 16
>>> head_dim = 128
>>> max_num_pages = 128
>>> page_size = 16
>>> # allocate 128MB workspace buffer
>>> workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
>>> prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
... workspace_buffer, "NHD"
... )
>>> batch_size = 7
>>> nnz_qo = 100
>>> qo_indptr = torch.tensor(
... [0, 33, 44, 55, 66, 77, 88, nnz_qo], dtype=torch.int32, device="cuda:0"
... )
>>> paged_kv_indices = torch.arange(max_num_pages).int().to("cuda:0")
>>> paged_kv_indptr = torch.tensor(
... [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0"
... )
>>> # 1 <= paged_kv_last_page_len <= page_size
>>> paged_kv_last_page_len = torch.tensor(
... [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0"
... )
>>> q_at_layer = torch.randn(num_layers, nnz_qo, num_qo_heads, head_dim).half().to("cuda:0")
>>> kv_cache_at_layer = torch.randn(
... num_layers, max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
... )
>>> # create auxiliary data structures for batch prefill attention
>>> prefill_wrapper.plan(
... qo_indptr,
... paged_kv_indptr,
... paged_kv_indices,
... paged_kv_last_page_len,
... num_qo_heads,
... num_kv_heads,
... head_dim,
... page_size,
... causal=True,
... )
>>> outputs = []
>>> for i in range(num_layers):
... q = q_at_layer[i]
... kv_cache = kv_cache_at_layer[i]
... # compute batch prefill attention, reuse auxiliary data structures
... o = prefill_wrapper.run(q, kv_cache)
... outputs.append(o)
...
>>> outputs[0].shape
torch.Size([100, 64, 128])
>>>
>>> # below is another example of creating custom mask for batch prefill attention
>>> mask_arr = []
>>> qo_len = (qo_indptr[1:] - qo_indptr[:-1]).cpu().tolist()
>>> kv_len = (page_size * (paged_kv_indptr[1:] - paged_kv_indptr[:-1] - 1) + paged_kv_last_page_len).cpu().tolist()
>>> for i in range(batch_size):
... mask_i = torch.tril(
... torch.full((qo_len[i], kv_len[i]), True, device="cuda:0"),
... diagonal=(kv_len[i] - qo_len[i]),
... )
... mask_arr.append(mask_i.flatten())
...
>>> mask = torch.cat(mask_arr, dim=0)
>>> prefill_wrapper.plan(
... qo_indptr,
... paged_kv_indptr,
... paged_kv_indices,
... paged_kv_last_page_len,
... num_qo_heads,
... num_kv_heads,
... head_dim,
... page_size,
... custom_mask=mask,
... )
>>> for i in range(num_layers):
... q = q_at_layer[i]
... kv_cache = kv_cache_at_layer[i]
... # compute batch prefill attention, reuse auxiliary data structures
... o_custom = prefill_wrapper.run(q, kv_cache)
... assert torch.allclose(o_custom, outputs[i], rtol=1e-3, atol=1e-3)
...
Note
----
To accelerate computation, FlashInfer's batch prefill/append attention operators
create some auxiliary data structures, these data structures can be reused across
multiple prefill/append attention calls (e.g. different Transformer layers). This
wrapper class manages the lifecycle of these data structures.
"""
def __init__(
self,
float_workspace_buffer: torch.Tensor,
kv_layout: str = "NHD",
use_cuda_graph: bool = False,
qo_indptr_buf: Optional[torch.Tensor] = None,
paged_kv_indptr_buf: Optional[torch.Tensor] = None,
paged_kv_indices_buf: Optional[torch.Tensor] = None,
paged_kv_last_page_len_buf: Optional[torch.Tensor] = None,
custom_mask_buf: Optional[torch.Tensor] = None,
mask_indptr_buf: Optional[torch.Tensor] = None,
backend: str = "auto",
jit_args: Optional[List[Any]] = None,
jit_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
r"""Constructor of :class:`BatchPrefillWithPagedKVCacheWrapper`.
Parameters
----------
float_workspace_buffer : torch.Tensor
The user reserved workspace buffer used to store intermediate attention results in
split-k algorithm. The recommended size is 128MB, the device of the workspace buffer
should be the same as the device of the input tensors.
kv_layout : str
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
use_cuda_graph : bool
Whether to enable CUDA graph capture for the prefill kernels, if enabled, the
auxiliary data structures will be stored in provided buffers. The ``batch_size``
cannot change during the lifecycle of this wrapper when CUDAGraph is enabled.
qo_indptr_buf : Optional[torch.Tensor]
The user reserved buffer to store the ``qo_indptr`` array, the size of the buffer
should be ``[batch_size + 1]``.
This argument is only effective when ``use_cuda_graph`` is ``True``.
paged_kv_indptr_buf : Optional[torch.Tensor]
The user reserved buffer to store the ``paged_kv_indptr`` array, the size of this
buffer should be ``[batch_size + 1]``.
This argument is only effective when ``use_cuda_graph`` is ``True``.
paged_kv_indices_buf : Optional[torch.Tensor]
The user reserved buffer to store the ``paged_kv_indices`` array, should be large
enough to store the maximum possible size of the ``paged_kv_indices`` array during
the lifetime of the wrapper. This argument is only effective when ``use_cuda_graph``
is ``True``.
paged_kv_last_page_len_buf : Optional[torch.Tensor]
The user reserved buffer to store the ``paged_kv_last_page_len`` array, the size of
the buffer should be ``[batch_size]``.
This argument is only effective when ``use_cuda_graph`` is ``True``.
custom_mask_buf : Optional[torch.Tensor]
The user reserved buffer to store the custom mask tensor, should be large enough to
store the maximum possible size of the packed custom mask tensor during the lifetime of
the wrapper. This argument is only effective when ``use_cuda_graph`` is set to ``True``
and the custom mask will be used in attention computation.
mask_indptr_buf : Optional[torch.Tensor]
The user reserved buffer to store the ``mask_indptr`` array, the size of the buffer
should be ``[batch_size + 1]``.
This argument is only effective when ``use_cuda_graph`` is ``True`` and the custom
mask will be used in attention computation.
backend : str
The implementation backend, could be ``auto``/``fa2``,``fa3`` or ``cudnn``. Defaults to ``auto``.
If set to ``auto``, the wrapper will automatically choose the backend based on the
device architecture and kernel availability.
jit_args : Optional[List[Any]]
If provided, the wrapper will use the provided arguments to create the JIT module,
otherwise, the wrapper will use default attention implementation.
jit_kwargs : Optional[Dict[str, Any]]
The keyword arguments to create the JIT module, defaults to None.
"""
_check_kv_layout(kv_layout)
if jit_args is not None:
if jit_kwargs is None:
jit_kwargs = {}
self._jit_module = get_batch_prefill_jit_module(
jit_args[0],
get_customize_batch_prefill_module(backend, *jit_args, **jit_kwargs),
)
else:
self._jit_module = None
self._kv_layout = kv_layout
if backend == "cudnn":
assert kv_layout == "NHD", "CUDNN backend only supports NHD layout"
self._float_workspace_buffer = float_workspace_buffer
self._workspace_size = (
self._float_workspace_buffer.numel()
* self._float_workspace_buffer.element_size()
)
self.device = float_workspace_buffer.device
self._vector_sparse_indptr_buffer: Optional[torch.Tensor] = None
if backend in ["fa3", "auto", "trtllm-gen"]:
# NOTE(Zihao): assume maximum accumulate kv length is 16M
self._vector_sparse_indices_buffer = torch.empty(
(16 * 1024 * 1024,), dtype=torch.int32, device=self.device
)
# NOTE(Zihao): assume maximum batch size is 32768
self._vector_sparse_indptr_buffer = torch.empty(
(32768,), dtype=torch.int32, device=self.device
)
self._kv_lens_buffer = torch.empty(
(32768,), dtype=torch.int32, device=self.device
)
self._int_workspace_buffer = torch.empty(
(8 * 1024 * 1024,), dtype=torch.uint8, device=self.device
)
self._pin_memory_int_workspace_buffer = torch.empty(
self._int_workspace_buffer.shape,
dtype=self._int_workspace_buffer.dtype,
device="cpu",
pin_memory=True,
)
self._use_cuda_graph = use_cuda_graph
if use_cuda_graph:
if not torch.is_tensor(qo_indptr_buf):
raise ValueError(
"qo_indptr_buf should be a torch.Tensor in CUDA graph mode"
)
if not torch.is_tensor(paged_kv_indptr_buf):
raise ValueError(
"paged_kv_indptr_buf should be a torch.Tensor in CUDA graph mode"
)
if not torch.is_tensor(paged_kv_indices_buf):
raise ValueError(
"paged_kv_indices_buf should be a torch.Tensor in CUDA graph mode"
)
if not torch.is_tensor(paged_kv_last_page_len_buf):
raise ValueError(
"paged_kv_last_page_len_buf should be a torch.Tensor in CUDA graph mode"
)
self._fixed_batch_size = len(qo_indptr_buf) - 1
if len(paged_kv_indptr_buf) != self._fixed_batch_size + 1:
raise ValueError(
"The length of paged_kv_indptr_buf should be batch_size + 1."
)
if len(paged_kv_last_page_len_buf) != self._fixed_batch_size:
raise ValueError(
"The length of paged_kv_last_page_len_buf should be batch_size."
)
# NOTE(Zihao): do not check custom_mask_buf and mask_indptr_buf here, as they are optional
else:
self._fixed_batch_size = 0
self._qo_indptr_buf = qo_indptr_buf
self._paged_kv_indptr_buf = paged_kv_indptr_buf
self._paged_kv_indices_buf = paged_kv_indices_buf
self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buf
self._custom_mask_buf = custom_mask_buf
self._mask_indptr_buf = mask_indptr_buf
self._max_total_num_rows = None
self._backend = backend
self._plan_info = None
self._cached_module = None
self._seq_lens_kv = None
self._seq_lens_q = None
self._block_tables = None
@property
def is_cuda_graph_enabled(self) -> bool:
return self._use_cuda_graph
def reset_workspace_buffer(
self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor
) -> None:
r"""Reset the workspace buffer.
Parameters
----------
float_workspace_buffer : torch.Tensor
The new float workspace buffer, the device of the new float workspace buffer should
be the same as the device of the input tensors.
int_workspace_buffer : torch.Tensor
The new int workspace buffer, the device of the new int workspace buffer should
be the same as the device of the input tensors.
"""
self._float_workspace_buffer = float_workspace_buffer
self._int_workspace_buffer = int_workspace_buffer
self._pin_memory_int_workspace_buffer = torch.empty(
self._int_workspace_buffer.shape,
dtype=self._int_workspace_buffer.dtype,
device="cpu",
pin_memory=True,
)
def plan(
self,
qo_indptr: torch.Tensor,
paged_kv_indptr: torch.Tensor,
paged_kv_indices: torch.Tensor,
paged_kv_last_page_len: torch.Tensor,
num_qo_heads: int,
num_kv_heads: int,
head_dim_qk: int,
page_size: int,
head_dim_vo: Optional[int] = None,
custom_mask: Optional[torch.Tensor] = None,
packed_custom_mask: Optional[torch.Tensor] = None,
causal: bool = False,
pos_encoding_mode: str = "NONE",
use_fp16_qk_reduction: bool = False,
sm_scale: Optional[float] = None,
window_left: int = -1,
logits_soft_cap: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
q_data_type: Union[str, torch.dtype] = "float16",
kv_data_type: Optional[Union[str, torch.dtype]] = None,
non_blocking: bool = True,
prefix_len_ptr: Optional[torch.Tensor] = None,
token_pos_in_items_ptr: Optional[torch.Tensor] = None,
token_pos_in_items_len: int = 0,
max_item_len_ptr: Optional[torch.Tensor] = None,
seq_lens: Optional[torch.Tensor] = None,
seq_lens_q: Optional[torch.Tensor] = None,
block_tables: Optional[torch.Tensor] = None,
max_token_per_sequence: Optional[int] = None,
max_sequence_kv: Optional[int] = None,
) -> None:
r"""Plan batch prefill/append attention on Paged KV-Cache for given problem specification.
Parameters
----------
qo_indptr : torch.Tensor
The indptr of the query/output tensor, shape: ``[batch_size + 1]``.
paged_kv_indptr : torch.Tensor
The indptr of the paged kv-cache, shape: ``[batch_size + 1]``.
paged_kv_indices : torch.Tensor
The page indices of the paged kv-cache, shape: ``[qo_indptr[-1]]``.
paged_kv_last_page_len : torch.Tensor
The number of entries in the last page of each request in the paged
kv-cache, shape: ``[batch_size]``.
num_qo_heads : int
The number of query/output heads.
num_kv_heads : int
The number of key/value heads.
head_dim_qk : int
The dimension of the query/key heads.
page_size : int
The size of each page in the paged kv-cache.
head_dim_vo : Optional[int]
The dimension of the value/output heads, if not provided, will be set to
``head_dim_qk``.
custom_mask : Optional[torch.Tensor]
The flattened boolean mask tensor, shape: ``(sum(q_len[i] * k_len[i] for i in range(batch_size))``.
The elements in the mask tensor should be either ``True`` or ``False``,
where ``False`` means the corresponding element in the attention matrix will be
masked out.
Please refer to the :ref:`mask layout <mask-layout>` for more details about flattened
layout of mask tensor.
When :attr:`custom_mask` is provided, and :attr:`packed_custom_mask` is not, the
function will pack the custom mask tensor into a 1D packed mask tensor, which introduces
additional overhead.
packed_custom_mask : Optional[torch.Tensor]
The 1D packed uint8 mask tensor, if provided, the :attr:`custom_mask` will be ignored.
The packed mask tensor is generated by :func:`flashinfer.quantization.packbits`.
causal : bool
Whether to apply causal mask to the attention matrix.
This is only effective when :attr:`custom_mask` is not provided in
:meth:`plan`.
pos_encoding_mode : str
The position encoding applied inside attention kernels, could be
``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``.
Default is ``NONE``.
use_fp16_qk_reduction : bool
Whether to use f16 for qk reduction (faster at the cost of slight precision
loss).
window_left : int
The left (inclusive) window size for the attention window, when set to ``-1``, the window
size will be set to the full length of the sequence. Defaults to ``-1``.
logits_soft_cap : Optional[float]
The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not
provided, will be set to ``0``. If greater than 0, the logits will be capped according to
formula:
:math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`,
where :math:`x` is the input logits.
sm_scale : Optional[float]
The scale used in softmax, if not provided, will be set to
``1.0 / sqrt(head_dim)``.
rope_scale : Optional[float]
The scale used in RoPE interpolation, if not provided, will be set to
``1.0``.
rope_theta : Optional[float]
The theta used in RoPE, if not provided, will be set to ``1e4``.
q_data_type : Union[str, torch.dtype]
The data type of the query tensor, defaults torch.float16.
kv_data_type : Optional[Union[str, torch.dtype]]
The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
non_blocking : bool
Whether to copy the input tensors to the device asynchronously, defaults to ``True``.
prefix_len_ptr :Optional[torch.Tensor]
prefix length. A uint32 1D tensor indicating the prefix length of each prompt. The tensor size is equal to the batch size.
token_pos_in_items_ptr : Optional[float]
A uint16 1D tensor (it will be converted to uint16 in flashinfer) indicating the token position of each item and started from 0 (delimiter)
for each item. E.g., if we have 3 items of length 3, 2, 4 respectively for this member. This vector will be looking like
`[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0]` with 4 delimiters indexed as 0. For batch size > 1,
we will concat them as 1D with zero paddings to make sure each has the same length, the padding length is defined by
`token_pos_in_items_len` - length of the raw `token_pos_in_items_ptr` for each prompt.
token_pos_in_items_len : int
zero padding length for `token_pos_in_items_ptr` to better handle the bsz > 1 case. Still using the above 3,2,4 example.
If we set `token_pos_in_items_len` to be 20, it will be `[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0]`
with 7 padded zeros. (note there're 8 zeros in the end where the first one is the delimiter token 0 in the end of the prompt)
max_item_len_ptr : Optional[float]
a uint16 vector contains the max token length of all items for each prompt
seq_lens: Optional[torch.Tensor]
A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]``.
seq_lens_q: Optional[torch.Tensor]
A uint32 1D tensor indicating the q sequence length of each prompt. shape: ``[batch_size]``.
If not provided, will be set to the same value as ``seq_lens``.
block_tables: Optional[torch.Tensor]
A uint32 2D tensor indicating the block table of each prompt. shape: ``[batch_size, max_num_blocks_per_seq]``.
max_token_per_sequence: Optional[int],
Required for cudnn backend. This is the scalar max token length of each sequence.
max_sequence_kv: Optional[int],
Required for cudnn backend. This is the scalar max sequence length of each sequence in kv cache.
Note
----
The :meth:`plan` method should be called before any :meth:`run` or
:meth:`run_return_lse` calls, auxiliary data structures will be created
during this call and cached for multiple kernel runs.
The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads``
is not equal to ``num_kv_heads``, the function will use
`grouped query attention <https://arxiv.org/abs/2305.13245>`_.
The :meth:`plan` method cannot be used in Cuda Graph or in ``torch.compile``.
"""
q_data_type = canonicalize_torch_dtype(q_data_type)
if kv_data_type is None:
kv_data_type = q_data_type
kv_data_type = canonicalize_torch_dtype(kv_data_type)
if logits_soft_cap is None:
logits_soft_cap = 0.0
if head_dim_vo is None:
head_dim_vo = head_dim_qk
batch_size = len(qo_indptr) - 1
self._batch_size = batch_size
self._num_qo_heads = num_qo_heads
self._num_kv_heads = num_kv_heads
if custom_mask is not None or packed_custom_mask is not None:
mask_indptr = _compute_page_mask_indptr(
qo_indptr,
paged_kv_indptr,
paged_kv_last_page_len,
page_size,
)
if packed_custom_mask is None and custom_mask is not None:
# create packed custom mask from custom mask
packed_custom_mask, mask_indptr = segment_packbits(
custom_mask.contiguous().view(-1),
mask_indptr,
bitorder="little",
)
self._prefix_len_ptr = prefix_len_ptr
self._token_pos_in_items_ptr = token_pos_in_items_ptr
self._token_pos_in_items_len = token_pos_in_items_len
self._max_item_len_ptr = max_item_len_ptr
# NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
if max_token_per_sequence is not None:
self._max_q_len = max_token_per_sequence
else:
qo_indptr_host = qo_indptr.to("cpu")
self._max_q_len = max(qo_indptr_host).item()
total_num_rows = qo_indptr_host[-1]
if max_sequence_kv is not None:
self._max_kv_len = max_sequence_kv
else:
paged_kv_indptr_host = paged_kv_indptr.to("cpu")
paged_kv_last_page_len_host = paged_kv_last_page_len.to("cpu")
if seq_lens is None:
kv_lens_arr_host = get_seq_lens(
paged_kv_indptr_host, paged_kv_last_page_len_host, page_size
)
else:
kv_lens_arr_host = seq_lens.cpu().flatten()
self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_(
kv_lens_arr_host, non_blocking=non_blocking
)
self._max_kv_len = max(kv_lens_arr_host).item()
if self.is_cuda_graph_enabled:
if self._max_total_num_rows is None:
self._max_total_num_rows = total_num_rows
elif total_num_rows > self._max_total_num_rows:
raise ValueError(
"The total number of rows in qo_indptr {} in cuda graph mode cannot "
"exceed the number of rows set during initialization {}.".format(
total_num_rows, self._max_total_num_rows
)
)
if batch_size != self._fixed_batch_size:
raise ValueError(
"The batch size should be fixed during the lifecycle of the wrapper in "
"cuda graph mode, the runtime batch size {} mismatches the batch size {} "
" set during initialization.".format(
batch_size, self._fixed_batch_size
)
)
if len(paged_kv_indices) > len(self._paged_kv_indices_buf):
raise ValueError(
"The length of paged_kv_indices exceeds the allocated buffer size."
)
self._qo_indptr_buf.copy_(qo_indptr, non_blocking=non_blocking)
self._paged_kv_indptr_buf.copy_(paged_kv_indptr, non_blocking=non_blocking)
self._paged_kv_last_page_len_buf.copy_(
paged_kv_last_page_len, non_blocking=non_blocking
)
self._paged_kv_indices_buf[: len(paged_kv_indices)].copy_(
paged_kv_indices,
non_blocking=(paged_kv_indices.device == self.device) and non_blocking,
)
if packed_custom_mask is not None:
if not torch.is_tensor(self._custom_mask_buf):
raise ValueError(
"custom_mask_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation."
)
if not torch.is_tensor(self._mask_indptr_buf):
raise ValueError(
"mask_indptr_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation."
)
self._custom_mask_buf[: len(packed_custom_mask)].copy_(
packed_custom_mask,
non_blocking=(packed_custom_mask.device == self.device)
and non_blocking,
)
# NOTE(Zihao): mask_indptr has the same length as qo_indptr
self._mask_indptr_buf.copy_(mask_indptr, non_blocking=non_blocking)
else:
self._qo_indptr_buf = qo_indptr.to(self.device, non_blocking=non_blocking)
self._paged_kv_indptr_buf = paged_kv_indptr.to(
self.device, non_blocking=non_blocking
)
self._paged_kv_indices_buf = paged_kv_indices.to(
self.device, non_blocking=non_blocking
)
self._paged_kv_last_page_len_buf = paged_kv_last_page_len.to(
self.device, non_blocking=non_blocking
)
if packed_custom_mask is not None:
self._custom_mask_buf = packed_custom_mask.to(
self.device, non_blocking=non_blocking
)
self._mask_indptr_buf = mask_indptr.to(
self.device, non_blocking=non_blocking
)
else:
self._custom_mask_buf = None
self._mask_indptr_buf = None
self._cached_q_data_type = q_data_type
self._cached_kv_data_type = kv_data_type
if self._jit_module is not None:
self._cached_module = self._jit_module
else:
if self._backend == "auto":
self._backend = determine_attention_backend(
self.device,
PosEncodingMode[pos_encoding_mode].value,
use_fp16_qk_reduction,
self._custom_mask_buf is not None, # use_custom_mask
q_data_type,
kv_data_type,
)
if self._backend != "cudnn":
get_module_args = (
q_data_type,
kv_data_type,
q_data_type,
paged_kv_indptr.dtype,
head_dim_qk,
head_dim_vo,
PosEncodingMode[pos_encoding_mode].value,
window_left >= 0, # use_sliding_window
logits_soft_cap > 0, # use_logits_soft_cap
use_fp16_qk_reduction,
)
self._cached_module = get_batch_prefill_module(
self._backend, *get_module_args
)
if self._backend == "fa3" or self._backend == "trtllm-gen":
if page_size != 1:
vector_sparse_indptr_host = torch.cat(
[
torch.tensor(
[0], dtype=torch.int32, device=kv_lens_arr_host.device
),
torch.cumsum(kv_lens_arr_host, dim=0, dtype=torch.int32),
],
dim=0,
)
self._vector_sparse_indptr_buffer[
: len(vector_sparse_indptr_host)
].copy_(vector_sparse_indptr_host, non_blocking=non_blocking)
paged_kv_indptr_host = vector_sparse_indptr_host
self._block_tables = block_tables
if self._backend == "trtllm-gen":
assert self._kv_layout == "HND"
assert logits_soft_cap == 0.0
if self._block_tables is None:
blocks_per_seq = [
(seq_len + page_size - 1) // page_size
for seq_len in kv_lens_arr_host
]
max_num_blocks_per_seq = max(blocks_per_seq)
self._block_tables = torch.zeros(
(batch_size, max_num_blocks_per_seq),
dtype=torch.int,
device=self.device,
)
block_id = paged_kv_indptr_host[0]
for i in range(batch_size):
num_blocks_needed = blocks_per_seq[i]
assert self._block_tables is not None, (
"block_tables is not initialized"
)
self._block_tables[i, :num_blocks_needed] = paged_kv_indices[
block_id : block_id + num_blocks_needed
]
block_id += num_blocks_needed
if self._cached_module is not None:
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_host,
paged_kv_indptr_host,
kv_lens_arr_host,
self._max_total_num_rows or total_num_rows,
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
head_dim_qk,
head_dim_vo,
causal,
)
self._causal = causal
self._pos_encoding_mode = pos_encoding_mode
self._use_fp16_qk_reduction = use_fp16_qk_reduction
self._window_left = window_left
self._logits_soft_cap = logits_soft_cap
self._sm_scale = sm_scale
self._rope_scale = rope_scale
self._rope_theta = rope_theta
self._seq_lens_kv = seq_lens
self._seq_lens_q = seq_lens_q if seq_lens_q is not None else seq_lens
begin_forward = plan
def forward(
self,
q: torch.Tensor,
paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
causal: bool = False,
pos_encoding_mode: str = "NONE",
use_fp16_qk_reduction: bool = False,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
window_left: int = -1,
logits_soft_cap: Optional[float] = None,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
) -> torch.Tensor:
r"""Warning: This function is deprecated, please use :meth:`run` instead."""
self._causal = causal
self._pos_encoding_mode = pos_encoding_mode
self._use_fp16_qk_reduction = use_fp16_qk_reduction
self._window_left = window_left
self._logits_soft_cap = logits_soft_cap
self._sm_scale = sm_scale
self._rope_scale = rope_scale
self._rope_theta = rope_theta
return self.run(q, paged_kv_cache, k_scale=k_scale, v_scale=v_scale)
@overload
def run(
self,
q: torch.Tensor,
paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
*args,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
return_lse: Literal[False] = False,
enable_pdl: Optional[bool] = None,
window_left: Optional[int] = None,
) -> torch.Tensor: ...
@overload
def run(
self,
q: torch.Tensor,
paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
*args,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
return_lse: Literal[True] = True,
enable_pdl: Optional[bool] = None,
window_left: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ...
def run(
self,
q: torch.Tensor,
paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
*args,
q_scale: Optional[float] = None,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
return_lse: bool = False,
enable_pdl: Optional[bool] = None,
window_left: Optional[int] = None,
sinks: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
r"""Compute batch prefill/append attention between query and paged kv-cache.
Parameters
----------
q : torch.Tensor
The query tensor, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``
paged_kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
The paged KV-Cache stored as a tuple of tensors or a single tensor:
* a tuple ``(k_cache, v_cache)`` of 4-D tensors, each with shape:
``[max_num_pages, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``,
and ``[max_num_pages, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``.
* a single 5-D tensor with shape:
``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if
:attr:`kv_layout` is ``NHD``, and
``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if
:attr:`kv_layout` is ``HND``. Where ``paged_kv_cache[:, 0]`` is the key-cache and
``paged_kv_cache[:, 1]`` is the value-cache.
*args
Additional arguments for custom kernels.
k_scale : Optional[float]
The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``.
v_scale : Optional[float]
The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``.
out : Optional[torch.Tensor]
The output tensor, if not provided, will be allocated internally.
lse : Optional[torch.Tensor]
The log-sum-exp of attention logits, if not provided, will be allocated internally.
return_lse : bool
Whether to return the logsumexp of attention output
enable_pdl : bool
Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization
Only supported for >= sm90, and currently only for FA2 and CUDA core decode.
Returns
-------
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
If :attr:`return_lse` is ``False``, the attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``.
If :attr:`return_lse` is ``True``, a tuple of two tensors:
* The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``.
* The logsumexp of attention output, shape: ``[qo_indptr[-1], num_qo_heads]``.
"""
if enable_pdl is None:
enable_pdl = device_support_pdl(q.device)
k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout)
_check_cached_qkv_data_type(
q, k_cache, self._cached_q_data_type, self._cached_kv_data_type
)
stride_block = k_cache.stride(0)
if self._kv_layout == "NHD":
page_size = k_cache.shape[1]
stride_n = k_cache.stride(1)
else:
page_size = k_cache.shape[2]
stride_n = k_cache.stride(2)
window_left = self._window_left if window_left is None else window_left
if self._backend != "trtllm-gen":
# NOTE(Siyuan): since window_left is appeared in the plan function, we need to make sure it is the same as the one in the plan function.
# Remove this check if the backend supports dynamic window_left.
assert window_left == self._window_left
logits_soft_cap = self._logits_soft_cap
sm_scale = self._sm_scale
rope_scale = self._rope_scale
rope_theta = self._rope_theta
if logits_soft_cap is None:
logits_soft_cap = 0.0
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(q.size(-1))
if q_scale is not None:
sm_scale *= q_scale
if k_scale is not None:
sm_scale *= k_scale
if rope_scale is None:
rope_scale = 1.0
if rope_theta is None:
rope_theta = 1e4
if return_lse:
if lse is None:
lse = torch.empty(
(q.size(0), q.size(1)), dtype=torch.float32, device=q.device
)
else:
check_shape_dtype_device(
lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
)
if out is None:
out = torch.empty(
q.shape[:-1] + v_cache.shape[-1:], dtype=q.dtype, device=q.device
)
else:
check_shape_dtype_device(
out, q.shape[:-1] + v_cache.shape[-1:], q.dtype, q.device, "out"
)
if self._custom_mask_buf is not None:
mask_mode = MaskMode.CUSTOM.value
else:
if self._causal:
mask_mode = MaskMode.CAUSAL.value
else:
mask_mode = MaskMode.NON_CAUSAL.value
if self._prefix_len_ptr is not None:
mask_mode = MaskMode.MULTIITEMSCORING.value
if self._backend == "fa3":
# NOTE(Zihao): we divide both stride_block and stride_n by stride_n
# because we will multiply stride_n back in the kernel
sparse_indices = block_sparse_indices_to_vector_sparse_offsets(
self._paged_kv_indices_buf,
self._paged_kv_indptr_buf,
self._vector_sparse_indices_buffer, # output
self._vector_sparse_indptr_buffer,
self._kv_lens_buffer,
stride_block // stride_n,
1, # stride_n // stride_n
page_size,
)
sparse_indptr = self._vector_sparse_indptr_buffer
else:
sparse_indices = self._paged_kv_indices_buf
sparse_indptr = self._paged_kv_indptr_buf
if self._backend == "cudnn":
if self._seq_lens_q is not None and self._seq_lens_q.dim() == 1:
self._seq_lens_q = self._seq_lens_q.reshape(self._batch_size, 1, 1, 1)
if self._seq_lens_kv is not None and self._seq_lens_kv.dim() == 1:
self._seq_lens_kv = self._seq_lens_kv.reshape(self._batch_size, 1, 1, 1)
cudnn_batch_prefill_with_kv_cache(
q,
k_cache, # Need to be changed
v_cache, # Need to be changed
self._sm_scale,
self._float_workspace_buffer,
actual_seq_lens_q=self._seq_lens_q,
actual_seq_lens_kv=self._seq_lens_kv,
max_token_per_sequence=self._max_q_len,
max_sequence_kv=self._max_kv_len,
block_tables=self._block_tables,
causal=self._causal,
return_lse=return_lse,
batch_offsets_q=self._qo_indptr_buf,
batch_offsets_o=self._qo_indptr_buf,
out=out,
lse=lse,
)
else:
if self._backend != "trtllm-gen":
assert self._plan_info is not None, "plan info is not initialized"
run_args = [
self._float_workspace_buffer,
self._int_workspace_buffer,
self._plan_info,
q,
k_cache,
v_cache,
self._qo_indptr_buf,
sparse_indptr,
sparse_indices,
self._paged_kv_last_page_len_buf,
out,
lse,
mask_mode,
TensorLayout[self._kv_layout].value,
window_left,
enable_pdl,
]
if self._jit_module is not None:
run_args.extend(list(args))
else:
run_args += [
self._custom_mask_buf,
self._mask_indptr_buf,
_get_cache_alibi_slopes_buf(q.shape[1], q.device),
self._prefix_len_ptr,
self._token_pos_in_items_ptr,
self._max_item_len_ptr,
logits_soft_cap,
sm_scale,
None, # scale_q, not supported yet
None, # scale_k
None, # scale_v
rope_scale,
rope_theta,
self._token_pos_in_items_len,
self._workspace_size,
self._num_qo_heads,
self._num_kv_heads,
self._block_tables,
self._kv_lens_buffer,
page_size,
self._max_q_len,
self._max_kv_len,
self._batch_size,
self._qo_indptr_buf,
self._vector_sparse_indptr_buffer,
sinks,
]
assert self._cached_module is not None, "cached module is not initialized"
self._cached_module.paged_run(*run_args)
if v_scale is not None:
# TODO(Zihao): fused into kernel
if is_float8(out):
out = (out.to(torch.float32) * v_scale).to(out.dtype)
else:
out *= v_scale
return (out, lse) if return_lse else out
run_return_lse = functools.partialmethod(run, return_lse=True)
def forward_return_lse(
self,
q: torch.Tensor,
paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
causal: bool = False,
pos_encoding_mode: str = "NONE",
use_fp16_qk_reduction: bool = False,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
window_left: int = -1,
logits_soft_cap: Optional[float] = None,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Warning: This function is deprecated, please use :meth:`run_return_lse` instead."""
self._causal = causal
self._pos_encoding_mode = pos_encoding_mode
self._use_fp16_qk_reduction = use_fp16_qk_reduction
self._window_left = window_left
self._logits_soft_cap = logits_soft_cap
self._sm_scale = sm_scale
self._rope_scale = rope_scale
self._rope_theta = rope_theta
return self.run_return_lse(q, paged_kv_cache, k_scale=k_scale, v_scale=v_scale)
def end_forward(self) -> None:
r"""Warning: this function is deprecated and has no effect."""
pass
def _compute_mask_indptr(
qo_indptr: torch.Tensor, kv_indptr: torch.Tensor
) -> torch.Tensor:
if len(qo_indptr) != len(kv_indptr):
raise ValueError("The length of qo_indptr and kv_indptr should be the same.")
mask_indptr = torch.empty_like(qo_indptr)
mask_indptr[0] = 0
mask_indptr[1:] = torch.cumsum(
(qo_indptr[1:] - qo_indptr[:-1]) * (kv_indptr[1:] - kv_indptr[:-1]),
0,
)
return mask_indptr
class BatchPrefillWithRaggedKVCacheWrapper:
r"""Wrapper class for prefill/append attention with ragged (tensor) kv-cache for
batch of requests.
Check :ref:`our tutorial <kv-layout>` for ragged kv-cache layout.
Example
-------
>>> import torch
>>> import flashinfer
>>> num_layers = 32
>>> num_qo_heads = 64
>>> num_kv_heads = 16
>>> head_dim = 128
>>> # allocate 128MB workspace buffer
>>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
>>> prefill_wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
... workspace_buffer, "NHD"
... )
>>> batch_size = 7
>>> nnz_kv = 100
>>> nnz_qo = 100
>>> qo_indptr = torch.tensor(
... [0, 33, 44, 55, 66, 77, 88, nnz_qo], dtype=torch.int32, device="cuda:0"
... )
>>> kv_indptr = qo_indptr.clone()
>>> q_at_layer = torch.randn(num_layers, nnz_qo, num_qo_heads, head_dim).half().to("cuda:0")
>>> k_at_layer = torch.randn(num_layers, nnz_kv, num_kv_heads, head_dim).half().to("cuda:0")
>>> v_at_layer = torch.randn(num_layers, nnz_kv, num_kv_heads, head_dim).half().to("cuda:0")
>>> # create auxiliary data structures for batch prefill attention
>>> prefill_wrapper.plan(
... qo_indptr,
... kv_indptr,
... num_qo_heads,
... num_kv_heads,
... head_dim,
... causal=True,
... )
>>> outputs = []
>>> for i in range(num_layers):
... q = q_at_layer[i]
... k = k_at_layer[i]
... v = v_at_layer[i]
... # compute batch prefill attention, reuse auxiliary data structures
... o = prefill_wrapper.run(q, k, v)
... outputs.append(o)
...
>>> outputs[0].shape
torch.Size([100, 64, 128])
>>>
>>> # below is another example of creating custom mask for batch prefill attention
>>> mask_arr = []
>>> qo_len = (qo_indptr[1:] - qo_indptr[:-1]).cpu().tolist()
>>> kv_len = (kv_indptr[1:] - kv_indptr[:-1]).cpu().tolist()
>>> for i in range(batch_size):
... mask_i = torch.tril(
... torch.full((qo_len[i], kv_len[i]), True, device="cuda:0"),
... diagonal=(kv_len[i] - qo_len[i]),
... )
... mask_arr.append(mask_i.flatten())
...
>>> mask = torch.cat(mask_arr, dim=0)
>>> prefill_wrapper.plan(
... qo_indptr,
... kv_indptr,
... num_qo_heads,
... num_kv_heads,
... head_dim,
... custom_mask=mask
... )
>>> outputs_custom_mask = []
>>> for i in range(num_layers):
... q = q_at_layer[i]
... k = k_at_layer[i]
... v = v_at_layer[i]
... # compute batch prefill attention, reuse auxiliary data structures
... o_custom = prefill_wrapper.run(q, k, v)
... assert torch.allclose(o_custom, outputs[i], rtol=1e-3, atol=1e-3)
...
>>> outputs_custom_mask[0].shape
torch.Size([100, 64, 128])
Note
----
To accelerate computation, FlashInfer's batch prefill/append attention operators
create some auxiliary data structures, these data structures can be reused across
multiple prefill/append attention calls (e.g. different Transformer layers). This
wrapper class manages the lifecycle of these data structures.
"""
def __init__(
self,
float_workspace_buffer: torch.Tensor,
kv_layout: str = "NHD",
use_cuda_graph: bool = False,
qo_indptr_buf: Optional[torch.Tensor] = None,
kv_indptr_buf: Optional[torch.Tensor] = None,
custom_mask_buf: Optional[torch.Tensor] = None,
mask_indptr_buf: Optional[torch.Tensor] = None,
backend: str = "auto",
jit_args: Optional[List[Any]] = None,
jit_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
r"""Constructor of :class:`BatchPrefillWithRaggedKVCacheWrapper`.
Parameters
----------
float_workspace_buffer : torch.Tensor
The user reserved float workspace buffer used to store intermediate attention results
in the split-k algorithm. The recommended size is 128MB, the device of the workspace
buffer should be the same as the device of the input tensors.
kv_layout : str
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
use_cuda_graph : bool
Whether to enable CUDA graph capture for the prefill kernels, if enabled, the
auxiliary data structures will be stored as the provided buffers.
qo_indptr_buf : Optional[torch.Tensor]
The user reserved GPU buffer to store the ``qo_indptr`` array, the size of the buffer
should be ``[batch_size + 1]``.
This argument is only effective when ``use_cuda_graph`` is ``True``.
kv_indptr_buf : Optional[torch.Tensor]
The user reserved GPU buffer to store the ``kv_indptr`` array, the size of the buffer
should be ``[batch_size + 1]``.
This argument is only effective when ``use_cuda_graph`` is ``True``.
custom_mask_buf : Optional[torch.Tensor]
The user reserved GPU buffer to store the custom mask tensor, should be large
enough to store the maximum possible size of the packed custom mask tensor during the
lifetime of the wrapper. This argument is only effective when ``use_cuda_graph``
is ``True`` and custom mask will be used in attention computation.
mask_indptr_buf : Optional[torch.Tensor]
The user reserved GPU buffer to store the ``mask_indptr`` array, the size of the buffer
should be ``[batch_size]``.
This argument is only effective when ``use_cuda_graph`` is ``True`` and custom mask
will be used in attention computation.
backend : str
The implementation backend, could be ``auto``/``fa2``/``fa3`` or ``trtllm-gen``.
Defaults to ``auto``.
If set to ``auto``, the wrapper will automatically choose the backend based on the
device architecture and kernel availability.
jit_args : Optional[List[Any]]
If provided, the wrapper will use the provided arguments to create the JIT module,
otherwise, the wrapper will use default attention implementation.
jit_kwargs : Optional[Dict[str, Any]]
The keyword arguments to create the JIT module, defaults to None.
"""
_check_kv_layout(kv_layout)
if jit_args is not None:
if jit_kwargs is None:
jit_kwargs = {}
self._jit_module = get_batch_prefill_jit_module(
jit_args[0],
get_customize_batch_prefill_module(backend, *jit_args, **jit_kwargs),
)
else:
self._jit_module = None
self._kv_layout = kv_layout
self._float_workspace_buffer = float_workspace_buffer
self.device = float_workspace_buffer.device
self._int_workspace_buffer = torch.empty(
(8 * 1024 * 1024,), dtype=torch.uint8, device=self.device
)
self._pin_memory_int_workspace_buffer = torch.empty(
self._int_workspace_buffer.shape,
dtype=torch.uint8,
pin_memory=True,
device="cpu",
)
self._use_cuda_graph = use_cuda_graph
if use_cuda_graph:
if not torch.is_tensor(qo_indptr_buf):
raise ValueError(
"qo_indptr_buf should be a torch.Tensor in cuda graph mode"
)
if not torch.is_tensor(kv_indptr_buf):
raise ValueError(
"kv_indptr_buf should be a torch.Tensor in cuda graph mode"
)
self._fixed_batch_size = len(qo_indptr_buf) - 1
if len(kv_indptr_buf) != self._fixed_batch_size + 1:
raise ValueError(
"The length of kv_indptr_buf ({}) should be the same as qo_indptr_buf ({}).".format(
len(kv_indptr_buf), self._fixed_batch_size
)
)
# NOTE(Zihao): do not check custom_mask_buf and mask_indptr_buf here,
# as they may not be used.
self._qo_indptr_buf = qo_indptr_buf
self._kv_indptr_buf = kv_indptr_buf
self._custom_mask_buf = custom_mask_buf
self._mask_indptr_buf = mask_indptr_buf
self._max_total_num_rows = None
self._backend = backend
self._cached_module = None
@property
def is_cuda_graph_enabled(self) -> bool:
return self._use_cuda_graph
def reset_workspace_buffer(
self, float_workspace_buffer: torch.Tensor, int_workspace_buffer
) -> None:
r"""Reset the workspace buffer.
Parameters
----------
float_workspace_buffer : torch.Tensor
The new float workspace buffer, the device of the new float workspace buffer should
be the same as the device of the input tensors.
int_workspace_buffer : torch.Tensor
The new int workspace buffer, the device of the new int workspace buffer should
be the same as the device of the input tensors.
"""
self._float_workspace_buffer = float_workspace_buffer
self._int_workspace_buffer = int_workspace_buffer
self._pin_memory_int_workspace_buffer = torch.empty(
self._int_workspace_buffer.shape,
dtype=self._int_workspace_buffer.dtype,
device="cpu",
pin_memory=True,
)
def plan(
self,
qo_indptr: torch.Tensor,
kv_indptr: torch.Tensor,
num_qo_heads: int,
num_kv_heads: int,
head_dim_qk: int,
head_dim_vo: Optional[int] = None,
custom_mask: Optional[torch.Tensor] = None,
packed_custom_mask: Optional[torch.Tensor] = None,
causal: bool = False,
pos_encoding_mode: str = "NONE",
use_fp16_qk_reduction: bool = False,
window_left: int = -1,
logits_soft_cap: Optional[float] = None,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
q_data_type: Union[str, torch.dtype] = "float16",
kv_data_type: Optional[Union[str, torch.dtype]] = None,
non_blocking: bool = True,
prefix_len_ptr: Optional[torch.Tensor] = None,
token_pos_in_items_ptr: Optional[torch.Tensor] = None,
token_pos_in_items_len: int = 0,
max_item_len_ptr: Optional[torch.Tensor] = None,
) -> None:
r"""Plan batch prefill/append attention on Ragged KV-Cache for given problem specification.
Parameters
----------
qo_indptr : torch.Tensor
The indptr of the query/output tensor, shape: ``[batch_size + 1]``.
kv_indptr : torch.Tensor
The indptr of the key/value tensor, shape: ``[batch_size + 1]``.
num_qo_heads : int
The number of query/output heads.
num_kv_heads : int
The number of key/value heads.
head_dim_qk : int
The dimension of the heads on query/key tensor.
head_dim_vo : Optional[int]
The dimension of the heads on value/output tensor.
If not provided, will be set to ``head_dim_vo``.
custom_mask : Optional[torch.Tensor]
The flattened boolean mask tensor, shape: ``(sum(q_len[i] * k_len[i] for i in range(batch_size))``.
The elements in the mask tensor should be either ``True`` or ``False``,
where ``False`` means the corresponding element in the attention matrix will be
masked out.
Please refer to the :ref:`mask layout <mask-layout>` for more details about flattened
layout of mask tensor.
When :attr:`custom_mask` is provided, and :attr:`packed_custom_mask` is not, the
function will pack the custom mask tensor into a 1D packed mask tensor, which introduces
additional overhead.
packed_custom_mask : Optional[torch.Tensor]
The 1D packed uint8 mask tensor, if provided, the :attr:`custom_mask` will be ignored.
The packed mask tensor is generated by :func:`flashinfer.quantization.packbits`.
If provided, the custom mask will be added to the attention matrix before softmax
and after scaling. The mask tensor should be in the same device as the input tensors.
causal : bool
Whether to apply causal mask to the attention matrix.
This argument is ignored if ``mask`` is provided in :meth:`plan`.
pos_encoding_mode : str
The position encoding applied inside attention kernels, could be
``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``.
Default is ``NONE``.
use_fp16_qk_reduction : bool
Whether to use f16 for qk reduction (faster at the cost of slight precision
loss).
window_left : int
The left (inclusive) window size for the attention window, when set to ``-1``, the window
size will be set to the full length of the sequence. Defaults to ``-1``.
logits_soft_cap : Optional[float]
The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not
provided, will be set to ``0``. If greater than 0, the logits will be capped according to
formula:
:math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`,
where :math:`x` is the input logits.
sm_scale : Optional[float]
The scale used in softmax, if not provided, will be set to
``1.0 / sqrt(head_dim_qk)``.
rope_scale : Optional[float]
The scale used in RoPE interpolation, if not provided, will be set to
``1.0``.
rope_theta : Optional[float]
The theta used in RoPE, if not provided, will be set to ``1e4``.
q_data_type : Union[str, torch.dtype]
The data type of the query tensor, defaults to torch.float16.
kv_data_type : Optional[Union[str, torch.dtype]]
The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
non_blocking : bool
Whether to copy the input tensors to the device asynchronously, defaults to ``True``.
prefix_len_ptr :Optional[torch.Tensor]
prefix length. A uint32 1D tensor indicating the prefix length of each prompt. The tensor size is equal to the batch size.
token_pos_in_items_ptr : Optional[float]
A uint16 1D tensor (it will be converted to uint16 in flashinfer) indicating the token position of each item and started from 0 (delimiter)
for each item. E.g., if we have 3 items of length 3, 2, 4 respectively for this member. This vector will be looking like
`[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0]` with 4 delimiters indexed as 0. For batch size > 1,
we will concat them as 1D with zero paddings to make sure each has the same length, the padding length is defined by
`token_pos_in_items_len` - length of the raw `token_pos_in_items_ptr` for each prompt.
token_pos_in_items_len : int
zero padding length for `token_pos_in_items_ptr` to better handle the bsz > 1 case. Still using the above 3,2,4 example.
If we set `token_pos_in_items_len` to be 20, it will be `[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0]`
with 7 padded zeros. (note there're 8 zeros in the end where the first one is the delimiter token 0 in the end of the prompt)
max_item_len_ptr : Optional[float]
a uint16 vector contains the max token length of all items for each prompt
Note
----
The :meth:`plan` method should be called before any :meth:`run` or
:meth:`run_return_lse` calls, auxiliary data structures will be created
during this plan call and cached for multiple kernel runs.
The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads``
is not equal to ``num_kv_heads``, the function will use
`grouped query attention <https://arxiv.org/abs/2305.13245>`_.
The :meth:`plan` method cannot be used in Cuda Graph or in ``torch.compile``.
"""
q_data_type = canonicalize_torch_dtype(q_data_type)
if kv_data_type is None:
kv_data_type = q_data_type
kv_data_type = canonicalize_torch_dtype(kv_data_type)
if head_dim_vo is None:
head_dim_vo = head_dim_qk
if logits_soft_cap is None:
logits_soft_cap = 0.0
batch_size = len(qo_indptr) - 1
if len(kv_indptr) != batch_size + 1:
raise ValueError(
"The kv_indptr length should be equal to mask_indptr length."
)
if custom_mask is not None or packed_custom_mask is not None:
mask_indptr = _compute_mask_indptr(qo_indptr, kv_indptr)
if packed_custom_mask is None and custom_mask is not None:
# create packed custom mask from custom mask
packed_custom_mask, mask_indptr = segment_packbits(
custom_mask.contiguous().view(-1),
mask_indptr,
bitorder="little",
)
# NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
qo_indptr_host = qo_indptr.to("cpu")
kv_indptr_host = kv_indptr.to("cpu")
total_num_rows = qo_indptr_host[-1]
if self.is_cuda_graph_enabled:
if self._max_total_num_rows is None:
self._max_total_num_rows = total_num_rows
elif total_num_rows > self._max_total_num_rows:
raise ValueError(
"The total number of rows in qo_indptr {} in cuda graph mode cannot "
"exceed the number of rows set during initialization {}.".format(
total_num_rows, self._max_total_num_rows
)
)
if batch_size != self._fixed_batch_size:
raise ValueError(
"The batch size should be fixed in cudagraph mode, the runtime batch size {} "
" mismatches the batch size set during initialization {}.".format(
batch_size, self._fixed_batch_size
)
)
self._qo_indptr_buf.copy_(qo_indptr, non_blocking=non_blocking)
self._kv_indptr_buf.copy_(kv_indptr, non_blocking=non_blocking)
if packed_custom_mask is not None:
if not torch.is_tensor(self._custom_mask_buf):
raise ValueError(
"custom_mask_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation."
)
if not torch.is_tensor(self._mask_indptr_buf):
raise ValueError(
"mask_indptr_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in the attention computation."
)
self._custom_mask_buf[: len(packed_custom_mask)] = packed_custom_mask
self._mask_indptr_buf.copy_(mask_indptr, non_blocking=non_blocking)
else:
self._qo_indptr_buf = qo_indptr.to(self.device, non_blocking=non_blocking)
self._kv_indptr_buf = kv_indptr.to(self.device, non_blocking=non_blocking)
if packed_custom_mask is not None:
self._custom_mask_buf = packed_custom_mask.to(
self.device, non_blocking=non_blocking
)
self._mask_indptr_buf = mask_indptr.to(
self.device, non_blocking=non_blocking
)
self._cached_q_data_type = q_data_type
self._cached_kv_data_type = kv_data_type
kv_len_arr = kv_indptr_host[1:] - kv_indptr_host[:-1]
self._prefix_len_ptr = prefix_len_ptr
self._token_pos_in_items_ptr = token_pos_in_items_ptr
self._token_pos_in_items_len = token_pos_in_items_len
self._max_item_len_ptr = max_item_len_ptr
if self._jit_module is not None:
self._cached_module = self._jit_module
else:
if self._backend == "auto":
self._backend = determine_attention_backend(
self.device,
PosEncodingMode[pos_encoding_mode].value,
use_fp16_qk_reduction,
self._custom_mask_buf is not None, # use_custom_mask
q_data_type,
kv_data_type,
)
get_module_args = (
q_data_type,
kv_data_type,
q_data_type,
kv_indptr.dtype,
head_dim_qk,
head_dim_vo,
PosEncodingMode[pos_encoding_mode].value,
window_left >= 0, # use_sliding_window
logits_soft_cap > 0, # use_logits_soft_cap
use_fp16_qk_reduction,
)
if self._backend == "cutlass":
# insert qo_indptr.device to 9th position (0-indexed) of get_module_args
new_get_module_args = (
get_module_args[:9] + (qo_indptr.device,) + get_module_args[9:]
)
self._cached_module = get_fmha_module(*new_get_module_args)
else:
self._cached_module = get_batch_prefill_module(
self._backend, *get_module_args
)
if self._backend == "cutlass":
self._plan_info = fmha_varlen_plan(
self._cached_module, qo_indptr, kv_indptr, num_qo_heads, causal
)
self._max_qo_len = torch.max(qo_indptr[1:] - qo_indptr[:-1]).item()
else:
assert self._cached_module is not None, "cached module is not initialized"
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_host,
kv_indptr_host,
kv_len_arr,
self._max_total_num_rows or total_num_rows,
batch_size,
num_qo_heads,
num_kv_heads,
1, # page_size
self.is_cuda_graph_enabled,
head_dim_qk,
head_dim_vo,
causal,
)
self._causal = causal
self._pos_encoding_mode = pos_encoding_mode
self._use_fp16_qk_reduction = use_fp16_qk_reduction
self._window_left = window_left
self._logits_soft_cap = logits_soft_cap
self._sm_scale = sm_scale
self._rope_scale = rope_scale
self._rope_theta = rope_theta
begin_forward = plan
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
causal: bool = False,
pos_encoding_mode: str = "NONE",
use_fp16_qk_reduction: bool = False,
window_left: int = -1,
logits_soft_cap: Optional[float] = None,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
) -> torch.Tensor:
r"""Warning: This function is deprecated, please use :meth:`run` instead."""
self._causal = causal
self._pos_encoding_mode = pos_encoding_mode
self._use_fp16_qk_reduction = use_fp16_qk_reduction
self._window_left = window_left
self._logits_soft_cap = logits_soft_cap
self._sm_scale = sm_scale
self._rope_scale = rope_scale
self._rope_theta = rope_theta
return self.run(q, k, v)
@overload
def run(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
*args,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
return_lse: Literal[False] = False,
enable_pdl: Optional[bool] = None,
) -> torch.Tensor: ...
@overload
def run(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
*args,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
return_lse: Literal[True] = True,
enable_pdl: Optional[bool] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ...
def run(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
*args,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
return_lse: bool = False,
enable_pdl: Optional[bool] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
r"""Compute batch prefill/append attention between query and kv-cache stored as
ragged tensor.
Parameters
----------
q : torch.Tensor
The query tensor, shape: ``[qo_indptr[-1], num_qo_heads, head_dim_qk]``
k : torch.Tensor
The key tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim_qk]``
v : torch.Tensor
The value tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim_vo]``
*args
Additional arguments for the custom kernel.
out : Optional[torch.Tensor]
The output tensor, if not provided, will be allocated internally.
lse : Optional[torch.Tensor]
The log-sum-exp of attention logits, if not provided, will be allocated internally.
return_lse : bool
Whether to return the logsumexp of attention output
enable_pdl : bool
Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization
Only supported for >= sm90, and currently only for FA2 and CUDA core decode.
Returns
-------
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
If :attr:`return_lse` is ``False``, the attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim_vo]``.
If :attr:`return_lse` is ``True``, a tuple of two tensors:
* The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim_vo]``.
* The logsumexp of attention output, shape: ``[qo_indptr[-1], num_qo_heads]``.
"""
if enable_pdl is None:
enable_pdl = device_support_pdl(q.device)
_check_cached_qkv_data_type(
q, k, self._cached_q_data_type, self._cached_kv_data_type
)
window_left = self._window_left
logits_soft_cap = self._logits_soft_cap
sm_scale = self._sm_scale
rope_scale = self._rope_scale
rope_theta = self._rope_theta
if logits_soft_cap is None:
logits_soft_cap = 0.0
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(q.size(-1))
if rope_scale is None:
rope_scale = 1.0
if rope_theta is None:
rope_theta = 1e4
if return_lse:
if lse is None:
lse = torch.empty(
(q.size(0), q.size(1)), dtype=torch.float32, device=q.device
)
else:
check_shape_dtype_device(
lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
)
if out is None:
out = torch.empty(
q.shape[:-1] + v.shape[-1:], dtype=q.dtype, device=q.device
)
else:
check_shape_dtype_device(
out, q.shape[:-1] + v.shape[-1:], q.dtype, q.device, "out"
)
if self._backend == "cutlass":
out, lse = fmha_varlen(
q,
k,
v,
self._qo_indptr_buf,
self._kv_indptr_buf,
plan_info=self._plan_info,
causal=self._causal,
sm_scale=sm_scale,
max_qo_len=self._max_qo_len,
out=out,
lse=lse,
)
return (out, lse) if return_lse else out
if is_float8(q):
logging.warning(
"Our current prefill kernel implementation needs f16 input, the f8 inputs "
" are casted to f16, which could result in performance degradation."
)
q = q.to(torch.float16)
k = k.to(torch.float16)
v = v.to(torch.float16)
if self._custom_mask_buf is not None:
mask_mode = MaskMode.CUSTOM.value
else:
if self._causal:
mask_mode = MaskMode.CAUSAL.value
else:
mask_mode = MaskMode.NON_CAUSAL.value
run_args = [
self._float_workspace_buffer,
self._int_workspace_buffer,
self._plan_info,
q,
k,
v,
self._qo_indptr_buf,
self._kv_indptr_buf,
out,
lse,
mask_mode,
TensorLayout[self._kv_layout].value,
window_left,
enable_pdl,
]
if self._jit_module is not None:
run_args.extend(list(args))
else:
run_args += [
self._custom_mask_buf,
self._mask_indptr_buf,
_get_cache_alibi_slopes_buf(q.shape[1], self.device),
self._prefix_len_ptr,
self._token_pos_in_items_ptr,
self._max_item_len_ptr,
logits_soft_cap,
sm_scale,
rope_scale,
rope_theta,
self._token_pos_in_items_len,
]
assert self._cached_module is not None, "cached module is not initialized"
self._cached_module.ragged_run(*run_args)
return (out, lse) if return_lse else out
run_return_lse = functools.partialmethod(run, return_lse=True)
def forward_return_lse(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
causal: bool = False,
pos_encoding_mode: str = "NONE",
use_fp16_qk_reduction: bool = False,
window_left: int = -1,
logits_soft_cap: Optional[float] = None,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Warning: This function is deprecated, please use :meth:`run_return_lse` instead."""
self._causal = causal
self._pos_encoding_mode = pos_encoding_mode
self._use_fp16_qk_reduction = use_fp16_qk_reduction
self._window_left = window_left
self._logits_soft_cap = logits_soft_cap
self._sm_scale = sm_scale
self._rope_scale = rope_scale
self._rope_theta = rope_theta
return self.run_return_lse(q, k, v)
def end_forward(self) -> None:
r"""Warning: this function is deprecated and has no effect."""
pass
def fmha_varlen_plan(
module,
qo_segment_offsets: torch.Tensor,
kv_segment_offsets: torch.Tensor,
num_qo_heads: int,
causal: bool,
):
num_ctas = torch.cuda.get_device_properties(
qo_segment_offsets.device
).multi_processor_count
work_indptr = torch.empty(
num_ctas + 1, device=qo_segment_offsets.device, dtype=torch.int32
)
qo_tile_indices = torch.empty(
131072, device=qo_segment_offsets.device, dtype=torch.int32
)
head_indices = torch.empty(
131072, device=qo_segment_offsets.device, dtype=torch.int32
)
batch_indices = torch.empty(
131072, device=qo_segment_offsets.device, dtype=torch.int32
)
module.plan(
qo_segment_offsets,
kv_segment_offsets,
work_indptr,
qo_tile_indices,
head_indices,
batch_indices,
256, # qo_tile_size
num_qo_heads,
num_ctas,
causal,
)
return (
work_indptr,
qo_tile_indices,
head_indices,
batch_indices,
)
@overload
def fmha_varlen(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
qo_segment_offsets: torch.Tensor,
kv_segment_offsets: torch.Tensor,
plan_info: Optional[List[torch.Tensor]] = None,
max_qo_len: Optional[int] = None,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
causal: bool = False,
sm_scale: Optional[float] = None,
return_lse: Literal[False] = False,
) -> torch.Tensor: ...
@overload
def fmha_varlen(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
qo_segment_offsets: torch.Tensor,
kv_segment_offsets: torch.Tensor,
plan_info: Optional[List[torch.Tensor]] = None,
max_qo_len: Optional[int] = None,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
causal: bool = False,
sm_scale: Optional[float] = None,
return_lse: Literal[True] = True,
) -> Tuple[torch.Tensor, torch.Tensor]: ...
def fmha_varlen(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
qo_segment_offsets: torch.Tensor,
kv_segment_offsets: torch.Tensor,
plan_info: Optional[List[torch.Tensor]] = None,
max_qo_len: Optional[int] = None,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
causal: bool = False,
sm_scale: Optional[float] = None,
return_lse: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
workspace_buffer = _get_cache_buf(
"fmha_varlen_cutlass_workspace", 32 * 1024 * 1024, q.device
)
module = get_fmha_module(
q.dtype,
k.dtype,
v.dtype,
torch.int32,
q.shape[2],
v.shape[2],
PosEncodingMode.NONE.value,
False, # use_sliding_window
False, # use_logits_soft_cap
q.device,
)
nnz_qo, num_qo_heads, head_dim_qk = q.shape
nnz_kv, num_kv_heads, head_dim_vo = v.shape
mask_mode_code = 1 if causal else 0
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(head_dim_qk)
qo_total_len = nnz_qo
if max_qo_len is None:
max_qo_len = torch.max(qo_segment_offsets[1:] - qo_segment_offsets[:-1]).item()
if plan_info is None:
plan_info = fmha_varlen_plan(
module, qo_segment_offsets, kv_segment_offsets, num_qo_heads, causal
)
(
work_indptr,
qo_tile_indices,
head_indices,
batch_indices,
) = plan_info
if out is None:
out = torch.empty(
qo_total_len + max(max_qo_len, 128),
num_qo_heads,
head_dim_vo,
device=q.device,
dtype=q.dtype,
)[max(max_qo_len, 128) :]
if lse is None and return_lse:
lse = torch.empty(
qo_total_len, num_qo_heads, device=q.device, dtype=torch.float32
)
module.run(
workspace_buffer,
q,
k,
v,
qo_segment_offsets,
kv_segment_offsets,
work_indptr,
qo_tile_indices,
head_indices,
batch_indices,
out,
lse,
mask_mode_code,
sm_scale,
num_qo_heads,
num_kv_heads,
head_dim_qk,
head_dim_vo,
max_qo_len,
)
return out, lse
@functools.cache
def get_trtllm_gen_fmha_module():
mod = gen_trtllm_gen_fmha_module()
op = mod.build_and_load()
setup_cubin_loader(mod.get_library_path())
return op
def trtllm_ragged_attention_deepseek(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
workspace_buffer: torch.Tensor,
seq_lens: torch.Tensor,
max_q_len: int,
max_kv_len: int,
bmm1_scale: float,
bmm2_scale: float,
o_sf_scale: float,
batch_size: int,
window_left: int,
cum_seq_lens_q: torch.Tensor,
cum_seq_lens_kv: torch.Tensor,
enable_pdl: bool,
is_causal: bool,
return_lse: bool,
attention_sinks: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Parameters
----------
query : torch.Tensor
query tensor with shape [num_tokens, num_heads, head_dim]
key : torch.Tensor
key tensor with shape [num_tokens, num_heads, head_dim]
value : torch.Tensor
value tensor with shape [num_tokens, num_heads, head_dim]
workspace_buffer : torch.Tensor
workspace buffer
seq_lens : torch.Tensor
sequence lengths
max_q_len : int
max query length
max_kv_len : int
max key/value length
bmm1_scale : float
scale for bmm1, scale_q * scale_k * 1.0 / (head_dim_qk ** 0.5)
bmm2_scale : float
scale for bmm2, scale_v
o_sf_scale : float
scale for output
batch_size : int
batch size
window_left : int
window left
cum_seq_lens_q : torch.Tensor
cumulative sequence lengths for query
cum_seq_lens_kv : torch.Tensor
cumulative sequence lengths for key/value
enable_pdl : bool
enable pdl
is_causal : bool
is causal
attention_sinks : Optional[torch.Tensor]
attention sinks
out : Optional[torch.Tensor]
output tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1], value.shape[2]]
lse : Optional[torch.Tensor]
lse tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1]]
Returns
-------
out: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
output torch.Tensor or Tuple[torch.Tensor, torch.Tensor].
If return_lse is True, the output will be a tuple of two tensors, the first is the output tensor, the second is the lse tensor.
If return_lse is False, the output will be a single tensor.
"""
assert query.shape[2] == 192 and key.shape[2] == 192 and value.shape[2] == 128, (
"currently only support deepseek r1 192 query and 128 value"
)
if enable_pdl is None:
enable_pdl = device_support_pdl(query.device)
run_func = get_trtllm_gen_fmha_module().trtllm_ragged_attention
sm_count = get_device_sm_count(query.device)
if out is None:
out = torch.empty(
query.shape[0],
query.shape[1],
value.shape[2],
device=query.device,
dtype=query.dtype,
)
if return_lse and lse is None:
lse = torch.empty(
query.shape[0],
query.shape[1],
device=query.device,
dtype=torch.float32,
)
workspace_size = workspace_buffer.numel() * workspace_buffer.element_size()
run_func(
out,
query,
key,
value,
workspace_buffer,
seq_lens,
max_q_len,
max_kv_len,
bmm1_scale,
bmm2_scale,
o_sf_scale,
batch_size,
window_left,
cum_seq_lens_q,
cum_seq_lens_kv,
sm_count,
enable_pdl,
is_causal,
workspace_size,
attention_sinks,
lse,
)
if return_lse:
return out, lse
else:
return out
def trtllm_batch_context_with_kv_cache(
query: torch.Tensor,
kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
workspace_buffer: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_q_len: int,
max_kv_len: int,
bmm1_scale: float,
bmm2_scale: float,
batch_size: int,
cum_seq_lens_q: torch.Tensor,
cum_seq_lens_kv: torch.Tensor,
window_left: int = -1,
out: Optional[Union[torch.Tensor, FP4Tensor]] = None,
out_dtype: Optional[Union[torch.dtype, str]] = None,
o_sf_scale: Optional[float] = None,
o_sf_vec_size: Optional[int] = None,
enable_pdl: Optional[bool] = None,
sinks: Optional[List[torch.Tensor]] = None,
) -> Union[torch.Tensor, FP4Tensor]:
"""
Parameters
----------
query : torch.Tensor
query tensor with shape [num_tokens, num_heads, head_dim]
kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, num_kv_heads, page_size, head_dim]
If kv_cache is a tuple of two tensors, it should be a tuple of two tensors with shape [num_pages, num_kv_heads, page_size, head_dim]
workspace_buffer : torch.Tensor. Must be initialized to 0 for its first use.
workspace
block_tables : torch.Tensor
page_table of kv cache, [batch_size, num_pages]
seq_lens : torch.Tensor
A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]``
max_q_len : int
max sequence length for query
max_kv_len : int
max sequence length for kv_cache
bmm1_scale : float
fused scale for bmm1 input.
bmm2_scale : float
fused scale for bmm2 input.
batch_size : int
batch size
cum_seq_lens_q : torch.Tensor
cumulative sequence length for query. shape: ``[batch_size + 1]``
cum_seq_lens_kv : torch.Tensor
cumulative sequence length for kv_cache. shape: ``[batch_size + 1]``
window_left : int = -1
The left (inclusive) window size for the attention window, when set to ``-1``, the window
size will be set to the full length of the sequence. Defaults to ``-1``.
out : Optional[Union[torch.Tensor, FP4Tensor]] = None
output tensor, if not provided, will be allocated with ``out_dtype``, if ``out_dtype`` is not provided, will use the type of ``query``.
out_dtype : Optional[Union[torch.dtype, str]] = None
output dtype, if not provided, will use the type of ``out``. For nvfp4, use string ``nvfp4``.
o_sf_scale : Optional[float] = None
scale for nvfp4 output tensor scale factor.
o_sf_vec_size : Optional[int] = None
vector size for nvfp4 output tensor scale factor.
sinks : Optional[List[torch.Tensor]] = None
additional value per head in the denominator of the softmax.
Returns
-------
out: Union[torch.Tensor, FP4Tensor]
output torch.Tensor or FP4Tensor.
"""
if enable_pdl is None:
enable_pdl = device_support_pdl(query.device)
if isinstance(kv_cache, tuple):
k_cache, v_cache = kv_cache
else:
if kv_cache.shape[1] == 1:
k_cache, v_cache = kv_cache, kv_cache
else:
assert kv_cache.shape[1] == 2, (
"When kv_cache is a single tensor, the second dimension must be 1 or 2"
)
# NOTE(Zihao): unbind transforms [num_pages, 2, ...] to ([num_pages, ...], [num_pages, ...])
# it doesn't change underlying storage
k_cache, v_cache = kv_cache.unbind(dim=1)
run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_context
sm_count = get_device_sm_count(query.device)
if out_dtype == "nvfp4" or (out_dtype is None and isinstance(out, FP4Tensor)):
assert query.dtype == torch.float8_e4m3fn, (
"query must be fp8 when out_dtype is nvfp4."
)
assert o_sf_scale is not None
assert o_sf_vec_size in [None, 16], "only o_sf_vec_size = 16 is supported"
o_sf_vec_size = o_sf_vec_size or 16
fp4_out_shape = query.shape[:-1] + (ceil_div(query.shape[-1], 2),)
if isinstance(out, FP4Tensor):
fp4_out_scale_shape = (
out.scale.shape[0],
round_up(query.shape[1] * query.shape[2] // o_sf_vec_size, 4),
)
out_scale_factor = out.scale
o_sf_start_index = out.scale_start_index
out = out.data
# out_dtype may be None
out_dtype = out_dtype or "nvfp4"
elif out is None:
fp4_out_scale_shape = (
round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // o_sf_vec_size, 4),
)
out_scale_factor = torch.empty(
fp4_out_scale_shape, dtype=torch.float8_e4m3fn, device=query.device
)
o_sf_start_index = 0
out = torch.empty(fp4_out_shape, dtype=torch.uint8, device=query.device)
else:
raise ValueError(f"Invalid out: {out}")
assert out_dtype == "nvfp4"
assert isinstance(out, torch.Tensor)
# Use uint8 as the container dtype to compliant with next fp4 gemm.
check_shape_dtype_device(out, fp4_out_shape, torch.uint8, query.device, "out")
check_shape_dtype_device(
out_scale_factor,
fp4_out_scale_shape,
torch.float8_e4m3fn,
query.device,
"out_scale_factor",
)
# Check o_sf_start_index is valid
if (
o_sf_start_index < 0
or o_sf_start_index + out.shape[0] > out_scale_factor.shape[0]
):
raise ValueError(
f"o_sf_start_index is out of the valid range of out_scale_factor. "
f"o_sf_start_index={o_sf_start_index}, out.shape[0]={out.shape[0]}, "
f"out_scale_factor.shape[0]={out_scale_factor.shape[0]}"
)
elif isinstance(out_dtype, torch.dtype) or out_dtype is None:
assert o_sf_scale is None
assert o_sf_vec_size is None
out_scale_factor = None
o_sf_start_index = 0
if out_dtype is None:
out_dtype = out.dtype if out is not None else query.dtype
out = out if out is not None else torch.empty_like(query, dtype=out_dtype)
if out_dtype not in (query.dtype, torch.float16, torch.bfloat16):
raise ValueError(f"Unsupported out_dtype: {out_dtype}")
check_shape_dtype_device(out, query.shape, out_dtype, query.device, "out")
else:
raise ValueError(f"Invalid out_dtype: {out_dtype}")
workspace_size = workspace_buffer.numel() * workspace_buffer.element_size()
run_func(
out,
out_scale_factor,
query,
k_cache,
v_cache,
workspace_buffer,
block_tables,
seq_lens,
max_q_len,
max_kv_len,
bmm1_scale,
bmm2_scale,
o_sf_scale or -1.0,
o_sf_vec_size or -1,
o_sf_start_index,
batch_size,
window_left,
cum_seq_lens_q,
cum_seq_lens_kv,
sm_count,
enable_pdl,
workspace_size,
sinks,
)
return (
out
if out_dtype != "nvfp4"
else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape)
)