sglang_v0.5.2/sglang/sgl-kernel/tests/test_dsv3_router_gemm.py

36 lines
1.1 KiB
Python

import pytest
import torch
import torch.nn.functional as F
from sgl_kernel import dsv3_router_gemm
@pytest.mark.parametrize("num_tokens", [i + 1 for i in range(16)])
@pytest.mark.parametrize("num_experts", [256, 384])
def test_dsv3_router_gemm(num_tokens, num_experts):
hidden_dim = 7168
mat_a = torch.randn(
(num_tokens, hidden_dim), dtype=torch.bfloat16, device="cuda"
).contiguous()
mat_b = torch.randn(
(num_experts, hidden_dim), dtype=torch.bfloat16, device="cuda"
).contiguous()
bf16_ref = F.linear(mat_a, mat_b)
float_ref = bf16_ref.to(torch.float32)
bf16_output = dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16)
float_output = dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32)
assert torch.allclose(
bf16_output, bf16_ref, rtol=1e-2, atol=1e-3
), "Router GEMM output in bf16 dtype mismatch with torch.nn.functional.linear reference"
assert torch.allclose(
float_output, float_ref, rtol=1e-2, atol=1e-3
), "Router GEMM output in float32 dtype mismatch with torch.nn.functional.linear reference"
if __name__ == "__main__":
pytest.main([__file__])