from typing import Optional import pytest import torch from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe NUM_EXPERTS = [8, 64] TOP_KS = [2, 6] def quantize_weights( w: torch.Tensor, quant_type: str, group_size: Optional[int], zero_points: bool = False, ref_zero_points_after_scales: bool = False, ): assert quant_type in ["w4a16", "w4a16b8", "w8a16", "w8a16b128"] assert not zero_points or group_size is not None, ( "to have group zero points, group_size must be provided " "(-1 group_size is channelwise)" ) orig_device = w.device orig_type = w.dtype size_k, size_n = w.shape assert w.is_floating_point(), "w must be float" if group_size == -1: group_size = size_k # Reshape to [groupsize, -1] if group_size is not None and group_size < size_k: w = w.reshape((-1, group_size, size_n)) w = w.permute(1, 0, 2) w = w.reshape((group_size, -1)) # Compute scale for each group max_val = torch.max(w, 0, keepdim=True).values min_val = torch.min(w, 0, keepdim=True).values if quant_type == "w4a16": max_q_val = 15 min_q_val = 0 elif quant_type == "w4a16b8": max_q_val = 7 min_q_val = -1 elif quant_type == "w8a16": max_q_val = 255 min_q_val = 0 elif quant_type == "w8a16b128": max_q_val = 127 min_q_val = -128 w_s = torch.Tensor([1.0]).to(w.device) # unscaled case maybe_w_zp = None if group_size is not None: if zero_points: w_s = (max_val - min_val).clamp(min=1e-5) / max_q_val maybe_w_zp = ( torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() ) else: # If the bias is such that there are no possible negative/positive # values, set the max value to inf to avoid divide by 0 w_s = torch.max( abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), ) # Quantize w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) w_q = torch.clamp(w_q, min_q_val, max_q_val) # Compute ref (dequantized) # For some kernels (namely Machete) the zero-points are applied after the # scales are applied, for this case computing the reference in similar way # allows us to use tighter error tolerances in our unit tests. if ref_zero_points_after_scales and maybe_w_zp is not None: w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s else: w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s if quant_type == "w4a16b8": w_q += 8 elif quant_type == "w8a16b128": w_q += 128 # Restore original shapes if group_size is not None and group_size < size_k: def reshape_w(w): w = w.reshape((group_size, -1, size_n)) w = w.permute(1, 0, 2) w = w.reshape((size_k, size_n)).contiguous() return w w_q = reshape_w(w_q) w_ref = reshape_w(w_ref) w_s = w_s.reshape((-1, size_n)).contiguous() if maybe_w_zp is not None: maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() maybe_w_zp = maybe_w_zp.to(device=orig_device) return ( w_ref.to(device=orig_device), w_q.to(device=orig_device), w_s if group_size is not None else None, maybe_w_zp, ) def torch_moe(a, w1, w2, score, topk): B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) score = torch.softmax(score, dim=-1, dtype=torch.float32) topk_weight, topk_ids = torch.topk(score, topk) topk_weight = topk_weight.view(-1) topk_ids = topk_ids.view(-1) for i in range(w1.shape[0]): mask = topk_ids == i if mask.sum(): out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose( 0, 1 ) return ( out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) ).sum(dim=1) # fork from https://github.com/vllm-project/vllm/blob/main/tests/kernels/test_moe.py @pytest.mark.parametrize("m", [1, 32, 222]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("group_size", [64, 128]) @pytest.mark.parametrize("has_zp", [True, False]) @pytest.mark.parametrize("weight_bits", [8]) # [4, 8]) def test_fused_moe_wn16( m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, group_size: int, has_zp: bool, weight_bits: int, ): print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits) a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 score = torch.randn((m, e), device="cuda", dtype=dtype) if weight_bits == 4: pack_factor = 2 quant_type = "w4a16" if has_zp else "w4a16b8" elif weight_bits == 8: pack_factor = 1 quant_type = "w8a16" if has_zp else "w8a16b128" w1_ref = w1.clone() w2_ref = w2.clone() w1_qweight = torch.empty( (e, 2 * n, k // pack_factor), device="cuda", dtype=torch.uint8 ) w2_qweight = torch.empty((e, k, n // pack_factor), device="cuda", dtype=torch.uint8) w1_scales = torch.empty((e, 2 * n, k // group_size), device="cuda", dtype=dtype) w2_scales = torch.empty((e, k, n // group_size), device="cuda", dtype=dtype) w1_qzeros = torch.empty( (e, 2 * n // pack_factor, k // group_size), device="cuda", dtype=torch.uint8 ) w2_qzeros = torch.empty( (e, k // pack_factor, n // group_size), device="cuda", dtype=torch.uint8 ) for i in range(e * 2): expert_id = i % e if i // e == 0: w, w_ref, w_qweight, w_scales, w_qzeros = ( w1, w1_ref, w1_qweight, w1_scales, w1_qzeros, ) else: w, w_ref, w_qweight, w_scales, w_qzeros = ( w2, w2_ref, w2_qweight, w2_scales, w2_qzeros, ) weight, qweight, scales, qzeros = quantize_weights( w[expert_id].T, quant_type, group_size, has_zp, False ) weight = weight.T qweight = qweight.T.contiguous().to(torch.uint8) scales = scales.T if has_zp: qzeros = qzeros.T.contiguous().to(torch.uint8) if weight_bits == 4: qweight = qweight[:, 1::2] * 16 + qweight[:, ::2] if has_zp: qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :] w_ref[expert_id] = weight w_qweight[expert_id] = qweight w_scales[expert_id] = scales if has_zp: w_qzeros[expert_id] = qzeros triton_output = fused_moe( a, w1_qweight, w2_qweight, score, topk, renormalize=False, use_int4_w4a16=weight_bits == 4, use_int8_w8a16=weight_bits == 8, w1_scale=w1_scales, w2_scale=w2_scales, w1_zp=w1_qzeros if has_zp else None, w2_zp=w2_qzeros if has_zp else None, block_shape=[0, group_size], ) torch_output = torch_moe(a, w1_ref, w2_ref, score, topk) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)