sglang0.4.5.post1/test/srt/test_fp8_kernel.py

129 lines
4.0 KiB
Python

import unittest
import torch
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
w8a8_block_fp8_matmul,
)
from sglang.test.test_utils import CustomTestCase
class TestFP8Base(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.M = 256
# test non-aligned
cls.N = 1024 + 64
cls.K = 512
cls.group_size = 128
cls.quant_type = torch.float8_e4m3fn
cls.output_type = torch.bfloat16
@staticmethod
def _make_A(M, K, group_size, out_dtype):
quant_A = torch.rand(
M, K // group_size, group_size, dtype=torch.float32, device="cuda"
)
# -1 ~ 1
quant_A = quant_A * 2 - 1
# scaling abs max to fmax
finfo = torch.finfo(out_dtype)
fmax = finfo.max
scaling = fmax / quant_A.abs().amax(-1, keepdim=True)
quant_A *= scaling
quant_A = quant_A.to(out_dtype).to(torch.float32)
# create scale and A
scale = torch.rand(M, K // group_size, dtype=torch.float32, device="cuda")
scale /= fmax
A = quant_A * scale[..., None]
A = A.reshape(M, K)
quant_A = quant_A.reshape(M, K).to(out_dtype)
return A, quant_A, scale
@staticmethod
def _make_B(K, N, group_size, out_dtype):
def _aligned_size(a, b):
return (a + b - 1) // b * b
K_aligned = _aligned_size(K, group_size)
N_aligned = _aligned_size(N, group_size)
quant_B = torch.rand(
K_aligned // group_size,
group_size,
N_aligned // group_size,
group_size,
dtype=torch.float32,
device="cuda",
)
quant_B = quant_B * 2 - 1
# scaling abs max to fmax
finfo = torch.finfo(out_dtype)
fmax = finfo.max
scaling = fmax / quant_B.abs().amax((1, 3), keepdim=True)
quant_B *= scaling
quant_B = quant_B.to(out_dtype).to(torch.float32)
scale = torch.rand(
K_aligned // group_size,
1,
N_aligned // group_size,
1,
dtype=torch.float32,
device="cuda",
)
scale /= fmax
B = quant_B * scale
B = B.reshape(K_aligned, N_aligned)[:K, :N]
quant_B = quant_B.reshape(K_aligned, N_aligned).to(out_dtype)[:K, :N]
scale = scale.reshape(K_aligned // group_size, N_aligned // group_size)
return B, quant_B, scale
class TestPerTokenGroupQuantFP8(TestFP8Base):
def test_per_token_group_quant_fp8(self):
if torch.cuda.get_device_capability()[0] < 9:
return
A, A_quant_gt, scale_gt = self._make_A(
M=self.M, K=self.K, group_size=self.group_size, out_dtype=self.quant_type
)
A_quant, scale = per_token_group_quant_fp8(
x=A, group_size=self.group_size, dtype=self.quant_type
)
torch.testing.assert_close(scale, scale_gt)
diff = (A_quant.to(torch.float16) - A_quant_gt.to(torch.float16)).abs()
diff_count = (diff > 1e-5).count_nonzero()
assert diff_count / diff.numel() < 1e-4
class TestW8A8BlockFP8Matmul(TestFP8Base):
def test_w8a8_block_fp8_matmul(self):
if torch.cuda.get_device_capability()[0] < 9:
return
A, A_quant_gt, A_scale_gt = self._make_A(
M=self.M, K=self.K, group_size=self.group_size, out_dtype=self.quant_type
)
B, B_quant_gt, B_scale_gt = self._make_B(
K=self.K, N=self.N, group_size=self.group_size, out_dtype=self.quant_type
)
C_gt = A.to(self.output_type) @ B.to(self.output_type)
C = w8a8_block_fp8_matmul(
A=A_quant_gt,
B=B_quant_gt.T.contiguous(),
As=A_scale_gt,
Bs=B_scale_gt.T.contiguous(),
block_size=[128, 128],
output_dtype=self.output_type,
)
torch.testing.assert_close(C, C_gt, atol=0.5, rtol=1e-4)
if __name__ == "__main__":
unittest.main()