import unittest import torch from sglang.srt.layers.quantization.fp8_kernel import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul, ) from sglang.test.test_utils import CustomTestCase class TestFP8Base(CustomTestCase): @classmethod def setUpClass(cls): cls.M = 256 # test non-aligned cls.N = 1024 + 64 cls.K = 512 cls.group_size = 128 cls.quant_type = torch.float8_e4m3fn cls.output_type = torch.bfloat16 @staticmethod def _make_A(M, K, group_size, out_dtype): quant_A = torch.rand( M, K // group_size, group_size, dtype=torch.float32, device="cuda" ) # -1 ~ 1 quant_A = quant_A * 2 - 1 # scaling abs max to fmax finfo = torch.finfo(out_dtype) fmax = finfo.max scaling = fmax / quant_A.abs().amax(-1, keepdim=True) quant_A *= scaling quant_A = quant_A.to(out_dtype).to(torch.float32) # create scale and A scale = torch.rand(M, K // group_size, dtype=torch.float32, device="cuda") scale /= fmax A = quant_A * scale[..., None] A = A.reshape(M, K) quant_A = quant_A.reshape(M, K).to(out_dtype) return A, quant_A, scale @staticmethod def _make_B(K, N, group_size, out_dtype): def _aligned_size(a, b): return (a + b - 1) // b * b K_aligned = _aligned_size(K, group_size) N_aligned = _aligned_size(N, group_size) quant_B = torch.rand( K_aligned // group_size, group_size, N_aligned // group_size, group_size, dtype=torch.float32, device="cuda", ) quant_B = quant_B * 2 - 1 # scaling abs max to fmax finfo = torch.finfo(out_dtype) fmax = finfo.max scaling = fmax / quant_B.abs().amax((1, 3), keepdim=True) quant_B *= scaling quant_B = quant_B.to(out_dtype).to(torch.float32) scale = torch.rand( K_aligned // group_size, 1, N_aligned // group_size, 1, dtype=torch.float32, device="cuda", ) scale /= fmax B = quant_B * scale B = B.reshape(K_aligned, N_aligned)[:K, :N] quant_B = quant_B.reshape(K_aligned, N_aligned).to(out_dtype)[:K, :N] scale = scale.reshape(K_aligned // group_size, N_aligned // group_size) return B, quant_B, scale class TestPerTokenGroupQuantFP8(TestFP8Base): def test_per_token_group_quant_fp8(self): if torch.cuda.get_device_capability()[0] < 9: return A, A_quant_gt, scale_gt = self._make_A( M=self.M, K=self.K, group_size=self.group_size, out_dtype=self.quant_type ) A_quant, scale = per_token_group_quant_fp8( x=A, group_size=self.group_size, dtype=self.quant_type ) torch.testing.assert_close(scale, scale_gt) diff = (A_quant.to(torch.float16) - A_quant_gt.to(torch.float16)).abs() diff_count = (diff > 1e-5).count_nonzero() assert diff_count / diff.numel() < 1e-4 class TestW8A8BlockFP8Matmul(TestFP8Base): def test_w8a8_block_fp8_matmul(self): if torch.cuda.get_device_capability()[0] < 9: return A, A_quant_gt, A_scale_gt = self._make_A( M=self.M, K=self.K, group_size=self.group_size, out_dtype=self.quant_type ) B, B_quant_gt, B_scale_gt = self._make_B( K=self.K, N=self.N, group_size=self.group_size, out_dtype=self.quant_type ) C_gt = A.to(self.output_type) @ B.to(self.output_type) C = w8a8_block_fp8_matmul( A=A_quant_gt, B=B_quant_gt.T.contiguous(), As=A_scale_gt, Bs=B_scale_gt.T.contiguous(), block_size=[128, 128], output_dtype=self.output_type, ) torch.testing.assert_close(C, C_gt, atol=0.5, rtol=1e-4) if __name__ == "__main__": unittest.main()