145 lines
4.6 KiB
Python
145 lines
4.6 KiB
Python
import pytest
|
|
import torch
|
|
|
|
from flashinfer import mxfp8_dequantize_host, mxfp8_quantize
|
|
|
|
|
|
@pytest.mark.parametrize("m", [1, 1024])
|
|
@pytest.mark.parametrize("k", [1024])
|
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
|
@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False])
|
|
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
|
def test_mxfp8_quantize_torch(m, k, dtype, is_sf_swizzled_layout, device):
|
|
a = 16 * torch.randn([m, k], dtype=dtype).to(device).contiguous()
|
|
|
|
if device == "cpu":
|
|
a = a.float()
|
|
|
|
a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout)
|
|
|
|
if device == "cuda":
|
|
a_fp8 = a_fp8.cpu()
|
|
a_sf = a_sf.cpu()
|
|
|
|
a_pt = mxfp8_dequantize_host(
|
|
a_fp8.view(torch.uint8),
|
|
a_sf.view(torch.uint8).reshape(-1),
|
|
is_sf_swizzled_layout,
|
|
)
|
|
|
|
if device == "cuda":
|
|
a_pt = a_pt.cuda()
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
def check_accuracy(a, b, atol, rtol, percent):
|
|
if torch.any(torch.isnan(a)):
|
|
raise Exception("NaN in a")
|
|
if torch.any(torch.isnan(b)):
|
|
raise Exception("NaN in b")
|
|
assert a.shape == b.shape
|
|
left = torch.abs(a - b)
|
|
right = atol + rtol * torch.abs(b)
|
|
count = torch.sum(left > right)
|
|
mismatch_percent = count / a.numel()
|
|
if mismatch_percent > 1 - percent:
|
|
raise Exception(
|
|
"Mismatch percentage is %f for rtol %f" % (mismatch_percent, rtol)
|
|
)
|
|
|
|
check_accuracy(a_pt, a, 8, 0, 0.999)
|
|
|
|
|
|
def mxfp8_quantize_check_accuracy(a, b, atol, rtol, percent):
|
|
if torch.any(torch.isnan(a)):
|
|
raise Exception("NaN in a")
|
|
if torch.any(torch.isnan(b)):
|
|
raise Exception("NaN in b")
|
|
assert a.shape == b.shape
|
|
left = torch.abs(a - b)
|
|
right = atol + rtol * torch.abs(b)
|
|
count = torch.sum(left > right)
|
|
mismatch_percent = count / a.numel()
|
|
if mismatch_percent > 1 - percent:
|
|
raise Exception(
|
|
"Mismatch percentage is %f for rtol %f" % (mismatch_percent, rtol)
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("m", [1, 2, 16, 1024])
|
|
@pytest.mark.parametrize("k", [512, 1024])
|
|
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
|
|
@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False])
|
|
def test_mxfp8_quantize_torch_host(m, k, dtype, is_sf_swizzled_layout):
|
|
torch.random.manual_seed(0)
|
|
a = (torch.randn([m, k], dtype=torch.float) * 16).cpu().contiguous()
|
|
|
|
a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout)
|
|
|
|
a_pt = mxfp8_dequantize_host(
|
|
a_fp8.view(torch.uint8), a_sf.view(torch.uint8), is_sf_swizzled_layout
|
|
)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
mxfp8_quantize_check_accuracy(a_pt, a, 8, 0, 0.999)
|
|
|
|
|
|
@pytest.mark.parametrize("m", [1, 2, 16, 1024])
|
|
@pytest.mark.parametrize("k", [512, 1024])
|
|
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
|
|
@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False])
|
|
def test_mxfp8_quantize_torch_device(m, k, dtype, is_sf_swizzled_layout):
|
|
torch.random.manual_seed(0)
|
|
a = (torch.randn([m, k], dtype=torch.float) * 16).to(dtype).cuda().contiguous()
|
|
|
|
a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout, 32)
|
|
a_pt = mxfp8_dequantize_host(
|
|
a_fp8.cpu().view(torch.uint8),
|
|
a_sf.cpu().view(torch.uint8),
|
|
is_sf_swizzled_layout,
|
|
)
|
|
|
|
torch.cuda.synchronize()
|
|
mxfp8_quantize_check_accuracy(
|
|
a_pt.cpu().to(torch.float32), a.cpu().to(torch.float32), 8, 0, 0.999
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("m", [1, 2, 16, 1024])
|
|
@pytest.mark.parametrize("k", [1568])
|
|
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
|
|
@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False])
|
|
@pytest.mark.parametrize("alignment", [64, 128])
|
|
def test_mxfp8_quantize_alignment_torch_device(
|
|
m, k, dtype, is_sf_swizzled_layout, alignment
|
|
):
|
|
torch.random.manual_seed(0)
|
|
a = (torch.randn([m, k], dtype=torch.float) * 16).to(dtype).cuda().contiguous()
|
|
padded_k = ((k + alignment - 1) // alignment) * alignment
|
|
|
|
# Quantize it on device.
|
|
a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout, alignment)
|
|
assert a_fp8.shape[1] == padded_k
|
|
|
|
# Dequantize it on host.
|
|
a_pt = mxfp8_dequantize_host(
|
|
a_fp8.cpu().view(torch.uint8),
|
|
a_sf.cpu().view(torch.uint8),
|
|
is_sf_swizzled_layout,
|
|
)
|
|
|
|
# Check if the bits of paddings are zero.
|
|
paddings = a_fp8.view(torch.int8)[:, k:]
|
|
assert torch.all(paddings == 0), "Paddings should be zero"
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
mxfp8_quantize_check_accuracy(
|
|
a_pt[:, :k].cpu().to(torch.float32), a.cpu().to(torch.float32), 8, 0, 0.999
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|