sglang_v0.5.2/flashinfer_0.3.1/tests/test_groupwise_scaled_gemm_...

356 lines
12 KiB
Python

"""
Copyright (c) 2025 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import math
from enum import Enum, auto
from itertools import product
import pytest
import torch
from einops import einsum, rearrange
from flashinfer.fp4_quantization import (
_pad_scale_factors,
get_fp4_quantization_module,
)
from flashinfer.gemm import group_gemm_mxfp4_nt_groupwise
from flashinfer.utils import get_compute_capability
class QuantMode(Enum):
MXFP4 = auto()
MXFP8_E4M3 = auto()
MXFP8_E5M2 = auto()
def swizzle_blockscale(
unswizzled_sf: torch.Tensor, b: int, m: int, n: int, sf_vec_size: int = 32
) -> torch.Tensor:
r"""Swizzle block scale tensor for MXFP4/MXFP8 format.
This function swizzles the block scale tensor to optimize memory access patterns
for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128.
Args:
unswizzled_sf (torch.Tensor): Input tensor with dtype uint8.
b (int): Batch dimension.
m (int): M dimension.
n (int): N dimension.
sf_vec_size (int, optional): Scale factor vector size. Defaults to 32.
Returns:
torch.Tensor: Swizzled tensor with the same shape as input.
"""
assert unswizzled_sf.dtype == torch.uint8, (
f"Input dtype must be uint8, got {unswizzled_sf.dtype}"
)
assert unswizzled_sf.ndim == 3, f"Input must be 3D, got {unswizzled_sf.ndim}"
assert unswizzled_sf.shape[0] == b, (
f"Batch dimension must equal b, got {unswizzled_sf.shape[0]} != {b}"
)
padded_input_sf_chunked = [
_pad_scale_factors(unswizzled_sf[i], m, n, sf_vec_size) for i in range(b)
]
padded_input_sf = torch.stack(padded_input_sf_chunked)
major, minor = get_compute_capability(unswizzled_sf.device)
out = get_fp4_quantization_module(f"{major}{minor}").block_scale_interleave_sm100(
padded_input_sf
)
out = out.view(padded_input_sf.shape)
return out
# Vanilla implementation only for unit test
def quantize_e2m1(x):
r"""
Quantizes a tensor to FP4.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The quantized tensor.
"""
assert x.shape[-1] % 2 == 0
x = x.clamp(-6, 6)
x_sign_bit = torch.lt(x, 0)
x_abs = torch.abs(x)
log_x_quant = torch.floor(torch.log2(x_abs)).clamp(0, 2)
x_quant_e_fp32 = torch.exp2(log_x_quant)
m_scale = 2
x_quant_m_scaled_fp32 = torch.round(x_abs * m_scale / x_quant_e_fp32)
mask = torch.ge(x_quant_m_scaled_fp32, m_scale)
x_quant_data_raw_e = log_x_quant + mask
x_quant_data_raw_m = x_quant_m_scaled_fp32 - mask * m_scale
x_quant_data_raw = (
x_sign_bit * 8 + x_quant_data_raw_e * m_scale + x_quant_data_raw_m
).to(torch.uint8)
x_quant_data = x_quant_data_raw[..., ::2] + x_quant_data_raw[..., 1::2] * 16
return x_quant_data
# Vanilla implementation only for unit test
def dequantize_e2m1(x):
r"""
Dequantizes a tensor from FP4.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The dequantized tensor.
"""
x_quant_data_raw_1 = x % 16
x_quant_data_raw_2 = x // 16
x_quant_data_raw = torch.stack(
[x_quant_data_raw_1, x_quant_data_raw_2], dim=-1
).flatten(start_dim=-2)
x_sign_bit = x_quant_data_raw // 8
x = x_quant_data_raw % 8
m_scale = 2
x_quant_data_raw_e = x // m_scale
x_quant_data_raw_m = x % m_scale
mask = torch.gt(x_quant_data_raw_e, 0).to(torch.float32)
log_x_quant = x_quant_data_raw_e - mask
x_quant_m_scaled_fp32 = x_quant_data_raw_m + mask * m_scale
x_dequant_abs = x_quant_m_scaled_fp32 / m_scale * torch.exp2(log_x_quant)
x_dequant = (0.5 - x_sign_bit) * 2 * x_dequant_abs
return x_dequant
def gemm_mxfp8_mxfp4_nt_groupwise_ref(
A, B, As, Bs, tile_size, n, k, output_dtype=torch.bfloat16
):
r"""
A: (m, k), torch.float8_e4m3fn or torch.float8_e5m2
B: (n // 2, k), e2m1 packed as torch.uint8
A_scale: (m, k // tile_size), ue8m0 saved as torch.uint8
B_scale: (n, k // tile_size), ue8m0 saved as torch.uint8
"""
ue8m0_bias = 127
A_f32 = A.to(torch.float32)
B_f32 = dequantize_e2m1(B)
A_f32_reshape = rearrange(A_f32, "m (k b) -> m k b", b=tile_size)
A_f32_scale_reshape = A_f32_reshape * rearrange(
torch.exp2(As.to(torch.float32) - ue8m0_bias), "m k -> m k 1"
)
A_f32_scale = rearrange(A_f32_scale_reshape, "m k b -> m (k b)")[:, :k]
B_f32_reshape = rearrange(B_f32, "n (k b) -> n k b", b=tile_size)
B_f32_scale_reshape = B_f32_reshape * rearrange(
torch.exp2(Bs.to(torch.float32) - ue8m0_bias), "n k -> n k 1"
)
B_f32_scale = rearrange(B_f32_scale_reshape, "n k b -> n (k b)")[:n, :k]
return einsum(A_f32_scale, B_f32_scale, "m k, n k -> m n").to(output_dtype)
def quantize_tensor(x, tile_size, n_padded, k_padded, quant_mode):
r"""
Quantizes a tensor to MXFP4 or MXFP8.
Args:
x (torch.Tensor): The input tensor.
tile_size (int): The tile size.
n_padded (int): The padded n dimension, None if not needed.
k_padded (int): The padded k dimension.
quant_mode (QuantMode): The quantization mode.
Returns:
tuple: A tuple containing the quantized tensor and the
calculated scales.
"""
# 1. Initial Setup
ue8m0_bias = 127
if quant_mode == QuantMode.MXFP8_E4M3:
fp8_info = torch.finfo(torch.float8_e4m3fn)
quant_amax = torch.tensor(fp8_info.max, dtype=torch.float32, device=x.device)
elif quant_mode == QuantMode.MXFP8_E5M2:
fp8_info = torch.finfo(torch.float8_e5m2)
quant_amax = torch.tensor(fp8_info.max, dtype=torch.float32, device=x.device)
elif quant_mode == QuantMode.MXFP4:
quant_amax = torch.tensor(6, dtype=torch.float32, device=x.device)
else:
raise ValueError(f"Unsupported quantization mode: {quant_mode}")
if n_padded is not None and x.shape[-2] != n_padded:
x = torch.cat(
[
x,
torch.zeros(
(*x.shape[:-2], n_padded - x.shape[-2], x.shape[-1]),
dtype=x.dtype,
device=x.device,
),
],
dim=-2,
)
if x.shape[-1] != k_padded:
x = torch.cat(
[
x,
torch.zeros(
(*x.shape[:-1], k_padded - x.shape[-1]),
dtype=x.dtype,
device=x.device,
),
],
dim=-1,
)
# 2. Tiling and Scale Calculation
x_tiled = x.unflatten(-1, (-1, tile_size))
x_tiled_abs = x_tiled.abs()
log2_x_scale = (
torch.floor(torch.log2(x_tiled_abs.amax(dim=-1)))
- torch.floor(torch.log2(quant_amax))
).clamp(-ue8m0_bias, ue8m0_bias)
# 3. Final Quantization
# Divide the original tensor by the broadcasted scales
x_tiled_quant = (
torch.exp2(torch.log2(x_tiled_abs) - log2_x_scale[..., None]).clamp(
0, quant_amax
)
* x_tiled.sign()
)
x_quant = x_tiled_quant.flatten(-2, -1)
# Convert the result to the target format
if quant_mode == QuantMode.MXFP8_E4M3:
x_quant_data = x_quant.to(torch.float8_e4m3fn)
elif quant_mode == QuantMode.MXFP8_E5M2:
x_quant_data = x_quant.to(torch.float8_e5m2)
elif quant_mode == QuantMode.MXFP4:
x_quant_data = quantize_e2m1(x_quant)
else:
raise ValueError(f"Unsupported quantization mode: {quant_mode}")
x_scale_data = (log2_x_scale + ue8m0_bias).to(torch.uint8)
return x_quant_data, x_scale_data
@pytest.mark.parametrize("m", [4, 128, 256, 512, 4096, 8192])
@pytest.mark.parametrize("n", [128, 256, 512, 2879, 4096, 8192])
@pytest.mark.parametrize("k", [128, 256, 512, 2880, 4096, 8192])
@pytest.mark.parametrize("group_size", [1, 2, 4, 8])
@pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
def test_mxfp8_mxfp4_groupwise_group_gemm(
m,
n,
k,
group_size,
fp8_dtype,
out_dtype,
):
torch.random.manual_seed(0)
tile_size = 32
alignment_n = 8
alignment_k = 128
a_val = torch.randn((group_size * m, k), dtype=torch.float32, device="cuda")
b_val = torch.randn(
(group_size, n, k), dtype=torch.float32, device="cuda"
) / math.sqrt(k)
n_padded = (n + alignment_n - 1) // alignment_n * alignment_n
k_padded = (k + alignment_k - 1) // alignment_k * alignment_k
if fp8_dtype == torch.float8_e4m3fn:
a_quant_mode = QuantMode.MXFP8_E4M3
elif fp8_dtype == torch.float8_e5m2:
a_quant_mode = QuantMode.MXFP8_E5M2
else:
raise ValueError(f"Unsupported FP8 dtype: {fp8_dtype}")
a_fp8, a_scale = quantize_tensor(a_val, tile_size, None, k_padded, a_quant_mode)
b_fp4, b_scale = quantize_tensor(
b_val, tile_size, n_padded, k_padded, QuantMode.MXFP4
)
a_scale_swizzled = swizzle_blockscale(
a_scale.unflatten(0, (group_size, m)), group_size, m, k_padded, tile_size
).flatten(0, 1)
b_scale_swizzled = swizzle_blockscale(
b_scale, group_size, n_padded, k_padded, tile_size
)
group_arange = torch.arange(0, group_size + 1, dtype=torch.int32, device="cuda")
m_indptr = group_arange * m
# Pad a_scale_swizzled according to the function compute_sm100_cutlass_group_gemm_args
# in group_gemm_mxfp4_groupwise_sm100.cuh
alignment_m_sf = 128
m_indptr_padded = (
(m_indptr + group_arange * (alignment_m_sf - 1))
// alignment_m_sf
* alignment_m_sf
)
m_sf = m_indptr_padded[1:] - m_indptr_padded[:-1]
a_scale_chunked = a_scale_swizzled.chunk(group_size, dim=0)
a_scale_chunked = [
torch.cat(
[
x,
torch.zeros(
m_sf[i] - x.shape[0], *x.shape[1:], dtype=x.dtype, device=x.device
),
]
)
for i, x in enumerate(a_scale_chunked)
]
a_scale_swizzled = torch.cat(a_scale_chunked)
out_ref = torch.empty((group_size * m, n), dtype=out_dtype, device="cuda")
for i in range(group_size):
out_ref[m * i : m * (i + 1)] = gemm_mxfp8_mxfp4_nt_groupwise_ref(
a_fp8[m * i : m * (i + 1)],
b_fp4[i],
a_scale[m * i : m * (i + 1)],
b_scale[i],
tile_size,
n,
k,
out_dtype,
)
mma_sm_list = [1, 2]
tile_m_list = [128]
tile_n_list = [64, 128, 192, 256]
tile_k_list = [128, 256]
swap_ab_list = [True, False]
for mma_sm, tile_m, tile_n, tile_k, swap_ab in product(
mma_sm_list, tile_m_list, tile_n_list, tile_k_list, swap_ab_list
):
out = group_gemm_mxfp4_nt_groupwise(
a_fp8,
b_fp4,
a_scale_swizzled,
b_scale_swizzled,
m_indptr,
mma_sm=mma_sm,
tile_m=tile_m,
tile_n=tile_n,
tile_k=tile_k,
swap_ab=swap_ab,
out_dtype=out_dtype,
)[:, :n]
torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)
if __name__ == "__main__":
for fp8_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
for out_dtype in [torch.bfloat16, torch.float16]:
test_mxfp8_mxfp4_groupwise_group_gemm(
4, 2879, 2880, 2, fp8_dtype, out_dtype
)