258 lines
8.9 KiB
Python
258 lines
8.9 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
import torch
|
|
|
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
|
|
|
|
|
class DoubleSparseAttnBackend(AttentionBackend):
|
|
def __init__(self, model_runner: ModelRunner):
|
|
# Lazy import to avoid the initialization of cuda context
|
|
from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import (
|
|
extend_attention_fwd,
|
|
flash_decode_attention_fwd,
|
|
flash_decode_sparse_attention_fwd,
|
|
)
|
|
|
|
super().__init__()
|
|
|
|
self.decode_attention_fwd = flash_decode_attention_fwd
|
|
self.decode_sparse_attention_fwd = flash_decode_sparse_attention_fwd
|
|
self.extend_attention_fwd = extend_attention_fwd
|
|
self.num_head = model_runner.model_config.num_attention_heads
|
|
self.head_dim = model_runner.model_config.hidden_size // self.num_head
|
|
self.heavy_token_num = model_runner.server_args.ds_heavy_token_num
|
|
|
|
self.sorted_channels = model_runner.sorted_channels
|
|
self.sparse_decode_thresold = (
|
|
model_runner.server_args.ds_sparse_decode_threshold
|
|
)
|
|
self.att_out_approx: torch.Tensor = None
|
|
self.mid_out: torch.Tensor = None
|
|
self.mid_o_logexpsum: torch.Tensor = None
|
|
|
|
# TODO: Change the hard-coded block_seq_num
|
|
self.BLOCK_SEQ = 128
|
|
|
|
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
|
|
self.reduce_dtype = torch.float32
|
|
else:
|
|
self.reduce_dtype = torch.float16
|
|
|
|
self.forward_metadata = None
|
|
|
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
|
"""Init auxiliary variables for triton attention backend."""
|
|
|
|
if forward_batch.forward_mode.is_decode():
|
|
start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
|
|
start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
|
|
|
|
total_num_tokens = torch.sum(forward_batch.seq_lens).item()
|
|
attn_logits = torch.empty(
|
|
(self.num_head, total_num_tokens),
|
|
dtype=self.reduce_dtype,
|
|
device="cuda",
|
|
)
|
|
|
|
max_seq_len = torch.max(forward_batch.seq_lens).item()
|
|
min_seq_len = torch.min(forward_batch.seq_lens).item()
|
|
max_extend_len = None
|
|
# NOTE: Align sequence order with req_to_token order
|
|
ds_req_to_token = forward_batch.req_to_token_pool.req_to_token[
|
|
forward_batch.req_pool_indices
|
|
]
|
|
|
|
bsz = forward_batch.seq_lens.shape[0]
|
|
|
|
att_out_approx = torch.empty(
|
|
[self.num_head, bsz, max_seq_len],
|
|
dtype=self.reduce_dtype,
|
|
device="cuda",
|
|
)
|
|
|
|
block_seq_num = (
|
|
self.heavy_token_num + self.BLOCK_SEQ - 1
|
|
) // self.BLOCK_SEQ
|
|
|
|
mid_out = torch.empty(
|
|
[bsz, self.num_head, block_seq_num, self.head_dim],
|
|
dtype=torch.float32,
|
|
device="cuda",
|
|
)
|
|
mid_o_logexpsum = torch.empty(
|
|
[bsz, self.num_head, block_seq_num], dtype=torch.float32, device="cuda"
|
|
)
|
|
self.att_out_approx = att_out_approx
|
|
self.mid_out = mid_out
|
|
self.mid_o_logexpsum = mid_o_logexpsum
|
|
|
|
else:
|
|
start_loc = attn_logits = max_seq_len = min_seq_len = None
|
|
prefix_lens = forward_batch.extend_prefix_lens
|
|
max_extend_len = torch.max(forward_batch.seq_lens - prefix_lens).item()
|
|
ds_req_to_token = None
|
|
|
|
self.forward_metadata = (
|
|
start_loc,
|
|
attn_logits,
|
|
max_seq_len,
|
|
min_seq_len,
|
|
max_extend_len,
|
|
ds_req_to_token,
|
|
)
|
|
|
|
def forward_extend(
|
|
self,
|
|
q,
|
|
k,
|
|
v,
|
|
layer: RadixAttention,
|
|
forward_batch: ForwardBatch,
|
|
save_kv_cache=True,
|
|
):
|
|
# TODO: reuse the buffer across layers
|
|
if layer.qk_head_dim != layer.v_head_dim:
|
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
|
else:
|
|
o = torch.empty_like(q)
|
|
|
|
k_label = torch.gather(
|
|
k,
|
|
2,
|
|
self.sorted_channels[layer.layer_id]
|
|
.unsqueeze(0)
|
|
.expand(k.shape[0], -1, -1),
|
|
)
|
|
|
|
if save_kv_cache:
|
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
layer, forward_batch.out_cache_loc, k, v, k_label
|
|
)
|
|
|
|
(
|
|
start_loc,
|
|
attn_logits,
|
|
max_seq_len,
|
|
min_seq_len,
|
|
max_extend_len,
|
|
ds_req_to_token,
|
|
) = self.forward_metadata
|
|
self.extend_attention_fwd(
|
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
|
k.contiguous(),
|
|
v.contiguous(),
|
|
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
|
forward_batch.req_to_token_pool.req_to_token,
|
|
forward_batch.req_pool_indices,
|
|
forward_batch.seq_lens,
|
|
forward_batch.extend_seq_lens,
|
|
forward_batch.extend_start_loc,
|
|
max_extend_len,
|
|
layer.scaling,
|
|
layer.logit_cap,
|
|
)
|
|
return o
|
|
|
|
def forward_decode(
|
|
self,
|
|
q,
|
|
k,
|
|
v,
|
|
layer: RadixAttention,
|
|
forward_batch: ForwardBatch,
|
|
save_kv_cache=True,
|
|
):
|
|
# During torch.compile, there is a bug in rotary_emb that causes the
|
|
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
|
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
|
|
|
# TODO: reuse the buffer across layers
|
|
if layer.qk_head_dim != layer.v_head_dim:
|
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
|
else:
|
|
o = torch.empty_like(q)
|
|
|
|
# TODO: Add min seqlen
|
|
(
|
|
start_loc,
|
|
attn_logits,
|
|
max_seq_len,
|
|
min_seq_len,
|
|
max_extend_len,
|
|
ds_req_to_token,
|
|
) = self.forward_metadata
|
|
|
|
k_label = torch.gather(
|
|
k,
|
|
2,
|
|
self.sorted_channels[layer.layer_id]
|
|
.unsqueeze(0)
|
|
.expand(k.shape[0], -1, -1),
|
|
)
|
|
|
|
if save_kv_cache:
|
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
layer, forward_batch.out_cache_loc, k, v, k_label
|
|
)
|
|
|
|
# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
|
|
# and set a minimum value for sparse_decode
|
|
if (
|
|
min_seq_len < self.heavy_token_num
|
|
or max_seq_len < self.sparse_decode_thresold
|
|
):
|
|
self.decode_attention_fwd(
|
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
|
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
|
forward_batch.req_to_token_pool.req_to_token,
|
|
forward_batch.req_pool_indices,
|
|
start_loc,
|
|
forward_batch.seq_lens,
|
|
attn_logits,
|
|
max_seq_len,
|
|
layer.scaling,
|
|
layer.logit_cap,
|
|
)
|
|
else:
|
|
# TODO(Andy): indexing with torch.gather or torch.index_select or customized kernel
|
|
q_label = torch.gather(
|
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
|
2,
|
|
self.sorted_channels[layer.layer_id]
|
|
.unsqueeze(0)
|
|
.expand(q.shape[0], -1, -1),
|
|
)
|
|
self.decode_sparse_attention_fwd(
|
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
|
o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
|
q_label,
|
|
forward_batch.token_to_kv_pool.get_label_buffer(layer.layer_id),
|
|
ds_req_to_token,
|
|
forward_batch.seq_lens,
|
|
max_seq_len,
|
|
layer.scaling,
|
|
layer.logit_cap,
|
|
self.heavy_token_num,
|
|
self.att_out_approx,
|
|
self.mid_out,
|
|
self.mid_o_logexpsum,
|
|
self.BLOCK_SEQ,
|
|
)
|
|
|
|
return o
|