sglang0.4.5.post1/python/sglang/srt/layers/attention/flashinfer_mla_backend.py

843 lines
30 KiB
Python

from __future__ import annotations
"""
Support attention backend for flashinfer MLA.
The flashinfer_mla_disable_ragged flag controls whether to use ragged prefill wrapper and defaults to be false.
When it's set to false, all wrappers are BatchMLAPaged wrapper.
When it's set to true, the backend uses BatchRagged and BatchMLAPaged wrapper for prefilling,
and uses BatchMLAPaged wrapper for decoding.
More details can be found in https://docs.flashinfer.ai/api/mla.html
"""
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Callable, Optional, Union
import torch
import triton
from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import is_flashinfer_available
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
if is_flashinfer_available():
from flashinfer import (
BatchMLAPagedAttentionWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
)
@dataclass
class DecodeMetadata:
decode_wrapper: BatchMLAPagedAttentionWrapper
@dataclass
class PrefillMetadata:
prefill_wrapper: BatchMLAPagedAttentionWrapper
use_ragged: bool
# Reuse this workspace buffer across all flashinfer wrappers
global_workspace_buffer = None
class FlashInferMLAAttnBackend(AttentionBackend):
"""Flashinfer attention kernels."""
def __init__(
self,
model_runner: ModelRunner,
skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None,
q_indptr_decode_buf: Optional[torch.Tensor] = None,
):
super().__init__()
# Parse constants
self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device
self.skip_prefill = skip_prefill
global_config.enable_flashinfer_mla = True
# Allocate buffers
global global_workspace_buffer
if global_workspace_buffer is None:
global_workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
device=model_runner.device,
)
self.workspace_buffer = global_workspace_buffer
max_bs = model_runner.req_to_token_pool.size
if kv_indptr_buf is None:
self.kv_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
else:
self.kv_indptr = kv_indptr_buf
if not self.skip_prefill:
self.qo_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
if q_indptr_decode_buf is None:
self.q_indptr_decode = torch.arange(
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
)
else:
self.q_indptr_decode = q_indptr_decode_buf
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD"
)
if not self.skip_prefill:
self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
backend="auto",
)
# FlashinferMLA backend uses mla wrapper for target verify
self.prefill_wrapper_verify = BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
backend="auto",
)
self.decode_wrapper = BatchMLAPagedAttentionWrapper(
self.workspace_buffer, backend="auto"
)
# Create indices updater
if not skip_prefill:
self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(
model_runner, self
)
self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode(
model_runner, self
)
# Other metadata
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
self.decode_cuda_graph_metadata = {}
self.prefill_cuda_graph_metadata = {} # For verify
def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode_or_idle():
self.indices_updater_decode.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
decode_wrapper=self.decode_wrapper,
init_metadata_replay=False,
)
self.forward_metadata = DecodeMetadata(self.decode_wrapper)
elif forward_batch.forward_mode.is_draft_extend():
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
prefill_wrapper_paged=self.prefill_wrapper_paged,
use_ragged=False,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = PrefillMetadata(self.prefill_wrapper_paged, False)
elif forward_batch.forward_mode.is_target_verify():
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
prefill_wrapper_paged=self.prefill_wrapper_verify,
use_ragged=False,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = PrefillMetadata(self.prefill_wrapper_verify, False)
else:
prefix_lens = forward_batch.extend_prefix_lens
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
use_ragged = (
not global_server_args_dict["flashinfer_mla_disable_ragged"]
and extend_no_prefix
)
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens,
prefill_wrapper_paged=self.prefill_wrapper_paged,
use_ragged=use_ragged,
)
self.forward_metadata = PrefillMetadata(
self.prefill_wrapper_paged, use_ragged
)
def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
):
if kv_indices_buf is None:
cuda_graph_kv_indices = torch.zeros(
(max_bs * self.max_context_len,),
dtype=torch.int32,
device="cuda",
)
else:
cuda_graph_kv_indices = kv_indices_buf
self.cuda_graph_kv_indices = cuda_graph_kv_indices
self.cuda_graph_qo_indptr = self.q_indptr_decode.clone()
self.cuda_graph_kv_indptr = self.kv_indptr.clone()
self.cuda_graph_kv_lens = torch.ones(
(max_bs,), dtype=torch.int32, device=self.device
)
# For fast decode plan in graph replaying
self.cuda_graph_qo_indptr_cpu = self.cuda_graph_qo_indptr.to("cpu")
self.cuda_graph_kv_indptr_cpu = self.cuda_graph_kv_indptr.to("cpu")
self.fast_decode_kwargs = {
"qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu,
"kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu,
"kv_indices": self.cuda_graph_kv_indices,
}
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
):
if forward_mode.is_decode_or_idle():
decode_wrapper = BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
use_cuda_graph=True,
qo_indptr=self.cuda_graph_qo_indptr[: num_tokens + 1],
kv_indptr=self.cuda_graph_kv_indptr[: num_tokens + 1],
kv_indices=self.cuda_graph_kv_indices,
kv_len_arr=self.cuda_graph_kv_lens[:num_tokens],
backend="auto",
)
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_decode.update(
req_pool_indices,
seq_lens,
seq_lens_sum,
decode_wrapper=decode_wrapper,
init_metadata_replay=False,
spec_info=spec_info,
)
self.decode_cuda_graph_metadata[bs] = decode_wrapper
self.forward_metadata = DecodeMetadata(decode_wrapper)
decode_wrapper.plan = partial(fast_mla_decode_plan, decode_wrapper)
elif forward_mode.is_target_verify():
verify_wrapper = BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
use_cuda_graph=True,
qo_indptr=self.cuda_graph_qo_indptr[: bs + 1],
kv_indptr=self.cuda_graph_kv_indptr[: bs + 1],
kv_indices=self.cuda_graph_kv_indices,
kv_len_arr=self.cuda_graph_kv_lens[:bs],
backend="auto",
)
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_prefill.update(
req_pool_indices,
seq_lens,
seq_lens_sum,
prefix_lens=None,
prefill_wrapper_paged=verify_wrapper,
use_ragged=False,
spec_info=spec_info,
)
self.prefill_cuda_graph_metadata[bs] = verify_wrapper
self.forward_metadata = PrefillMetadata(verify_wrapper, False)
else:
raise ValueError(f"Invalid mode: {forward_mode=}")
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
assert seq_lens_cpu is not None
kv_len_arr_cpu = seq_lens_cpu[:bs]
self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
kv_len_arr_cpu, dim=0
)
self.fast_decode_kwargs.update(
{
"qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu[: bs + 1],
"kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu[: bs + 1],
"kv_len_arr_cpu": kv_len_arr_cpu,
}
)
self.indices_updater_decode.update(
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_sum,
decode_wrapper=self.decode_cuda_graph_metadata[bs],
init_metadata_replay=True,
spec_info=spec_info,
**self.fast_decode_kwargs,
)
elif forward_mode.is_target_verify():
self.indices_updater_prefill.update(
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_sum,
prefix_lens=None,
prefill_wrapper_paged=self.prefill_cuda_graph_metadata[bs],
use_ragged=False,
spec_info=spec_info,
)
else:
raise ValueError(f"Invalid forward mode: {forward_mode=}")
def get_cuda_graph_seq_len_fill_value(self):
return 0
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
):
cache_loc = forward_batch.out_cache_loc
logits_soft_cap = layer.logit_cap
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
# Save kv cache
if save_kv_cache and k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
if self.forward_metadata.use_ragged:
# ragged prefill
o, _ = self.prefill_wrapper_ragged.forward_return_lse(
qall,
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
else:
# mla paged prefill
o = prefill_wrapper_paged.run(
qall[:, :, : layer.v_head_dim],
qall[:, :, layer.v_head_dim :],
k_buf[:, :, : layer.v_head_dim],
k_buf[:, :, layer.v_head_dim :],
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
):
decode_wrapper = self.forward_metadata.decode_wrapper
cache_loc = forward_batch.out_cache_loc
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
o = decode_wrapper.run(
reshaped_q[:, :, : layer.v_head_dim],
reshaped_q[:, :, layer.v_head_dim :],
reshaped_k[:, :, : layer.v_head_dim],
reshaped_k[:, :, layer.v_head_dim :],
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
class FlashInferMLAIndicesUpdaterDecode:
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
# Parse Constants
self.num_local_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.scaling = model_runner.model_config.scaling
self.data_type = model_runner.kv_cache_dtype
self.attn_backend = attn_backend
# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.q_indptr = attn_backend.q_indptr_decode
def update(
self,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
decode_wrapper: BatchMLAPagedAttentionWrapper,
init_metadata_replay: bool = False,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
**fast_decode_kwargs,
):
decode_wrapper = decode_wrapper or self.decode_wrapper
self.call_begin_forward(
decode_wrapper,
req_pool_indices,
seq_lens,
seq_lens_sum,
self.q_indptr,
self.kv_indptr,
init_metadata_replay,
spec_info,
**fast_decode_kwargs,
)
def call_begin_forward(
self,
wrapper: BatchMLAPagedAttentionWrapper,
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
q_indptr: torch.Tensor,
kv_indptr: torch.Tensor,
init_metadata_replay: bool = False,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
**fast_decode_kwargs,
):
bs = len(req_pool_indices)
q_indptr = q_indptr[: bs + 1]
kv_lens = paged_kernel_lens.to(torch.int32)
sm_scale = self.scaling
if spec_info is None:
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = (
torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
if not init_metadata_replay
else fast_decode_kwargs["kv_indices"]
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.shape[1],
)
else:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
if not init_metadata_replay:
wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
kv_lens,
self.num_local_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
1,
False,
sm_scale,
self.data_type,
self.data_type,
)
else:
wrapper.plan(
fast_decode_kwargs["qo_indptr_cpu"],
fast_decode_kwargs["kv_indptr_cpu"],
kv_indices,
fast_decode_kwargs["kv_len_arr_cpu"],
self.num_local_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
1,
False,
sm_scale,
self.data_type,
self.data_type,
)
class FlashInferMLAIndicesUpdaterPrefill:
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
# Parse Constants
self.num_local_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.v_head_dim = model_runner.model_config.v_head_dim
self.scaling = model_runner.model_config.scaling
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
self.attn_backend = attn_backend
# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.qo_indptr = attn_backend.qo_indptr
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
def update(
self,
req_pool_indices: torch.Tnesor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
prefill_wrapper_paged: BatchMLAPagedAttentionWrapper,
use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
):
if use_ragged:
paged_kernel_lens = prefix_lens
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
else:
paged_kernel_lens = seq_lens
paged_kernel_lens_sum = seq_lens_sum
self.call_begin_forward(
self.prefill_wrapper_ragged,
prefill_wrapper_paged,
req_pool_indices,
paged_kernel_lens,
paged_kernel_lens_sum,
seq_lens,
prefix_lens,
self.kv_indptr,
self.qo_indptr,
use_ragged,
spec_info,
)
def call_begin_forward(
self,
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
wrapper_paged: BatchMLAPagedAttentionWrapper,
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
seq_lens: torch.Tensor,
prefix_lens: torch.Tensor,
kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor,
use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
):
bs = len(seq_lens)
sm_scale = self.scaling
if spec_info is None:
assert len(seq_lens) == len(req_pool_indices)
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
paged_kernel_lens_sum,
dtype=torch.int32,
device=req_pool_indices.device,
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.shape[1],
)
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
custom_mask = None
else:
assert isinstance(spec_info, EagleDraftInput) or isinstance(
spec_info, EagleVerifyInput
)
# TODO: Support topk > 1 with custom mask
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
req_pool_indices,
paged_kernel_lens,
paged_kernel_lens_sum,
self.req_to_token,
)
)
if use_ragged:
# ragged prefill
wrapper_ragged.begin_forward(
qo_indptr=qo_indptr,
kv_indptr=qo_indptr,
num_qo_heads=self.num_local_heads,
num_kv_heads=self.num_local_heads,
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
head_dim_vo=self.v_head_dim,
q_data_type=self.q_data_type,
)
else:
# mla paged prefill
kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
wrapper_paged.plan(
qo_indptr,
kv_indptr,
kv_indices,
kv_len_arr,
self.num_local_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
1,
True,
sm_scale,
self.q_data_type,
self.data_type,
)
class FlashInferMLAMultiStepDraftBackend:
"""
Wrap multiple flashinfer mla attention backends as one for multiple consecutive
draft decoding steps.
"""
def __init__(
self,
model_runner: ModelRunner,
topk: int,
speculative_num_steps: int,
):
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
if topk > 1:
raise ValueError(
f"Currently Flashinfer MLA only supports topk=1 for speculative decoding"
)
self.topk = topk
self.speculative_num_steps = speculative_num_steps
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
max_bs = model_runner.req_to_token_pool.size * self.topk
self.kv_indptr = torch.zeros(
(
self.speculative_num_steps,
max_bs + 1,
),
dtype=torch.int32,
device=model_runner.device,
)
self.q_indptr_decode = torch.arange(
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
)
self.attn_backends = []
for i in range(self.speculative_num_steps):
self.attn_backends.append(
FlashInferMLAAttnBackend(
model_runner,
skip_prefill=True,
kv_indptr_buf=self.kv_indptr[i],
q_indptr_decode_buf=self.q_indptr_decode,
)
)
self.max_context_len = self.attn_backends[0].max_context_len
# Cached variables for generate_draft_decode_kv_indices
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
def common_template(
self,
forward_batch: ForwardBatch,
kv_indices_buffer: torch.Tensor,
call_fn: Callable,
):
num_seqs = forward_batch.batch_size
bs = self.topk * num_seqs
seq_lens_sum = forward_batch.seq_lens_sum
self.generate_draft_decode_kv_indices[
(self.speculative_num_steps, num_seqs, self.topk)
](
forward_batch.req_pool_indices,
forward_batch.req_to_token_pool.req_to_token,
forward_batch.seq_lens,
kv_indices_buffer,
self.kv_indptr,
forward_batch.positions,
num_seqs,
self.topk,
self.pool_len,
kv_indices_buffer.shape[1],
self.kv_indptr.shape[1],
triton.next_power_of_2(num_seqs),
triton.next_power_of_2(self.speculative_num_steps),
triton.next_power_of_2(bs),
)
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
for i in range(self.speculative_num_steps - 1):
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
: seq_lens_sum * self.topk + bs * (i + 1)
]
call_fn(i, forward_batch)
def init_forward_metadata(self, forward_batch: ForwardBatch):
kv_indices = torch.zeros(
(
self.speculative_num_steps,
forward_batch.batch_size * self.topk * self.max_context_len,
),
dtype=torch.int32,
device="cuda",
)
def call_fn(i, forward_batch):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
forward_batch.spec_info.kv_indptr = (
forward_batch.spec_info.kv_indptr.clone()
)
forward_batch.spec_info.kv_indices = (
forward_batch.spec_info.kv_indices.clone()
)
self.attn_backends[i].init_forward_metadata(forward_batch)
self.common_template(forward_batch, kv_indices, call_fn)
def init_cuda_graph_state(self, max_bs: int):
self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len),
dtype=torch.int32,
device="cuda",
)
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
forward_batch.batch_size,
forward_batch.batch_size * self.topk,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
seq_lens_sum=-1,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=forward_batch.decode_seq_lens_cpu,
)
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
def fast_mla_decode_plan(
self,
qo_indptr_cpu: torch.Tensor,
kv_indptr_cpu: torch.Tensor,
kv_indices: torch.Tensor,
kv_len_arr_cpu: torch.Tensor,
num_heads: int,
head_dim_ckv: int,
head_dim_kpe: int,
page_size: int,
causal: bool,
sm_scale: float,
q_data_type: torch.dtype,
kv_data_type: torch.dtype,
) -> None:
"""A faster version of BatchMLAPagedAttentionWrapper::plan,
for skipping the stream synchronization in original plan function during
cuda graph replaying.
"""
self._causal = causal
self._page_size = page_size
self._sm_scale = sm_scale
with self.device as device:
stream = torch.cuda.current_stream(device).cuda_stream
self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_cpu,
kv_indptr_cpu,
kv_len_arr_cpu,
num_heads,
head_dim_ckv,
causal,
stream,
)