sglang.0.4.8.post1/sglang/sgl-kernel/tests/test_merge_state.py

143 lines
4.1 KiB
Python

# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/55576c626421b5ee7e7ebe74afd26465c8ae863f/flashinfer/triton/kernels/cascade.py
from typing import List
import pytest
import torch
import triton
import triton.language as tl
from sgl_kernel import merge_state
def check_input(x: torch.Tensor):
assert x.is_cuda, f"{str(x)} must be a CUDA Tensor"
assert x.is_contiguous(), f"{str(x)} must be contiguous"
def check_dim(d, x: torch.Tensor):
assert x.dim() == d, f"{str(x)} must be a {d}D tensor"
def check_shape(a: torch.Tensor, b: torch.Tensor):
assert a.dim() == b.dim(), "tensors should have same dim"
for i in range(a.dim()):
assert a.size(i) == b.size(
i
), f"tensors shape mismatch, {a.size()} and {b.size()}"
def check_device(tensors: List[torch.Tensor]):
device = tensors[0].device
for t in tensors:
assert (
t.device == device
), f"All tensors should be on the same device, but got {device} and {t.device}"
@triton.jit
def state_merge(o, m, d, other_o, other_m, other_d):
m_max = tl.maximum(m, other_m)
d = d * tl.exp2(m - m_max) + other_d * tl.exp2(other_m - m_max)
o = o * tl.exp2(m - m_max) + other_o * tl.exp2(other_m - m_max)
return o, m_max, d
@triton.jit
def state_normalize(o, m, d):
o = o / d
return o, m, d
@triton.jit
def state_get_lse(o, m, d):
return m + tl.log2(d)
@triton.jit
def merge_state_kernel(
v_a_ptr,
s_a_ptr,
v_b_ptr,
s_b_ptr,
v_merged_ptr,
s_merged_ptr,
num_heads,
head_dim,
bdx: tl.constexpr,
bdy: tl.constexpr,
):
pos = tl.program_id(axis=0)
for tx in tl.range(bdx):
for head_idx in tl.range(bdy):
s_a_val = tl.load(s_a_ptr + pos * num_heads + head_idx)
s_b_val = tl.load(s_b_ptr + pos * num_heads + head_idx)
offsets = (pos * num_heads + head_idx) * head_dim + tx
v_a = tl.load(v_a_ptr + offsets)
v_b = tl.load(v_b_ptr + offsets)
v_merged, s_max, d = state_merge(
o=v_a, m=s_a_val, d=1, other_o=v_b, other_m=s_b_val, other_d=1
)
v_merged, s_max, d = state_normalize(v_merged, s_max, d)
v_merged_offset = (pos * num_heads + head_idx) * head_dim + tx
tl.store(v_merged_ptr + v_merged_offset, v_merged)
if s_merged_ptr:
tl.store(
s_merged_ptr + pos * num_heads + head_idx,
tl.log2(d) + s_max,
)
def merge_state_triton(
v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
):
check_input(v_a)
check_input(s_a)
check_input(v_b)
check_input(s_b)
check_device([v_a, s_a, v_b, s_b])
check_dim(3, v_a)
check_dim(2, s_a)
check_dim(3, v_b)
check_dim(2, s_b)
check_shape(v_a, v_b)
check_shape(s_a, s_b)
assert v_a.size(0) == s_a.size(0)
assert v_a.size(1) == s_b.size(1)
s_a = s_a.to(torch.float32)
s_b = s_b.to(torch.float32)
seq_len = v_a.size(0)
num_heads = v_a.size(1)
head_dim = v_a.size(2)
v_merged = torch.empty_like(v_a).to(s_a.device)
s_merged = torch.empty((seq_len, num_heads)).to(s_a.device)
bdx = head_dim
bdy = num_heads
merge_state_kernel[lambda meta: (seq_len,)](
v_a, s_a, v_b, s_b, v_merged, s_merged, num_heads, head_dim, bdx=bdx, bdy=bdy
)
return v_merged, s_merged
@pytest.mark.parametrize("seq_len", [2048])
@pytest.mark.parametrize("num_heads", [32])
@pytest.mark.parametrize("head_dim", [128])
def test_merge_state(seq_len, num_heads, head_dim):
va = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
sa = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
vb = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
sb = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
v_merged, s_merged = merge_state_triton(va, sa, vb, sb)
v_merged_std, s_merged_std = merge_state(va, sa, vb, sb)
assert torch.allclose(v_merged, v_merged_std, atol=1e-2)
assert torch.allclose(s_merged, s_merged_std, atol=1e-2)
if __name__ == "__main__":
pytest.main([__file__])