sglang_v0.5.2/sglang/test/srt/cpu/test_moe.py

266 lines
7.1 KiB
Python

import itertools
import math
import unittest
# TODO: use interface in cpu.py
import sgl_kernel
import torch
kernel = torch.ops.sgl_kernel
torch.manual_seed(1234)
from utils import (
BLOCK_K,
BLOCK_N,
factor_for_scale,
fp8_max,
fp8_min,
native_fp8_fused_moe,
precision,
scaled_weight,
torch_naive_fused_moe,
torch_w8a8_per_column_fused_moe,
)
from sglang.test.test_utils import CustomTestCase
def fused_moe(a, w1, w2, score, topk, renormalize, prepack):
G = 1
topk_group = 1
B, D = a.shape
topk_weights = torch.empty(B, topk, dtype=torch.float32)
topk_ids = torch.empty(B, topk, dtype=torch.int32)
topk_weights, topk_ids = kernel.grouped_topk_cpu(
a, score, topk, renormalize, G, topk_group, 0, None, None
)
packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1
packed_w2 = kernel.convert_weight_packed(w2) if prepack else w2
inplace = True
return kernel.fused_experts_cpu(
a,
packed_w1,
packed_w2,
topk_weights,
topk_ids,
inplace,
False,
False,
None,
None,
None,
None,
None,
prepack,
)
class TestFusedExperts(CustomTestCase):
M = [2, 114]
N = [32]
K = [32]
E = [4]
topk = [2]
renormalize = [False, True]
M_int8 = [1, 39]
N_int8 = [128]
K_int8 = [256]
E_int8 = [8]
topk_int8 = [3]
M_fp8 = [2, 121]
N_fp8 = [352, 512]
K_fp8 = [256, 320]
E_fp8 = [8]
topk_fp8 = [4]
def _bf16_moe(self, m, n, k, e, topk, renormalize):
dtype = torch.bfloat16
prepack = True
a = torch.randn((m, k), device="cpu", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cpu", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cpu", dtype=dtype) / 10
score = torch.randn((m, e), device="cpu", dtype=dtype)
torch_output = torch_naive_fused_moe(a, w1, w2, score, topk, renormalize)
fused_output = fused_moe(a, w1, w2, score, topk, renormalize, prepack)
atol = rtol = precision[torch_output.dtype]
torch.testing.assert_close(torch_output, fused_output, atol=atol, rtol=rtol)
def test_bf16_moe(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.E,
self.topk,
self.renormalize,
):
with self.subTest(
m=params[0],
n=params[1],
k=params[2],
e=params[3],
topk=params[4],
renormalize=params[5],
):
self._bf16_moe(*params)
def _int8_moe(self, M, N, K, E, topk):
dtype = torch.bfloat16
prepack = True
# Initialize int8 quantization parameters
int8_factor_for_scale = 1e-2
int8_max = 127
int8_min = -128
# Input tensor
# M * K
a = torch.randn((M, K), dtype=dtype) / math.sqrt(K)
# Generate int8 weights
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2
w1 = (w1_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2
w2 = (w2_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
# Generate scale for each column (per-column quantization)
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * int8_factor_for_scale
w2_s = torch.rand(E, K, device=w2_fp32.device) * int8_factor_for_scale
# Calculate routing
score = torch.randn((M, E), dtype=dtype)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
ref_out = torch_w8a8_per_column_fused_moe(
a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, topk
)
inplace = True
packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1
packed_w2 = kernel.convert_weight_packed(w2) if prepack else w2
out = kernel.fused_experts_cpu(
a,
packed_w1,
packed_w2,
topk_weight,
topk_ids.to(torch.int32),
inplace,
True,
False,
w1_s,
w2_s,
None,
None,
None,
prepack,
)
atol = rtol = precision[ref_out.dtype]
# Increase the tolerance for large input shapes
if M > 35:
atol = rtol = 0.02
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def test_int8_moe(self):
for params in itertools.product(
self.M_int8,
self.N_int8,
self.K_int8,
self.E_int8,
self.topk_int8,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
E=params[3],
topk=params[4],
):
self._int8_moe(*params)
def _fp8_moe(self, M, N, K, E, topk):
dtype = torch.bfloat16
a = torch.randn(M, K, dtype=dtype) / math.sqrt(K)
w1_fp32 = torch.randn(E, 2 * N, K)
w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w2_fp32 = torch.randn(E, K, N)
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w1s = (
torch.randn(E, math.ceil(2 * N / BLOCK_N), math.ceil(K / BLOCK_K))
* factor_for_scale
)
w2s = (
torch.randn(E, math.ceil(K / BLOCK_N), math.ceil(N / BLOCK_K))
* factor_for_scale
)
w1_scaled = scaled_weight(w1, w1s)
w2_scaled = scaled_weight(w2, w2s)
score = torch.randn((M, E), dtype=dtype)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
w1 = kernel.convert_weight_packed(w1)
w2 = kernel.convert_weight_packed(w2)
ref_out = native_fp8_fused_moe(
a, w1_scaled, w2_scaled, topk_weight, topk_ids, topk
)
out = kernel.fused_experts_cpu(
a,
w1,
w2,
topk_weight,
topk_ids.to(torch.int32),
False,
False,
True,
w1s,
w2s,
[BLOCK_N, BLOCK_K],
None,
None,
True,
)
atol = rtol = precision[dtype]
torch.testing.assert_close(ref_out.bfloat16(), out, atol=atol, rtol=rtol)
def test_fp8_moe(self):
for params in itertools.product(
self.M_fp8,
self.N_fp8,
self.K_fp8,
self.E_fp8,
self.topk_fp8,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
E=params[3],
topk=params[4],
):
self._fp8_moe(*params)
if __name__ == "__main__":
unittest.main()