223 lines
7.4 KiB
Python
223 lines
7.4 KiB
Python
import itertools
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
from sglang.srt.layers.activation import SiluAndMul
|
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
|
from sglang.test.test_utils import CustomTestCase
|
|
|
|
|
|
# For test
|
|
def native_per_token_group_quant_int8(x, group_size, eps=1e-10, dtype=torch.int8):
|
|
"""Function to perform per-token-group quantization on an input tensor `x` using native torch.
|
|
|
|
It converts the tensor values into float8 values and returns the
|
|
quantized tensor along with the scaling factor used for quantization.
|
|
Note that only `torch.float8_e4m3fn` is supported for now.
|
|
"""
|
|
assert (
|
|
x.shape[-1] % group_size == 0
|
|
), "the last dimension of `x` cannot be divisible by `group_size`"
|
|
assert x.is_contiguous(), "`x` is not contiguous"
|
|
|
|
iinfo = torch.iinfo(dtype)
|
|
int8_min = iinfo.min
|
|
int8_max = iinfo.max
|
|
|
|
x_ = x.reshape(x.numel() // group_size, group_size)
|
|
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
|
|
x_s = amax / int8_max
|
|
x_q = (x_ / x_s).clamp(min=int8_min, max=int8_max).to(dtype)
|
|
x_q = x_q.reshape(x.shape)
|
|
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))
|
|
|
|
return x_q, x_s
|
|
|
|
|
|
# For test
|
|
def native_w8a8_block_int8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
|
|
"""This function performs matrix multiplication with block-wise quantization using native torch.
|
|
|
|
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
|
The output is returned in the specified `output_dtype`.
|
|
"""
|
|
|
|
A = A.to(torch.float32)
|
|
B = B.to(torch.float32)
|
|
assert A.shape[-1] == B.shape[-1]
|
|
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
|
assert len(block_size) == 2
|
|
block_n, block_k = block_size[0], block_size[1]
|
|
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
|
|
assert A.shape[:-1] == As.shape[:-1]
|
|
|
|
M = A.numel() // A.shape[-1]
|
|
N, K = B.shape
|
|
origin_C_shape = A.shape[:-1] + (N,)
|
|
A = A.reshape(M, A.shape[-1])
|
|
As = As.reshape(M, As.shape[-1])
|
|
n_tiles = (N + block_n - 1) // block_n
|
|
k_tiles = (K + block_k - 1) // block_k
|
|
assert n_tiles == Bs.shape[0]
|
|
assert k_tiles == Bs.shape[1]
|
|
|
|
C_shape = (M, N)
|
|
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
|
|
|
|
A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)]
|
|
B_tiles = [
|
|
[
|
|
B[
|
|
j * block_n : min((j + 1) * block_n, N),
|
|
i * block_k : min((i + 1) * block_k, K),
|
|
]
|
|
for i in range(k_tiles)
|
|
]
|
|
for j in range(n_tiles)
|
|
]
|
|
C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)]
|
|
As_tiles = [As[:, i : i + 1] for i in range(k_tiles)]
|
|
|
|
for i in range(k_tiles):
|
|
for j in range(n_tiles):
|
|
a = A_tiles[i]
|
|
b = B_tiles[j][i]
|
|
c = C_tiles[j]
|
|
s = As_tiles[i] * Bs[j][i]
|
|
c[:, :] += torch.matmul(a, b.t()) * s
|
|
|
|
C = C.reshape(origin_C_shape).to(output_dtype)
|
|
return C
|
|
|
|
|
|
# For test
|
|
def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
|
|
"""This function performs fused moe with block-wise quantization using native torch."""
|
|
|
|
B, D = a.shape
|
|
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
|
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
|
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
|
topk_weight, topk_ids = torch.topk(score, topk)
|
|
topk_weight = topk_weight.view(-1)
|
|
topk_ids = topk_ids.view(-1)
|
|
|
|
_, block_k = block_shape[0], block_shape[1]
|
|
a_q, a_s = native_per_token_group_quant_int8(a, block_k)
|
|
for i in range(w1.shape[0]):
|
|
mask = topk_ids == i
|
|
if mask.sum():
|
|
inter_out = native_w8a8_block_int8_matmul(
|
|
a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype
|
|
)
|
|
act_out = SiluAndMul().forward_native(inter_out)
|
|
act_out_q, act_out_s = native_per_token_group_quant_int8(act_out, block_k)
|
|
act_out = act_out.to(torch.float32)
|
|
out[mask] = native_w8a8_block_int8_matmul(
|
|
act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype
|
|
)
|
|
return (
|
|
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
|
).sum(dim=1)
|
|
|
|
|
|
class TestW8A8BlockINT8FusedMoE(CustomTestCase):
|
|
DTYPES = [torch.half, torch.bfloat16]
|
|
M = [1, 33, 64, 222]
|
|
N = [128, 1024]
|
|
K = [256, 4096]
|
|
E = [8, 24]
|
|
TOP_KS = [2, 6]
|
|
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
|
|
BLOCK_SIZE = [[128, 128]]
|
|
SEEDS = [0]
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
if not torch.cuda.is_available():
|
|
raise unittest.SkipTest("CUDA is not available")
|
|
torch.set_default_device("cuda")
|
|
|
|
def _w8a8_block_int8_fused_moe(self, M, N, K, E, topk, block_size, dtype, seed):
|
|
torch.manual_seed(seed)
|
|
# NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
|
|
factor_for_scale = 1e-2
|
|
int8_info = torch.iinfo(torch.int8)
|
|
int8_max, int8_min = int8_info.max, int8_info.min
|
|
|
|
a = torch.randn((M, K), dtype=dtype) / 10
|
|
|
|
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 * int8_max
|
|
w1 = w1_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
|
|
|
|
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 * int8_max
|
|
w2 = w2_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
|
|
|
|
block_n, block_k = block_size[0], block_size[1]
|
|
n_tiles_w1 = (2 * N + block_n - 1) // block_n
|
|
n_tiles_w2 = (K + block_n - 1) // block_n
|
|
k_tiles_w1 = (K + block_k - 1) // block_k
|
|
k_tiles_w2 = (N + block_k - 1) // block_k
|
|
|
|
w1_s = (
|
|
torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
|
|
* factor_for_scale
|
|
)
|
|
w2_s = (
|
|
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
|
|
* factor_for_scale
|
|
)
|
|
|
|
score = torch.randn((M, E), dtype=dtype)
|
|
|
|
with torch.inference_mode():
|
|
out = fused_moe(
|
|
a,
|
|
w1,
|
|
w2,
|
|
score,
|
|
topk,
|
|
renormalize=False,
|
|
use_int8_w8a8=True,
|
|
w1_scale=w1_s,
|
|
w2_scale=w2_s,
|
|
block_shape=block_size,
|
|
)
|
|
ref_out = torch_w8a8_block_int8_moe(
|
|
a, w1, w2, w1_s, w2_s, score, topk, block_size
|
|
)
|
|
|
|
self.assertTrue(
|
|
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
|
|
/ torch.mean(torch.abs(ref_out.to(torch.float32)))
|
|
< 0.02
|
|
)
|
|
|
|
def test_w8a8_block_int8_fused_moe(self):
|
|
for params in itertools.product(
|
|
self.M,
|
|
self.N,
|
|
self.K,
|
|
self.E,
|
|
self.TOP_KS,
|
|
self.BLOCK_SIZE,
|
|
self.DTYPES,
|
|
self.SEEDS,
|
|
):
|
|
with self.subTest(
|
|
M=params[0],
|
|
N=params[1],
|
|
K=params[2],
|
|
E=params[3],
|
|
topk=params[4],
|
|
block_size=params[5],
|
|
dtype=params[6],
|
|
seed=params[7],
|
|
):
|
|
self._w8a8_block_int8_fused_moe(*params)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main(verbosity=2)
|