139 lines
4.4 KiB
Python
139 lines
4.4 KiB
Python
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
|
|
|
|
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
|
|
torch.ops.sgl_kernel.lightning_attention_decode.default(
|
|
q, k, v, past_kv, slope, output, new_kv
|
|
)
|
|
|
|
|
|
def merge_state(
|
|
v_a: torch.Tensor,
|
|
s_a: torch.Tensor,
|
|
v_b: torch.Tensor,
|
|
s_b: torch.Tensor,
|
|
v_merged: Optional[torch.Tensor] = None,
|
|
s_merged: Optional[torch.Tensor] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
s_a = s_a.to(torch.float32)
|
|
s_b = s_b.to(torch.float32)
|
|
# Avoid creating new tensors if they are already provided
|
|
if v_merged is None:
|
|
v_merged = torch.empty_like(v_a)
|
|
if s_merged is None:
|
|
s_merged = torch.empty_like(s_a)
|
|
torch.ops.sgl_kernel.merge_state.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
|
|
return v_merged, s_merged
|
|
|
|
|
|
def merge_state_v2(
|
|
v_a: torch.Tensor,
|
|
s_a: torch.Tensor,
|
|
v_b: torch.Tensor,
|
|
s_b: torch.Tensor,
|
|
v_merged: Optional[torch.Tensor] = None,
|
|
s_merged: Optional[torch.Tensor] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
s_a = s_a.to(torch.float32)
|
|
s_b = s_b.to(torch.float32)
|
|
# TODO(DefTruth): Currently, the custom merge_attn_states kernel
|
|
# does not support the FP8 data type and non - CUDA devices.
|
|
# It may be necessary to fall back to using the Triton kernel.
|
|
|
|
# Avoid creating new tensors if they are already provided
|
|
if v_merged is None:
|
|
v_merged = torch.empty_like(v_a)
|
|
if s_merged is None:
|
|
s_merged = torch.empty_like(s_a)
|
|
torch.ops.sgl_kernel.merge_state_v2.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
|
|
return v_merged, s_merged
|
|
|
|
|
|
def cutlass_mla_decode(
|
|
q_nope: torch.Tensor,
|
|
q_pe: torch.Tensor,
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
page_table: torch.Tensor,
|
|
workspace: torch.Tensor,
|
|
sm_scale: float,
|
|
num_kv_splits: int = 1, # Set to 1 to avoid cuda_graph issue by default.
|
|
) -> torch.Tensor:
|
|
assert q_nope.ndim == 3, f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
|
|
assert q_pe.ndim == 3, f"q_pe must be a 3D tensor, but got {q_pe.ndim}"
|
|
assert (
|
|
kv_c_and_k_pe_cache.ndim == 3
|
|
), f"kv_c_and_k_pe_cache must be a 3D tensor, but got {kv_c_and_k_pe_cache.ndim}"
|
|
|
|
B_q, H, D_q_nope = q_nope.shape
|
|
B_q_2, H_2, D_q_pe = q_pe.shape
|
|
assert (B_q == B_q_2) and (H == H_2)
|
|
|
|
_, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape
|
|
|
|
D_latent = 512
|
|
D_rope = 64
|
|
assert D_q_nope == D_latent
|
|
assert D_q_pe == D_rope
|
|
assert D_ckv == D_latent + D_rope
|
|
|
|
MAX_HEADS = 128
|
|
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
|
|
if H < MAX_HEADS:
|
|
q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope))
|
|
q_nope_padded[:, :H] = q_nope
|
|
q_nope = q_nope_padded
|
|
|
|
q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe))
|
|
q_pe_padded[:, :H] = q_pe
|
|
q_pe = q_pe_padded
|
|
|
|
assert len(page_table.shape) == 2
|
|
B_block_table, block_num = page_table.shape
|
|
assert B_block_table == B_q
|
|
assert block_num > 0, f"block num must be greater than 0, got {block_num}"
|
|
assert block_num % (128 / PAGE_SIZE) == 0
|
|
|
|
# TODO(kaixih@nvidia): support fp8
|
|
assert q_nope.dtype in (
|
|
torch.float16,
|
|
torch.bfloat16,
|
|
), f"q_nope.dtype needs to be fp16 or bf16 but got {q_nope.dtype}."
|
|
assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype
|
|
assert (
|
|
seq_lens.dtype == torch.int32
|
|
), f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}."
|
|
assert (
|
|
page_table.dtype == torch.int32
|
|
), f"page_table.dtype needs to be int32 but got {page_table.dtype}."
|
|
|
|
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent))
|
|
|
|
torch.ops.sgl_kernel.cutlass_mla_decode.default(
|
|
out,
|
|
q_nope,
|
|
q_pe,
|
|
kv_c_and_k_pe_cache,
|
|
seq_lens,
|
|
page_table,
|
|
workspace,
|
|
sm_scale,
|
|
num_kv_splits,
|
|
)
|
|
return out[:, :H].contiguous()
|
|
|
|
|
|
def cutlass_mla_get_workspace_size(
|
|
max_seq_len: int,
|
|
num_batches: int,
|
|
sm_count: int = 0,
|
|
num_kv_splits: int = 1, # Set to 1 to avoid cuda_graph issue by default.
|
|
) -> int:
|
|
assert max_seq_len > 0, f"max_seq_len must be greater than 0, got {max_seq_len}"
|
|
assert num_batches > 0, f"num_batches must be greater than 0, got {num_batches}"
|
|
return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size.default(
|
|
max_seq_len, num_batches, sm_count, num_kv_splits
|
|
)
|