239 lines
7.7 KiB
Python
239 lines
7.7 KiB
Python
from typing import Optional
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from sglang.srt.layers.activation import SiluAndMul
|
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
|
|
|
NUM_EXPERTS = [8, 64]
|
|
TOP_KS = [2, 6]
|
|
|
|
|
|
def quantize_weights(
|
|
w: torch.Tensor,
|
|
quant_type: str,
|
|
group_size: Optional[int],
|
|
zero_points: bool = False,
|
|
ref_zero_points_after_scales: bool = False,
|
|
):
|
|
assert quant_type in ["w4a16", "w4a16b8", "w8a16", "w8a16b128"]
|
|
assert not zero_points or group_size is not None, (
|
|
"to have group zero points, group_size must be provided "
|
|
"(-1 group_size is channelwise)"
|
|
)
|
|
|
|
orig_device = w.device
|
|
orig_type = w.dtype
|
|
size_k, size_n = w.shape
|
|
|
|
assert w.is_floating_point(), "w must be float"
|
|
|
|
if group_size == -1:
|
|
group_size = size_k
|
|
|
|
# Reshape to [groupsize, -1]
|
|
if group_size is not None and group_size < size_k:
|
|
w = w.reshape((-1, group_size, size_n))
|
|
w = w.permute(1, 0, 2)
|
|
w = w.reshape((group_size, -1))
|
|
|
|
# Compute scale for each group
|
|
max_val = torch.max(w, 0, keepdim=True).values
|
|
min_val = torch.min(w, 0, keepdim=True).values
|
|
|
|
if quant_type == "w4a16":
|
|
max_q_val = 15
|
|
min_q_val = 0
|
|
elif quant_type == "w4a16b8":
|
|
max_q_val = 7
|
|
min_q_val = -1
|
|
elif quant_type == "w8a16":
|
|
max_q_val = 255
|
|
min_q_val = 0
|
|
elif quant_type == "w8a16b128":
|
|
max_q_val = 127
|
|
min_q_val = -128
|
|
|
|
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
|
|
maybe_w_zp = None
|
|
if group_size is not None:
|
|
if zero_points:
|
|
w_s = (max_val - min_val).clamp(min=1e-5) / max_q_val
|
|
maybe_w_zp = (
|
|
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
|
|
)
|
|
else:
|
|
# If the bias is such that there are no possible negative/positive
|
|
# values, set the max value to inf to avoid divide by 0
|
|
w_s = torch.max(
|
|
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
|
|
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
|
|
)
|
|
|
|
# Quantize
|
|
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
|
|
w_q = torch.clamp(w_q, min_q_val, max_q_val)
|
|
|
|
# Compute ref (dequantized)
|
|
# For some kernels (namely Machete) the zero-points are applied after the
|
|
# scales are applied, for this case computing the reference in similar way
|
|
# allows us to use tighter error tolerances in our unit tests.
|
|
if ref_zero_points_after_scales and maybe_w_zp is not None:
|
|
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
|
|
else:
|
|
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
|
|
|
|
if quant_type == "w4a16b8":
|
|
w_q += 8
|
|
elif quant_type == "w8a16b128":
|
|
w_q += 128
|
|
|
|
# Restore original shapes
|
|
if group_size is not None and group_size < size_k:
|
|
|
|
def reshape_w(w):
|
|
w = w.reshape((group_size, -1, size_n))
|
|
w = w.permute(1, 0, 2)
|
|
w = w.reshape((size_k, size_n)).contiguous()
|
|
return w
|
|
|
|
w_q = reshape_w(w_q)
|
|
w_ref = reshape_w(w_ref)
|
|
w_s = w_s.reshape((-1, size_n)).contiguous()
|
|
|
|
if maybe_w_zp is not None:
|
|
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
|
|
maybe_w_zp = maybe_w_zp.to(device=orig_device)
|
|
|
|
return (
|
|
w_ref.to(device=orig_device),
|
|
w_q.to(device=orig_device),
|
|
w_s if group_size is not None else None,
|
|
maybe_w_zp,
|
|
)
|
|
|
|
|
|
def torch_moe(a, w1, w2, score, topk):
|
|
B, D = a.shape
|
|
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
|
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
|
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
|
topk_weight, topk_ids = torch.topk(score, topk)
|
|
topk_weight = topk_weight.view(-1)
|
|
topk_ids = topk_ids.view(-1)
|
|
for i in range(w1.shape[0]):
|
|
mask = topk_ids == i
|
|
if mask.sum():
|
|
out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
|
|
0, 1
|
|
)
|
|
return (
|
|
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
|
).sum(dim=1)
|
|
|
|
|
|
# fork from https://github.com/vllm-project/vllm/blob/main/tests/kernels/test_moe.py
|
|
@pytest.mark.parametrize("m", [1, 32, 222])
|
|
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
|
@pytest.mark.parametrize("k", [128, 1024])
|
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
|
@pytest.mark.parametrize("topk", TOP_KS)
|
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
|
@pytest.mark.parametrize("group_size", [64, 128])
|
|
@pytest.mark.parametrize("has_zp", [True, False])
|
|
@pytest.mark.parametrize("weight_bits", [8]) # [4, 8])
|
|
def test_fused_moe_wn16(
|
|
m: int,
|
|
n: int,
|
|
k: int,
|
|
e: int,
|
|
topk: int,
|
|
dtype: torch.dtype,
|
|
group_size: int,
|
|
has_zp: bool,
|
|
weight_bits: int,
|
|
):
|
|
print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
|
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
|
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
|
|
|
if weight_bits == 4:
|
|
pack_factor = 2
|
|
quant_type = "w4a16" if has_zp else "w4a16b8"
|
|
elif weight_bits == 8:
|
|
pack_factor = 1
|
|
quant_type = "w8a16" if has_zp else "w8a16b128"
|
|
|
|
w1_ref = w1.clone()
|
|
w2_ref = w2.clone()
|
|
w1_qweight = torch.empty(
|
|
(e, 2 * n, k // pack_factor), device="cuda", dtype=torch.uint8
|
|
)
|
|
w2_qweight = torch.empty((e, k, n // pack_factor), device="cuda", dtype=torch.uint8)
|
|
w1_scales = torch.empty((e, 2 * n, k // group_size), device="cuda", dtype=dtype)
|
|
w2_scales = torch.empty((e, k, n // group_size), device="cuda", dtype=dtype)
|
|
w1_qzeros = torch.empty(
|
|
(e, 2 * n // pack_factor, k // group_size), device="cuda", dtype=torch.uint8
|
|
)
|
|
w2_qzeros = torch.empty(
|
|
(e, k // pack_factor, n // group_size), device="cuda", dtype=torch.uint8
|
|
)
|
|
|
|
for i in range(e * 2):
|
|
expert_id = i % e
|
|
if i // e == 0:
|
|
w, w_ref, w_qweight, w_scales, w_qzeros = (
|
|
w1,
|
|
w1_ref,
|
|
w1_qweight,
|
|
w1_scales,
|
|
w1_qzeros,
|
|
)
|
|
else:
|
|
w, w_ref, w_qweight, w_scales, w_qzeros = (
|
|
w2,
|
|
w2_ref,
|
|
w2_qweight,
|
|
w2_scales,
|
|
w2_qzeros,
|
|
)
|
|
weight, qweight, scales, qzeros = quantize_weights(
|
|
w[expert_id].T, quant_type, group_size, has_zp, False
|
|
)
|
|
weight = weight.T
|
|
qweight = qweight.T.contiguous().to(torch.uint8)
|
|
scales = scales.T
|
|
if has_zp:
|
|
qzeros = qzeros.T.contiguous().to(torch.uint8)
|
|
if weight_bits == 4:
|
|
qweight = qweight[:, 1::2] * 16 + qweight[:, ::2]
|
|
if has_zp:
|
|
qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :]
|
|
|
|
w_ref[expert_id] = weight
|
|
w_qweight[expert_id] = qweight
|
|
w_scales[expert_id] = scales
|
|
if has_zp:
|
|
w_qzeros[expert_id] = qzeros
|
|
|
|
triton_output = fused_moe(
|
|
a,
|
|
w1_qweight,
|
|
w2_qweight,
|
|
score,
|
|
topk,
|
|
renormalize=False,
|
|
use_int4_w4a16=weight_bits == 4,
|
|
use_int8_w8a16=weight_bits == 8,
|
|
w1_scale=w1_scales,
|
|
w2_scale=w2_scales,
|
|
w1_zp=w1_qzeros if has_zp else None,
|
|
w2_zp=w2_qzeros if has_zp else None,
|
|
block_shape=[0, group_size],
|
|
)
|
|
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk)
|
|
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|