import random import unittest import torch from sglang.srt.layers.attention.triton_ops.decode_attention import ( decode_attention_fwd_grouped, ) from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import ( decode_attention_fwd_grouped_rope, ) from sglang.srt.layers.rotary_embedding import DeepseekScalingRotaryEmbedding from sglang.test.test_utils import CustomTestCase class TestTritonAttentionMLA(CustomTestCase): def _set_all_seeds(self, seed): """Set all random seeds for reproducibility.""" random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def setUp(self): # Set seeds before each test method self._set_all_seeds(42) def preprocess_kv_cache(self, kv_cache, kv_lora_rank): latent_cache = kv_cache v_input = latent_cache[..., :kv_lora_rank] v_input = v_input.contiguous().unsqueeze(1) k_input = latent_cache.unsqueeze(1) k_input[..., :kv_lora_rank] = v_input return k_input, v_input def input_helper( self, B, H, S, kv_lora_rank, rotary_dim, qk_rope_head_dim, num_kv_splits, dtype, device, rope_base=10, rope_max_seq_len=16384, rope_scaling=1.0, is_neox_style=False, ): q = torch.randn( B, H, kv_lora_rank + qk_rope_head_dim, device=device, dtype=dtype ) kv_cache = torch.randn( B * S, kv_lora_rank + qk_rope_head_dim, dtype=dtype, device=device ) kv_indptr = torch.arange(B + 1, device=device) * S kv_indices = torch.arange(B * S, device=device) attn_logits = torch.empty( B, H, num_kv_splits, kv_lora_rank + 1, dtype=dtype, device=device ) rotary_emb = DeepseekScalingRotaryEmbedding( qk_rope_head_dim, rotary_dim, rope_max_seq_len, rope_base, is_neox_style, rope_scaling, q.dtype, device="cpu", ).cuda() positions = torch.tensor([S], device=device).unsqueeze(0).repeat(B, 1) return kv_indptr, kv_indices, q, kv_cache, attn_logits, rotary_emb, positions def ref_compute_full_fwd( self, q, k_input, v_input, kv_lora_rank, kv_indptr, kv_indices, num_kv_splits, sm_scale, logit_cap, rotary_emb, positions, use_rope, device="cuda", ): B, H = q.shape[0], q.shape[1] S = kv_indptr[1].item() qk_rope_head_dim = k_input.shape[-1] - kv_lora_rank q_input = torch.empty(B, H, kv_lora_rank + qk_rope_head_dim, dtype=q.dtype).to( device ) q_nope_out, q_pe = q.split([kv_lora_rank, qk_rope_head_dim], dim=-1) k_pe_t = k_input.view(B, 1, S, -1)[:, :, -1:, kv_lora_rank:] if use_rope: q_pe, k_pe_t = rotary_emb(positions, q_pe.unsqueeze(2), k_pe_t) q_pe = q_pe.squeeze() k_input.view(B, 1, S, -1)[:, :, -1:, kv_lora_rank:] = k_pe_t q_input[..., :kv_lora_rank] = q_nope_out q_input[..., kv_lora_rank:] = q_pe B, H = q_input.shape[0], q_input.shape[1] kv_lora_rank = v_input.shape[-1] device = q_input.device attn_logits = torch.empty( B, H, num_kv_splits, kv_lora_rank + 1, dtype=q_input.dtype, device=device ) o = torch.empty(B, H, kv_lora_rank, dtype=q_input.dtype, device=device) decode_attention_fwd_grouped( q_input, k_input, v_input, o, kv_indptr, kv_indices, attn_logits, num_kv_splits, sm_scale, logit_cap, ) return attn_logits, o, k_pe_t.squeeze() def _test_rocm_fused_mla_kernel( self, B, H, S, kv_lora_rank, qk_rope_head_dim, rotary_dim, dtype, use_rope, is_neox_style, num_kv_splits=2, sm_scale=1.0, logit_cap=0.0, device="cuda", ): kv_indptr, kv_indices, q, kv_cache, attn_logits, rotary_emb, positions = ( self.input_helper( B, H, S, kv_lora_rank, rotary_dim, qk_rope_head_dim, num_kv_splits, dtype, device=device, is_neox_style=is_neox_style, ) ) k_input, v_input = self.preprocess_kv_cache(kv_cache, kv_lora_rank) k_pe_tokens = torch.empty( B, qk_rope_head_dim, dtype=kv_cache.dtype, device=device ) tri_o = torch.empty(B, H, kv_lora_rank, dtype=kv_cache.dtype, device=device) decode_attention_fwd_grouped_rope( q, k_input, v_input, tri_o, kv_indptr, kv_indices, k_pe_tokens if use_rope else None, kv_lora_rank, rotary_dim if use_rope else None, rotary_emb.cos_sin_cache if use_rope else None, positions if use_rope else None, attn_logits, num_kv_splits, sm_scale, logit_cap, use_rope, is_neox_style, ) tri_logits = attn_logits # reference ref_logits, ref_o, ref_k_pe_tokens = self.ref_compute_full_fwd( q, k_input, v_input, kv_lora_rank, kv_indptr, kv_indices, num_kv_splits, sm_scale, logit_cap, rotary_emb, positions, use_rope, device="cuda", ) if use_rope: torch.testing.assert_close( ref_k_pe_tokens, k_pe_tokens.squeeze(), atol=1e-2, rtol=1e-2 ) torch.testing.assert_close(ref_logits, tri_logits, atol=1e-2, rtol=1e-2) torch.testing.assert_close(ref_o, tri_o, atol=1e-2, rtol=1e-2) def test_grouped_rocm_fused_mla(self): configs = [ (1, 128, 2048, 512, 64, 64), (1, 128, 2048, 512, 128, 64), (1, 128, 2048, 512, 127, 64), (1, 128, 2050, 512, 127, 64), (1, 128, 2050, 512, 128, 64), (8, 128, 2048, 512, 64, 64), (8, 128, 2048, 512, 128, 64), (8, 128, 2048, 512, 127, 64), (8, 128, 2050, 512, 127, 64), (8, 128, 2050, 512, 128, 64), ] dtypes = [torch.bfloat16, torch.float32] use_rope_list = [True, False] is_neox_style_list = [True, False] for B, H, S, kv_lora_rank, qk_rope_head_dim, rotary_dim in configs: for dtype in dtypes: for use_rope in use_rope_list: for is_neox_style in is_neox_style_list: self._test_rocm_fused_mla_kernel( B, H, S, kv_lora_rank, qk_rope_head_dim, rotary_dim, dtype, use_rope, is_neox_style, ) if __name__ == "__main__": unittest.main()