from __future__ import annotations from typing import TYPE_CHECKING import torch from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner class DoubleSparseAttnBackend(AttentionBackend): def __init__(self, model_runner: ModelRunner): # Lazy import to avoid the initialization of cuda context from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import ( extend_attention_fwd, flash_decode_attention_fwd, flash_decode_sparse_attention_fwd, ) super().__init__() self.decode_attention_fwd = flash_decode_attention_fwd self.decode_sparse_attention_fwd = flash_decode_sparse_attention_fwd self.extend_attention_fwd = extend_attention_fwd self.num_head = model_runner.model_config.num_attention_heads self.head_dim = model_runner.model_config.hidden_size // self.num_head self.heavy_token_num = model_runner.server_args.ds_heavy_token_num self.sorted_channels = model_runner.sorted_channels self.sparse_decode_thresold = ( model_runner.server_args.ds_sparse_decode_threshold ) self.att_out_approx: torch.Tensor = None self.mid_out: torch.Tensor = None self.mid_o_logexpsum: torch.Tensor = None # TODO: Change the hard-coded block_seq_num self.BLOCK_SEQ = 128 if global_server_args_dict.get("triton_attention_reduce_in_fp32", False): self.reduce_dtype = torch.float32 else: self.reduce_dtype = torch.float16 self.forward_metadata = None def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" if forward_batch.forward_mode.is_decode(): start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32) start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0) total_num_tokens = torch.sum(forward_batch.seq_lens).item() attn_logits = torch.empty( (self.num_head, total_num_tokens), dtype=self.reduce_dtype, device="cuda", ) max_seq_len = torch.max(forward_batch.seq_lens).item() min_seq_len = torch.min(forward_batch.seq_lens).item() max_extend_len = None # NOTE: Align sequence order with req_to_token order ds_req_to_token = forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices ] bsz = forward_batch.seq_lens.shape[0] att_out_approx = torch.empty( [self.num_head, bsz, max_seq_len], dtype=self.reduce_dtype, device="cuda", ) block_seq_num = ( self.heavy_token_num + self.BLOCK_SEQ - 1 ) // self.BLOCK_SEQ mid_out = torch.empty( [bsz, self.num_head, block_seq_num, self.head_dim], dtype=torch.float32, device="cuda", ) mid_o_logexpsum = torch.empty( [bsz, self.num_head, block_seq_num], dtype=torch.float32, device="cuda" ) self.att_out_approx = att_out_approx self.mid_out = mid_out self.mid_o_logexpsum = mid_o_logexpsum else: start_loc = attn_logits = max_seq_len = min_seq_len = None prefix_lens = forward_batch.extend_prefix_lens max_extend_len = torch.max(forward_batch.seq_lens - prefix_lens).item() ds_req_to_token = None self.forward_metadata = ( start_loc, attn_logits, max_seq_len, min_seq_len, max_extend_len, ds_req_to_token, ) def forward_extend( self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache=True, ): # TODO: reuse the buffer across layers if layer.qk_head_dim != layer.v_head_dim: o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) else: o = torch.empty_like(q) k_label = torch.gather( k, 2, self.sorted_channels[layer.layer_id] .unsqueeze(0) .expand(k.shape[0], -1, -1), ) if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( layer, forward_batch.out_cache_loc, k, v, k_label ) ( start_loc, attn_logits, max_seq_len, min_seq_len, max_extend_len, ds_req_to_token, ) = self.forward_metadata self.extend_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), k.contiguous(), v.contiguous(), o.view(-1, layer.tp_q_head_num, layer.v_head_dim), forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), forward_batch.req_to_token_pool.req_to_token, forward_batch.req_pool_indices, forward_batch.seq_lens, forward_batch.extend_seq_lens, forward_batch.extend_start_loc, max_extend_len, layer.scaling, layer.logit_cap, ) return o def forward_decode( self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache=True, ): # During torch.compile, there is a bug in rotary_emb that causes the # output value to have a 3D tensor shape. This reshapes the output correctly. q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) # TODO: reuse the buffer across layers if layer.qk_head_dim != layer.v_head_dim: o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) else: o = torch.empty_like(q) # TODO: Add min seqlen ( start_loc, attn_logits, max_seq_len, min_seq_len, max_extend_len, ds_req_to_token, ) = self.forward_metadata k_label = torch.gather( k, 2, self.sorted_channels[layer.layer_id] .unsqueeze(0) .expand(k.shape[0], -1, -1), ) if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( layer, forward_batch.out_cache_loc, k, v, k_label ) # NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num # and set a minimum value for sparse_decode if ( min_seq_len < self.heavy_token_num or max_seq_len < self.sparse_decode_thresold ): self.decode_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), o.view(-1, layer.tp_q_head_num, layer.v_head_dim), forward_batch.req_to_token_pool.req_to_token, forward_batch.req_pool_indices, start_loc, forward_batch.seq_lens, attn_logits, max_seq_len, layer.scaling, layer.logit_cap, ) else: # TODO(Andy): indexing with torch.gather or torch.index_select or customized kernel q_label = torch.gather( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), 2, self.sorted_channels[layer.layer_id] .unsqueeze(0) .expand(q.shape[0], -1, -1), ) self.decode_sparse_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), o.view(-1, layer.tp_q_head_num, layer.qk_head_dim), q_label, forward_batch.token_to_kv_pool.get_label_buffer(layer.layer_id), ds_req_to_token, forward_batch.seq_lens, max_seq_len, layer.scaling, layer.logit_cap, self.heavy_token_num, self.att_out_approx, self.mid_out, self.mid_o_logexpsum, self.BLOCK_SEQ, ) return o