from typing import Optional, Tuple import torch def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv): torch.ops.sgl_kernel.lightning_attention_decode.default( q, k, v, past_kv, slope, output, new_kv ) def merge_state( v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor, v_merged: Optional[torch.Tensor] = None, s_merged: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: s_a = s_a.to(torch.float32) s_b = s_b.to(torch.float32) # Avoid creating new tensors if they are already provided if v_merged is None: v_merged = torch.empty_like(v_a) if s_merged is None: s_merged = torch.empty_like(s_a) torch.ops.sgl_kernel.merge_state.default(v_a, s_a, v_b, s_b, v_merged, s_merged) return v_merged, s_merged def merge_state_v2( v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor, v_merged: Optional[torch.Tensor] = None, s_merged: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: s_a = s_a.to(torch.float32) s_b = s_b.to(torch.float32) # TODO(DefTruth): Currently, the custom merge_attn_states kernel # does not support the FP8 data type and non - CUDA devices. # It may be necessary to fall back to using the Triton kernel. # Avoid creating new tensors if they are already provided if v_merged is None: v_merged = torch.empty_like(v_a) if s_merged is None: s_merged = torch.empty_like(s_a) torch.ops.sgl_kernel.merge_state_v2.default(v_a, s_a, v_b, s_b, v_merged, s_merged) return v_merged, s_merged def cutlass_mla_decode( q_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, seq_lens: torch.Tensor, page_table: torch.Tensor, workspace: torch.Tensor, sm_scale: float, num_kv_splits: int = 1, # Set to 1 to avoid cuda_graph issue by default. ) -> torch.Tensor: assert q_nope.ndim == 3, f"q_nope must be a 3D tensor, but got {q_nope.ndim}" assert q_pe.ndim == 3, f"q_pe must be a 3D tensor, but got {q_pe.ndim}" assert ( kv_c_and_k_pe_cache.ndim == 3 ), f"kv_c_and_k_pe_cache must be a 3D tensor, but got {kv_c_and_k_pe_cache.ndim}" B_q, H, D_q_nope = q_nope.shape B_q_2, H_2, D_q_pe = q_pe.shape assert (B_q == B_q_2) and (H == H_2) _, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape D_latent = 512 D_rope = 64 assert D_q_nope == D_latent assert D_q_pe == D_rope assert D_ckv == D_latent + D_rope MAX_HEADS = 128 assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}" if H < MAX_HEADS: q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope)) q_nope_padded[:, :H] = q_nope q_nope = q_nope_padded q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe)) q_pe_padded[:, :H] = q_pe q_pe = q_pe_padded assert len(page_table.shape) == 2 B_block_table, block_num = page_table.shape assert B_block_table == B_q assert block_num > 0, f"block num must be greater than 0, got {block_num}" assert block_num % (128 / PAGE_SIZE) == 0 # TODO(kaixih@nvidia): support fp8 assert q_nope.dtype in ( torch.float16, torch.bfloat16, ), f"q_nope.dtype needs to be fp16 or bf16 but got {q_nope.dtype}." assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype assert ( seq_lens.dtype == torch.int32 ), f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}." assert ( page_table.dtype == torch.int32 ), f"page_table.dtype needs to be int32 but got {page_table.dtype}." out = q_nope.new_empty((B_q, MAX_HEADS, D_latent)) torch.ops.sgl_kernel.cutlass_mla_decode.default( out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, ) return out[:, :H].contiguous() def cutlass_mla_get_workspace_size( max_seq_len: int, num_batches: int, sm_count: int = 0, num_kv_splits: int = 1, # Set to 1 to avoid cuda_graph issue by default. ) -> int: assert max_seq_len > 0, f"max_seq_len must be greater than 0, got {max_seq_len}" assert num_batches > 0, f"num_batches must be greater than 0, got {num_batches}" return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size.default( max_seq_len, num_batches, sm_count, num_kv_splits )