435 lines
16 KiB
Python
435 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
|
|
|
"""
|
|
Support different attention backends.
|
|
Now there are three backends: FlashInfer, Triton and FlashAttention.
|
|
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
|
"""
|
|
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Optional, Union
|
|
|
|
import torch
|
|
|
|
from sglang.srt.configs.model_config import AttentionArch
|
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
|
|
|
from flash_attn_interface import flash_attn_with_kvcache
|
|
|
|
|
|
@dataclass
|
|
class FlashAttentionMetadata:
|
|
"""Metadata for decode operations to avoid redundant computations."""
|
|
|
|
cu_seqlens_q: torch.Tensor = None
|
|
cu_seqlens_k: torch.Tensor = None
|
|
max_seq_len_q: int = 0
|
|
max_seq_len_k: int = 0
|
|
window_size: tuple = (-1, -1)
|
|
page_table: torch.Tensor = None
|
|
cache_seqlens_int32: torch.Tensor = None
|
|
|
|
|
|
class FlashAttentionBackend(AttentionBackend):
|
|
"""FlashAttention backend implementation."""
|
|
|
|
def __init__(
|
|
self,
|
|
model_runner: ModelRunner,
|
|
skip_prefill: bool = False,
|
|
):
|
|
super().__init__()
|
|
|
|
assert not (
|
|
model_runner.sliding_window_size is not None
|
|
and model_runner.model_config.is_encoder_decoder
|
|
), "Sliding window and cross attention are not supported together"
|
|
|
|
# Initialize metadata
|
|
self.forward_metadata: FlashAttentionMetadata = None
|
|
self.max_context_len = model_runner.model_config.context_len
|
|
self.device = model_runner.device
|
|
self.decode_cuda_graph_metadata = {}
|
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
|
self.page_size = model_runner.page_size
|
|
self.use_mla = (
|
|
model_runner.model_config.attention_arch == AttentionArch.MLA
|
|
) and (not global_server_args_dict["disable_mla"])
|
|
|
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
|
"""Initialize forward metadata to cache repetitive calculations."""
|
|
# Create metadata based on forward mode
|
|
metadata = FlashAttentionMetadata()
|
|
|
|
# Get sequence information
|
|
seqlens_in_batch = forward_batch.seq_lens
|
|
# Precompute int32 version of sequence lengths
|
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
|
batch_size = len(seqlens_in_batch)
|
|
device = seqlens_in_batch.device
|
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
|
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
|
)
|
|
# Precompute maximum sequence length
|
|
metadata.max_seq_len_k = seqlens_in_batch.max().item()
|
|
# Precompute page table
|
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
|
]
|
|
|
|
# Precompute strided indices
|
|
# [0, page_size, 2 * page_size, ...]
|
|
if self.page_size > 1:
|
|
self.strided_indices = torch.arange(
|
|
0, metadata.page_table.shape[1], self.page_size, device=self.device
|
|
)
|
|
metadata.page_table = (
|
|
metadata.page_table[:, self.strided_indices] // self.page_size
|
|
)
|
|
|
|
if forward_batch.forward_mode == ForwardMode.DECODE:
|
|
# Precompute cumulative sequence lengths
|
|
metadata.cu_seqlens_q = torch.arange(
|
|
0, batch_size + 1, dtype=torch.int32, device=device
|
|
)
|
|
else:
|
|
# Precompute cumulative sequence lengths
|
|
if any(forward_batch.extend_prefix_lens_cpu):
|
|
extend_seq_lens = forward_batch.extend_seq_lens
|
|
metadata.cu_seqlens_q = torch.nn.functional.pad(
|
|
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
|
)
|
|
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
|
|
else:
|
|
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
|
metadata.max_seq_len_q = metadata.max_seq_len_k
|
|
self.forward_metadata = metadata
|
|
|
|
def forward_extend(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
layer: RadixAttention,
|
|
forward_batch: ForwardBatch,
|
|
save_kv_cache=True,
|
|
):
|
|
|
|
if k is not None:
|
|
assert v is not None
|
|
if save_kv_cache:
|
|
cache_loc = (
|
|
forward_batch.out_cache_loc
|
|
if not layer.is_cross_attention
|
|
else forward_batch.encoder_out_cache_loc
|
|
)
|
|
if not self.use_mla:
|
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
|
)
|
|
else:
|
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
layer,
|
|
cache_loc,
|
|
k,
|
|
v,
|
|
)
|
|
|
|
# Use precomputed metadata
|
|
metadata = self.forward_metadata
|
|
|
|
# Calculate window size (can be moved to metadata if layer properties don't change)
|
|
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
|
# here is two side inclusive
|
|
window_size = (
|
|
(layer.sliding_window_size, 0)
|
|
if layer.sliding_window_size is not None
|
|
else (-1, -1)
|
|
)
|
|
|
|
page_table = metadata.page_table
|
|
|
|
# # Use Flash Attention for prefill
|
|
if not self.use_mla:
|
|
# Do multi-head attention
|
|
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
|
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
|
key_cache = key_cache.view(
|
|
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
|
)
|
|
value_cache = value_cache.view(
|
|
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
|
)
|
|
o = flash_attn_with_kvcache(
|
|
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
|
k_cache=key_cache,
|
|
v_cache=value_cache,
|
|
page_table=page_table,
|
|
cache_seqlens=metadata.cache_seqlens_int32,
|
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
|
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
|
max_seqlen_q=metadata.max_seq_len_q,
|
|
softmax_scale=layer.scaling,
|
|
causal=True,
|
|
window_size=window_size,
|
|
softcap=layer.logit_cap,
|
|
k_descale=layer.k_scale,
|
|
v_descale=layer.v_scale,
|
|
)
|
|
else:
|
|
# Do absorbed multi-latent attention
|
|
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
|
k_rope = kv_cache[:, :, layer.v_head_dim :]
|
|
c_kv = kv_cache[:, :, : layer.v_head_dim]
|
|
k_rope_cache = k_rope.view(
|
|
-1,
|
|
self.page_size,
|
|
layer.tp_k_head_num,
|
|
layer.head_dim - layer.v_head_dim,
|
|
)
|
|
c_kv_cache = c_kv.view(
|
|
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
|
)
|
|
|
|
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
|
q_nope = q_all[:, :, : layer.v_head_dim]
|
|
q_rope = q_all[:, :, layer.v_head_dim :]
|
|
o = flash_attn_with_kvcache(
|
|
q=q_rope,
|
|
k_cache=k_rope_cache,
|
|
v_cache=c_kv_cache,
|
|
qv=q_nope,
|
|
page_table=page_table,
|
|
cache_seqlens=metadata.cache_seqlens_int32,
|
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
|
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
|
max_seqlen_q=metadata.max_seq_len_q,
|
|
softmax_scale=layer.scaling,
|
|
causal=True,
|
|
softcap=layer.logit_cap,
|
|
k_descale=layer.k_scale,
|
|
v_descale=layer.v_scale,
|
|
)
|
|
|
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
|
|
|
def forward_decode(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
layer: RadixAttention,
|
|
forward_batch: ForwardBatch,
|
|
save_kv_cache=True,
|
|
) -> torch.Tensor:
|
|
"""Forward pass with FlashAttention using precomputed metadata."""
|
|
# Save KV cache if needed
|
|
if k is not None:
|
|
assert v is not None
|
|
if save_kv_cache:
|
|
cache_loc = (
|
|
forward_batch.out_cache_loc
|
|
if not layer.is_cross_attention
|
|
else forward_batch.encoder_out_cache_loc
|
|
)
|
|
if not self.use_mla:
|
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
|
)
|
|
else:
|
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
layer,
|
|
cache_loc,
|
|
k,
|
|
v,
|
|
)
|
|
|
|
# Use precomputed metadata
|
|
metadata = self.forward_metadata
|
|
|
|
# Calculate window size (can be moved to metadata if layer properties don't change)
|
|
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
|
# here is two side inclusive
|
|
window_size = (
|
|
(layer.sliding_window_size, 0)
|
|
if layer.sliding_window_size is not None
|
|
else (-1, -1)
|
|
)
|
|
|
|
page_table = metadata.page_table
|
|
|
|
if not self.use_mla:
|
|
# Do multi-head attention
|
|
|
|
# Get KV cache
|
|
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
|
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
|
key_cache = key_cache.view(
|
|
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
|
)
|
|
value_cache = value_cache.view(
|
|
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
|
)
|
|
|
|
# Pre-reshape query tensor
|
|
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
|
|
|
# Run attention with precomputed values
|
|
o = flash_attn_with_kvcache(
|
|
q=q_reshaped,
|
|
k_cache=key_cache,
|
|
v_cache=value_cache,
|
|
page_table=page_table,
|
|
cache_seqlens=metadata.cache_seqlens_int32,
|
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
|
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
|
max_seqlen_q=1,
|
|
softmax_scale=layer.scaling,
|
|
causal=True,
|
|
window_size=window_size,
|
|
softcap=layer.logit_cap,
|
|
k_descale=layer.k_scale,
|
|
v_descale=layer.v_scale,
|
|
)
|
|
else:
|
|
# Do absorbed multi-latent attention
|
|
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
|
k_rope = kv_cache[:, :, layer.v_head_dim :]
|
|
c_kv = kv_cache[:, :, : layer.v_head_dim]
|
|
k_rope_cache = k_rope.view(
|
|
-1,
|
|
self.page_size,
|
|
layer.tp_k_head_num,
|
|
layer.head_dim - layer.v_head_dim,
|
|
)
|
|
c_kv_cache = c_kv.view(
|
|
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
|
)
|
|
|
|
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
|
q_nope = q_all[:, :, : layer.v_head_dim]
|
|
q_rope = q_all[:, :, layer.v_head_dim :]
|
|
|
|
o = flash_attn_with_kvcache(
|
|
q=q_rope,
|
|
k_cache=k_rope_cache,
|
|
v_cache=c_kv_cache,
|
|
qv=q_nope,
|
|
page_table=page_table,
|
|
cache_seqlens=metadata.cache_seqlens_int32,
|
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
|
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
|
max_seqlen_q=1,
|
|
softmax_scale=layer.scaling,
|
|
causal=True,
|
|
softcap=layer.logit_cap,
|
|
k_descale=layer.k_scale,
|
|
v_descale=layer.v_scale,
|
|
)
|
|
|
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
|
|
|
def init_cuda_graph_state(self, max_bs: int):
|
|
"""Initialize CUDA graph state for the attention backend.
|
|
|
|
Args:
|
|
max_bs (int): Maximum batch size to support in CUDA graphs
|
|
|
|
This creates fixed-size tensors that will be reused during CUDA graph replay
|
|
to avoid memory allocations.
|
|
"""
|
|
# Initialize fixed size tensors for decode operations
|
|
self.decode_cuda_graph_metadata = {
|
|
# Page table for token mapping (batch_size, max_context_len)
|
|
"page_table": torch.zeros(
|
|
max_bs,
|
|
(self.max_context_len + self.page_size - 1) // self.page_size,
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
),
|
|
"strided_indices": torch.arange(
|
|
0, self.max_context_len, self.page_size, device=self.device
|
|
),
|
|
}
|
|
|
|
def init_forward_metadata_capture_cuda_graph(
|
|
self,
|
|
bs: int,
|
|
num_tokens: int,
|
|
req_pool_indices: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
encoder_lens: Optional[torch.Tensor],
|
|
forward_mode: ForwardMode,
|
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
|
):
|
|
"""Initialize forward metadata for capturing CUDA graph."""
|
|
metadata = FlashAttentionMetadata()
|
|
# Get sequence information
|
|
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
|
batch_size = len(seq_lens)
|
|
device = seq_lens.device
|
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
|
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
|
)
|
|
# Precompute maximum sequence length
|
|
metadata.max_seq_len_k = seq_lens.max().item()
|
|
# Precompute page table
|
|
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
|
|
req_pool_indices, :
|
|
]
|
|
if forward_mode == ForwardMode.DECODE:
|
|
# Precompute cumulative sequence lengths
|
|
metadata.cu_seqlens_q = torch.arange(
|
|
0, batch_size + 1, dtype=torch.int32, device=device
|
|
)
|
|
else:
|
|
raise ValueError("Do not support Prefill Mode cuda graph")
|
|
self.decode_cuda_graph_metadata[bs] = metadata
|
|
self.forward_metadata = metadata
|
|
|
|
def init_forward_metadata_replay_cuda_graph(
|
|
self,
|
|
bs: int,
|
|
req_pool_indices: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
seq_lens_sum: int,
|
|
encoder_lens: Optional[torch.Tensor],
|
|
forward_mode: ForwardMode,
|
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
|
seq_lens_cpu: Optional[torch.Tensor],
|
|
):
|
|
# """Initialize forward metadata for replaying CUDA graph."""
|
|
metadata = self.decode_cuda_graph_metadata[bs]
|
|
|
|
# For CPU operations
|
|
max_len = seq_lens_cpu[:bs].max().item()
|
|
metadata.max_seq_len_k = max_len
|
|
|
|
# For GPU operations
|
|
seq_lens_in_batch = seq_lens[:bs]
|
|
metadata.cache_seqlens_int32 = seq_lens_in_batch.to(torch.int32)
|
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
|
torch.cumsum(seq_lens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
|
)
|
|
|
|
max_seq_pages = (metadata.max_seq_len_k + self.page_size - 1) // self.page_size
|
|
page_indices = self.req_to_token[
|
|
:, self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages]
|
|
]
|
|
page_indices = page_indices[req_pool_indices[:bs]] // self.page_size
|
|
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
|
metadata.page_table[:, max_seq_pages:].fill_(0)
|
|
self.forward_metadata = metadata
|
|
|
|
def get_cuda_graph_seq_len_fill_value(self):
|
|
"""Get the fill value for sequence length in CUDA graph."""
|
|
return 0
|