sglang_v0.5.2/flashinfer_0.3.1/tests/test_fp4_quantize.py

315 lines
11 KiB
Python

import functools
import pytest
import torch
from utils_fp4 import cast_from_fp4, ref_fp4_quant
from flashinfer import (
block_scale_interleave,
e2m1_and_ufp8sf_scale_to_float,
fp4_quantize,
mxfp4_quantize,
mxfp4_dequantize,
)
from flashinfer.utils import is_sm100a_supported
DTYPES = [torch.float16, torch.bfloat16]
# The batch dimension doesn't need to be multiple of 128
SHAPES = [(128, 64), (256, 128), (120, 64), (200, 256)]
SEEDS = [42]
CUDA_DEVICES = ["cuda:0"]
FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
BLOCK_SIZE = 16
def swizzle_sf(
unswizzled_sf: torch.Tensor,
original_row: int,
original_col: int,
scaling_vector_size: int = 16,
) -> torch.Tensor:
"""
Inverse of `unswizzle_sf`. Converts an unswizzled tensor back to swizzled form.
Args:
unswizzled_sf: Tensor of shape [row, col // scaling_vector_size].
original_row: Original row dimension (e.g., 120).
original_col: Original column dimension (e.g., 64).
scaling_vector_size: Scaling factor (default 16).
Returns:
Swizzled tensor of shape [padded_row, padded_col // scaling_vector_size].
"""
unswizzled_sf = unswizzled_sf.contiguous()
factor = scaling_vector_size * 4
padded_row = ((original_row + 128 - 1) // 128) * 128 # Next multiple of 128
padded_col = ((original_col + factor - 1) // factor) * factor # Next multiple of 64
# Pad the input tensor to [padded_row, padded_col // scaling_vector_size]
pad_rows = padded_row - original_row
pad_cols = (padded_col - original_col) // scaling_vector_size
padded_sf = torch.nn.functional.pad(
unswizzled_sf,
(0, pad_cols, 0, pad_rows),
mode="constant",
value=0,
).contiguous()
# Reshape and transpose to reverse unswizzle_sf
num_m_tiles = padded_row // 128
num_k_tiles = padded_col // factor
sf_reshaped = padded_sf.view(num_m_tiles, 4, 32, num_k_tiles, 4) # Reverse reshape
sf_swizzled = sf_reshaped.transpose(
1, 3
) # Reverse transpose [num_m_tiles, num_k_tiles, 32, 4, 4]
sf_swizzled = sf_swizzled.reshape(
padded_row, padded_col // scaling_vector_size
) # Flatten to [128, 64]
return sf_swizzled.contiguous()
def unswizzle_sf(
sf: torch.Tensor, row: int, col: int, scaling_vector_size: int = 16
) -> torch.Tensor:
factor = scaling_vector_size * 4
num_m_tiles = (row + 128 - 1) // 128
num_k_tiles = (col + factor - 1) // factor
# SF layout [num_m_tiles, num_k_tiles, 32 (m_tile column major), 4 (m_tile column major), 4(k_tile)]
sf_reshaped = sf.view(num_m_tiles, num_k_tiles, 32, 4, 4)
sf_unswizzle = sf_reshaped.transpose(1, 3)
sf_unswizzle = sf_unswizzle.reshape(num_m_tiles * 32 * 4, num_k_tiles * 4)
sf_unswizzle_sliced = sf_unswizzle[:row, : (col // scaling_vector_size)]
return sf_unswizzle_sliced.contiguous()
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("sf_use_ue8m0", [False, True])
@pytest.mark.parametrize("is_swizzled", [False, True])
@torch.inference_mode()
def test_fp4_quantization(
dtype: torch.dtype,
shape: tuple[int, int],
seed: int,
device: str,
sf_use_ue8m0: bool,
is_swizzled: bool,
) -> None:
if not is_sm100a_supported(torch.device(device)):
pytest.skip("Nvfp4 Requires compute capability >= 10 and CUDA >= 12.8")
torch.set_default_device(device)
torch.manual_seed(seed)
m, n = shape
sf_vec_size = 32 if sf_use_ue8m0 else 16
x = torch.randn((m, n), dtype=dtype)
tensor_amax = torch.abs(x).max().to(torch.float32)
if sf_use_ue8m0:
global_scale = torch.tensor(1.0, dtype=torch.float32)
else:
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
out_ref, scale_ref = ref_fp4_quant(x, global_scale, sf_vec_size, sf_use_ue8m0)
out, out_scale = fp4_quantize(
x, global_scale, sf_vec_size, sf_use_ue8m0, is_swizzled
)
assert n % sf_vec_size == 0, f"cols needs to be {sf_vec_size} divisible"
if sf_use_ue8m0:
out_scale = (out_scale.to(torch.int32) << 23).view(torch.float32)
else:
out_scale = out_scale.view(torch.float8_e4m3fn).to(torch.float32)
if is_swizzled:
scale_ans = unswizzle_sf(
out_scale.reshape(-1, n // sf_vec_size), m, n, sf_vec_size
)
else:
scale_ans = out_scale
out_ans = cast_from_fp4(out).reshape(m, n)
torch.testing.assert_close(out_ans, out_ref, rtol=1e0, atol=1e-1)
torch.testing.assert_close(scale_ans, scale_ref, rtol=1e-1, atol=1e-1)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_scale_swizzling(
dtype: torch.dtype,
shape: tuple[int, int],
seed: int,
device: str,
) -> None:
if not is_sm100a_supported(torch.device("cuda")):
pytest.skip("Nvfp4 Requires compute capability >= 10 and CUDA >= 12.8")
torch.set_default_device(device)
torch.manual_seed(seed)
m, n = shape
x = torch.randn((m, n), dtype=dtype)
tensor_amax = torch.abs(x).max().to(torch.float32)
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
_, unswizzled_scale = fp4_quantize(x, global_scale, BLOCK_SIZE, False, False)
_, swizzled_scale = fp4_quantize(x, global_scale, BLOCK_SIZE, False, True)
assert n % BLOCK_SIZE == 0, f"cols needs to be {BLOCK_SIZE} divisible"
recovered_unswizzled_scale = unswizzle_sf(
swizzle_sf(unswizzled_scale, m, n),
m,
n,
)
# We don't expect the following since padding:
# swizzle_sf(unswizzled_scale) == swizzled_scale
ref_unswizzled_scale = unswizzle_sf(swizzled_scale, m, n)
assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
assert_equal(recovered_unswizzled_scale, unswizzled_scale)
assert_equal(ref_unswizzled_scale, unswizzled_scale)
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_block_scale_interleave(
shape: tuple[int, int],
seed: int,
device: str,
) -> None:
"""Test the block_scale_interleave function directly."""
if not is_sm100a_supported(torch.device("cuda")):
pytest.skip("Nvfp4 Requires compute capability >= 10 and CUDA >= 12.8")
torch.set_default_device(device)
torch.manual_seed(seed)
m, n = shape
sf_vec_size = BLOCK_SIZE
# Create a test scale factors tensor with uint8 dtype
# The shape should be [m, n // sf_vec_size] for scale factors
scale_shape = (m, n // sf_vec_size)
unswizzled_sf = torch.randint(0, 256, scale_shape, dtype=torch.uint8, device=device)
# Test the swizzling function
swizzled_sf = block_scale_interleave(unswizzled_sf)
# Compare against the reference implementation
ref_swizzled_sf = swizzle_sf(unswizzled_sf, m, n, sf_vec_size)
# Basic checks
assert swizzled_sf.dtype == torch.uint8, f"Expected uint8, got {swizzled_sf.dtype}"
assert swizzled_sf.device == unswizzled_sf.device, "Device mismatch"
# Check that the output has the expected padded shape
factor = sf_vec_size * 4
padded_row = ((m + 128 - 1) // 128) * 128 # Next multiple of 128
padded_col = ((n + factor - 1) // factor) * factor # Next multiple of 64
expected_shape = (padded_row, padded_col // sf_vec_size)
expected_size = expected_shape[0] * expected_shape[1]
assert expected_size == swizzled_sf.shape[0], (
f"Expected size {expected_size}, got {swizzled_sf.shape[0]}"
)
assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
assert_equal(swizzled_sf.reshape(expected_shape), ref_swizzled_sf)
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("sf_use_ue8m0", [True, False])
@torch.inference_mode()
def test_e2m1_dequantization(
shape: tuple[int, int],
seed: int,
device: str,
sf_use_ue8m0: bool,
) -> None:
"""Test roundtrip: fp4_quantize -> e2m1_and_ufp8sf_scale_to_float."""
if not is_sm100a_supported(torch.device("cuda")):
pytest.skip("Nvfp4 Requires compute capability >= 10 and CUDA >= 12.8")
torch.set_default_device(device)
torch.manual_seed(seed)
# Create a reasonable test tensor
m, n = shape
x = torch.randn((m, n), dtype=torch.float16)
# Calculate global scale as in the other tests
tensor_amax = torch.abs(x).max().to(torch.float32)
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
# Test with default common settings
is_sf_swizzled_layout = True
block_size = 32 if sf_use_ue8m0 else 16
# Step 1: Quantize with fp4_quantize
quantized_tensor, scale_factors = fp4_quantize(
x, global_scale, block_size, sf_use_ue8m0, is_sf_swizzled_layout
)
# Step 2: Dequantize with e2m1_and_ufp8sf_scale_to_float
ufp8_type = 0 if sf_use_ue8m0 else 1
dequantized_tensor = e2m1_and_ufp8sf_scale_to_float(
quantized_tensor,
scale_factors,
1 / global_scale,
sf_vec_size=block_size,
ufp8_type=ufp8_type,
is_sf_swizzled_layout=is_sf_swizzled_layout,
)
# Move back to device for comparison
dequantized_tensor = dequantized_tensor.to(device)
x_float32 = x.to(torch.float32)
# Step 3: Compare results
assert dequantized_tensor.shape == x.shape, (
f"Shape mismatch: expected {x.shape}, got {dequantized_tensor.shape}"
)
assert dequantized_tensor.dtype == torch.float32, (
f"Expected float32, got {dequantized_tensor.dtype}"
)
# Check for invalid values
assert not torch.isnan(dequantized_tensor).any(), (
"Dequantized tensor contains NaN values"
)
assert not torch.isinf(dequantized_tensor).any(), (
"Dequantized tensor contains Inf values"
)
# Compare with original - should be reasonably close since FP4 is designed to preserve important values
torch.testing.assert_close(
dequantized_tensor,
x_float32,
rtol=0.3,
atol=0.5, # Reasonable tolerance for FP4 quantization
msg="Quantize -> dequantize roundtrip failed",
)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_mxfp4_quantize_roundtrip(device: str):
if not is_sm100a_supported(torch.device(device)):
pytest.skip("Nvfp4 Requires compute capability >= 10 and CUDA >= 12.8")
x = torch.randn((128, 64), device="cuda", dtype=torch.bfloat16) / 10
quant_a, sfs = mxfp4_quantize(x)
dq_a = mxfp4_dequantize(quant_a, sfs)
torch.testing.assert_close(
dq_a.cpu().to(torch.float32),
x.cpu().to(torch.float32),
rtol=0.3,
atol=0.5,
msg="Quantize -> dequantize mxfp4 roundtrip failed",
)
if __name__ == "__main__":
pytest.main([__file__, "-v"])