import unittest import sgl_kernel import torch from torch.nn.functional import scaled_dot_product_attention from sglang.test.test_utils import CustomTestCase torch.manual_seed(1234) class TestExtendAttention(CustomTestCase): def _run_sdpa_forward_extend( self, query: torch.Tensor, output: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, req_to_token: torch.Tensor, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, scaling=None, enable_gqa=False, causal=False, ): assert seq_lens.shape[0] == extend_prefix_lens.shape[0] assert seq_lens.shape[0] == extend_seq_lens.shape[0] # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size] query = query.movedim(0, query.dim() - 2) start_q, start_kv = 0, 0 for seq_idx in range(seq_lens.shape[0]): extend_seq_len_q = extend_seq_lens[seq_idx] prefill_seq_len_q = extend_prefix_lens[seq_idx] seq_len_kv = seq_lens[seq_idx] end_q = start_q + extend_seq_len_q end_kv = start_kv + seq_len_kv per_req_query = query[:, start_q:end_q, :] per_req_query_redudant = torch.empty( (per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]), dtype=per_req_query.dtype, device=per_req_query.device, ) per_req_query_redudant[:, prefill_seq_len_q:, :] = per_req_query # get key and value from cache. per_req_tokens contains the kv cache # index for each token in the sequence. req_pool_idx = req_pool_indices[seq_idx] per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv] per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2) per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2) per_req_out_redudant = ( scaled_dot_product_attention( per_req_query_redudant.unsqueeze(0), per_req_key.unsqueeze(0), per_req_value.unsqueeze(0), enable_gqa=enable_gqa, scale=scaling, is_causal=causal, ) .squeeze(0) .movedim(query.dim() - 2, 0) ) output[start_q:end_q, :, :] = per_req_out_redudant[prefill_seq_len_q:, :, :] start_q, start_kv = end_q, end_kv return output def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D, DV, mla=False): dtype = torch.bfloat16 b_seq_len_prefix = torch.randint(1, N_CTX // 2, (B,), dtype=torch.int32) if mla: b_seq_len_prefix.zero_() b_seq_len_extend = torch.randint(1, N_CTX // 2, (B,), dtype=torch.int32) b_seq_len = b_seq_len_prefix + b_seq_len_extend max_len_in_batch = torch.max(b_seq_len, 0)[0].item() b_req_idx = torch.arange(B, dtype=torch.int32) req_to_tokens = torch.empty((B, max_len_in_batch), dtype=torch.int32) b_start_loc = torch.zeros((B,), dtype=torch.int32) b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) b_start_loc_extend = torch.zeros((B,), dtype=torch.int32) b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) for i in range(B): req_to_tokens[i, : b_seq_len[i]] = torch.arange( b_start_loc[i], b_start_loc[i] + b_seq_len[i] ) total_token_num = torch.sum(b_seq_len).item() extend_token_num = torch.sum(b_seq_len_extend).item() H_BUF = 1 if mla else H_KV k_buffer = torch.randn((total_token_num, H_BUF, D), dtype=dtype) v_buffer = torch.randn((total_token_num, H_BUF, DV), dtype=dtype) k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype) v_extend = torch.empty((extend_token_num, H_KV, DV), dtype=dtype) q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype) for i in range(B): extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i] extend_end_in_buffer = b_start_loc[i] + b_seq_len[i] extend_start = b_start_loc_extend[i] extend_end = b_start_loc_extend[i] + b_seq_len_extend[i] k_extend[extend_start:extend_end] = k_buffer[ extend_start_in_buffer:extend_end_in_buffer ] v_extend[extend_start:extend_end] = v_buffer[ extend_start_in_buffer:extend_end_in_buffer ] q_extend[extend_start:extend_end] = torch.randn( (b_seq_len_extend[i], H_Q, D), dtype=dtype ) # k_extend, v_extend, k_buffer and v_buffer supports non-contiguous tensors k_extend = k_extend.transpose(0, 1).contiguous().transpose(0, 1) v_extend = v_extend.transpose(0, 1).contiguous().transpose(0, 1) k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1) v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1) b_seq_len_extend = b_seq_len - b_seq_len_prefix b_start_loc_extend = torch.zeros_like(b_seq_len) b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() sm_scale = 1.0 / (D**0.5) logit_cap = 0.0 # handle index type b_req_idx = b_req_idx.to(torch.int64) b_seq_len = b_seq_len.to(torch.int64) enable_gqa = H_Q != H_KV o_ref = torch.empty((extend_token_num, H_Q, DV), dtype=dtype) self._run_sdpa_forward_extend( q_extend, o_ref, k_buffer, v_buffer, req_to_tokens, b_req_idx, b_seq_len, b_seq_len_prefix, b_seq_len_extend, scaling=sm_scale, enable_gqa=enable_gqa, causal=True, ) o_extend = torch.empty((extend_token_num, H_Q, DV), dtype=dtype) torch.ops.sgl_kernel.extend_attention_cpu( q_extend, k_extend, v_extend, o_extend, k_buffer, v_buffer, req_to_tokens, b_req_idx, b_seq_len, b_seq_len_extend, b_start_loc_extend, max_len_extend, sm_scale, logit_cap, ) torch.testing.assert_close(o_ref, o_extend, atol=1e-2, rtol=1e-2) def test_extend_attention(self): for is_mla in [True, False]: self._test_extend_attention_once(1, 123, 1, 1, 128, 96, is_mla) self._test_extend_attention_once(1, 123, 16, 1, 128, 96, is_mla) self._test_extend_attention_once(4, 1230, 16, 4, 128, 96, is_mla) if __name__ == "__main__": unittest.main()