sglang.0.4.8.post1/sglang/test/srt/cpu/test_shared_expert.py

224 lines
6.0 KiB
Python

import itertools
import math
import unittest
# TODO: use interface in cpu.py
import sgl_kernel
import torch
import torch.nn as nn
from utils import (
BLOCK_K,
BLOCK_N,
SiluAndMul,
factor_for_scale,
fp8_max,
fp8_min,
per_token_quant_int8,
precision,
scaled_weight,
torch_naive_moe,
torch_w8a8_per_column_moe,
)
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
class TestSharedExpert(CustomTestCase):
M = [2, 121]
N = [32, 32 * 4]
K = [32, 32 * 2]
routed_scaling_factor = [16]
M_fp8 = [2, 12]
N_fp8 = [512]
K_fp8 = [256]
def _bf16_shared_expert(self, m, n, k, routed_scaling_factor):
dtype = torch.bfloat16
prepack = True
hidden_states = torch.randn(m, k, dtype=dtype) / k
w1 = torch.randn(2 * n, k, dtype=dtype)
w2 = torch.randn(k, n, dtype=dtype)
fused_output = torch.randn(m, k, dtype=dtype) / k
# fused moe mutates content in hs
hidden_states2 = hidden_states.clone()
# bfloat16
ref = torch_naive_moe(
hidden_states.float(),
w1.float(),
w2.float(),
fused_output.float(),
routed_scaling_factor,
).to(dtype=dtype)
res = torch.ops.sgl_kernel.shared_expert_cpu(
hidden_states,
w1,
w2,
fused_output,
routed_scaling_factor,
True,
False,
False,
None,
None,
None,
None,
None,
False,
)
atol = rtol = precision[ref.dtype]
torch.testing.assert_close(ref, res, atol=atol, rtol=rtol)
def test_bf16_shared_expert(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.routed_scaling_factor,
):
with self.subTest(
m=params[0],
n=params[1],
k=params[2],
routed_scaling_factor=params[3],
):
self._bf16_shared_expert(*params)
def _int8_shared_expert(self, m, n, k, routed_scaling_factor):
dtype = torch.bfloat16
prepack = True
hidden_states = torch.randn(m, k, dtype=dtype) / k
w1 = torch.randn(2 * n, k, dtype=dtype)
w2 = torch.randn(k, n, dtype=dtype)
fused_output = torch.randn(m, k, dtype=dtype) / k
# fused moe mutates content in hs
hidden_states2 = hidden_states.clone()
w1_q, w1_s = per_token_quant_int8(w1)
w2_q, w2_s = per_token_quant_int8(w2)
ref2 = torch_w8a8_per_column_moe(
hidden_states2.float(),
w1_q,
w2_q,
w1_s,
w2_s,
fused_output.float(),
routed_scaling_factor,
).to(dtype=dtype)
res2 = torch.ops.sgl_kernel.shared_expert_cpu(
hidden_states2,
w1_q,
w2_q,
fused_output,
routed_scaling_factor,
True,
True,
False,
w1_s,
w2_s,
None,
None,
None,
False,
)
atol = rtol = precision[ref2.dtype]
torch.testing.assert_close(ref2, res2, atol=atol, rtol=rtol)
def test_int8_shared_expert(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.routed_scaling_factor,
):
with self.subTest(
m=params[0],
n=params[1],
k=params[2],
routed_scaling_factor=params[3],
):
self._int8_shared_expert(*params)
def _fp8_shared_expert(self, M, N, K, routed_scaling_factor):
dtype = torch.bfloat16
prepack = True
a = torch.randn(M, K, dtype=dtype) / math.sqrt(K)
w1_fp32 = torch.randn(1, 2 * N, K)
w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w2_fp32 = torch.randn(1, K, N)
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w1s = torch.randn(1, 2 * N // BLOCK_N, K // BLOCK_K) * factor_for_scale
w2s = torch.randn(1, K // BLOCK_N, N // BLOCK_K) * factor_for_scale
w1_scaled = scaled_weight(w1, w1s).view(2 * N, K)
w2_scaled = scaled_weight(w2, w2s).view(K, N)
# change back to 2D
w1, w2 = w1.squeeze(0), w2.squeeze(0)
w1s, w2s = w1s.squeeze(0), w2s.squeeze(0)
w1_scaled, w2_scaled = w1_scaled.squeeze(0), w2_scaled.squeeze(0)
fused_out = torch.randn(M, K, dtype=dtype) / math.sqrt(K)
a2 = a.clone()
# ref
ic0 = torch.matmul(a.float(), w1_scaled.transpose(0, 1))
ic1 = SiluAndMul(ic0)
shared_out = torch.matmul(ic1, w2_scaled.transpose(0, 1))
ref_out = shared_out + fused_out.float() * routed_scaling_factor
ref_out = ref_out.to(dtype=dtype)
w1 = torch.ops.sgl_kernel.convert_weight_packed(w1) # [2N, K]
w2 = torch.ops.sgl_kernel.convert_weight_packed(w2) # [K, N]
out = torch.ops.sgl_kernel.shared_expert_cpu(
a2,
w1,
w2,
fused_out,
routed_scaling_factor,
True,
False,
True,
w1s,
w2s,
[BLOCK_N, BLOCK_K],
None,
None,
True,
)
atol = rtol = precision[ref_out.dtype]
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def test_fp8_shared_expert(self):
for params in itertools.product(
self.M_fp8,
self.N_fp8,
self.K_fp8,
self.routed_scaling_factor,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
routed_scaling_factor=params[3],
):
self._fp8_shared_expert(*params)
if __name__ == "__main__":
unittest.main()