356 lines
12 KiB
Python
356 lines
12 KiB
Python
"""
|
|
Copyright (c) 2025 by FlashInfer team.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import math
|
|
from enum import Enum, auto
|
|
from itertools import product
|
|
|
|
import pytest
|
|
import torch
|
|
from einops import einsum, rearrange
|
|
|
|
from flashinfer.fp4_quantization import (
|
|
_pad_scale_factors,
|
|
get_fp4_quantization_module,
|
|
)
|
|
from flashinfer.gemm import group_gemm_mxfp4_nt_groupwise
|
|
from flashinfer.utils import get_compute_capability
|
|
|
|
|
|
class QuantMode(Enum):
|
|
MXFP4 = auto()
|
|
MXFP8_E4M3 = auto()
|
|
MXFP8_E5M2 = auto()
|
|
|
|
|
|
def swizzle_blockscale(
|
|
unswizzled_sf: torch.Tensor, b: int, m: int, n: int, sf_vec_size: int = 32
|
|
) -> torch.Tensor:
|
|
r"""Swizzle block scale tensor for MXFP4/MXFP8 format.
|
|
|
|
This function swizzles the block scale tensor to optimize memory access patterns
|
|
for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128.
|
|
|
|
Args:
|
|
unswizzled_sf (torch.Tensor): Input tensor with dtype uint8.
|
|
b (int): Batch dimension.
|
|
m (int): M dimension.
|
|
n (int): N dimension.
|
|
sf_vec_size (int, optional): Scale factor vector size. Defaults to 32.
|
|
|
|
Returns:
|
|
torch.Tensor: Swizzled tensor with the same shape as input.
|
|
"""
|
|
assert unswizzled_sf.dtype == torch.uint8, (
|
|
f"Input dtype must be uint8, got {unswizzled_sf.dtype}"
|
|
)
|
|
assert unswizzled_sf.ndim == 3, f"Input must be 3D, got {unswizzled_sf.ndim}"
|
|
assert unswizzled_sf.shape[0] == b, (
|
|
f"Batch dimension must equal b, got {unswizzled_sf.shape[0]} != {b}"
|
|
)
|
|
padded_input_sf_chunked = [
|
|
_pad_scale_factors(unswizzled_sf[i], m, n, sf_vec_size) for i in range(b)
|
|
]
|
|
padded_input_sf = torch.stack(padded_input_sf_chunked)
|
|
major, minor = get_compute_capability(unswizzled_sf.device)
|
|
out = get_fp4_quantization_module(f"{major}{minor}").block_scale_interleave_sm100(
|
|
padded_input_sf
|
|
)
|
|
out = out.view(padded_input_sf.shape)
|
|
return out
|
|
|
|
|
|
# Vanilla implementation only for unit test
|
|
def quantize_e2m1(x):
|
|
r"""
|
|
Quantizes a tensor to FP4.
|
|
|
|
Args:
|
|
x (torch.Tensor): The input tensor.
|
|
|
|
Returns:
|
|
torch.Tensor: The quantized tensor.
|
|
"""
|
|
assert x.shape[-1] % 2 == 0
|
|
x = x.clamp(-6, 6)
|
|
x_sign_bit = torch.lt(x, 0)
|
|
x_abs = torch.abs(x)
|
|
log_x_quant = torch.floor(torch.log2(x_abs)).clamp(0, 2)
|
|
x_quant_e_fp32 = torch.exp2(log_x_quant)
|
|
m_scale = 2
|
|
x_quant_m_scaled_fp32 = torch.round(x_abs * m_scale / x_quant_e_fp32)
|
|
mask = torch.ge(x_quant_m_scaled_fp32, m_scale)
|
|
x_quant_data_raw_e = log_x_quant + mask
|
|
x_quant_data_raw_m = x_quant_m_scaled_fp32 - mask * m_scale
|
|
x_quant_data_raw = (
|
|
x_sign_bit * 8 + x_quant_data_raw_e * m_scale + x_quant_data_raw_m
|
|
).to(torch.uint8)
|
|
x_quant_data = x_quant_data_raw[..., ::2] + x_quant_data_raw[..., 1::2] * 16
|
|
return x_quant_data
|
|
|
|
|
|
# Vanilla implementation only for unit test
|
|
def dequantize_e2m1(x):
|
|
r"""
|
|
Dequantizes a tensor from FP4.
|
|
|
|
Args:
|
|
x (torch.Tensor): The input tensor.
|
|
|
|
Returns:
|
|
torch.Tensor: The dequantized tensor.
|
|
"""
|
|
x_quant_data_raw_1 = x % 16
|
|
x_quant_data_raw_2 = x // 16
|
|
x_quant_data_raw = torch.stack(
|
|
[x_quant_data_raw_1, x_quant_data_raw_2], dim=-1
|
|
).flatten(start_dim=-2)
|
|
x_sign_bit = x_quant_data_raw // 8
|
|
x = x_quant_data_raw % 8
|
|
m_scale = 2
|
|
x_quant_data_raw_e = x // m_scale
|
|
x_quant_data_raw_m = x % m_scale
|
|
mask = torch.gt(x_quant_data_raw_e, 0).to(torch.float32)
|
|
log_x_quant = x_quant_data_raw_e - mask
|
|
x_quant_m_scaled_fp32 = x_quant_data_raw_m + mask * m_scale
|
|
x_dequant_abs = x_quant_m_scaled_fp32 / m_scale * torch.exp2(log_x_quant)
|
|
x_dequant = (0.5 - x_sign_bit) * 2 * x_dequant_abs
|
|
return x_dequant
|
|
|
|
|
|
def gemm_mxfp8_mxfp4_nt_groupwise_ref(
|
|
A, B, As, Bs, tile_size, n, k, output_dtype=torch.bfloat16
|
|
):
|
|
r"""
|
|
A: (m, k), torch.float8_e4m3fn or torch.float8_e5m2
|
|
B: (n // 2, k), e2m1 packed as torch.uint8
|
|
A_scale: (m, k // tile_size), ue8m0 saved as torch.uint8
|
|
B_scale: (n, k // tile_size), ue8m0 saved as torch.uint8
|
|
"""
|
|
ue8m0_bias = 127
|
|
A_f32 = A.to(torch.float32)
|
|
B_f32 = dequantize_e2m1(B)
|
|
A_f32_reshape = rearrange(A_f32, "m (k b) -> m k b", b=tile_size)
|
|
A_f32_scale_reshape = A_f32_reshape * rearrange(
|
|
torch.exp2(As.to(torch.float32) - ue8m0_bias), "m k -> m k 1"
|
|
)
|
|
A_f32_scale = rearrange(A_f32_scale_reshape, "m k b -> m (k b)")[:, :k]
|
|
B_f32_reshape = rearrange(B_f32, "n (k b) -> n k b", b=tile_size)
|
|
B_f32_scale_reshape = B_f32_reshape * rearrange(
|
|
torch.exp2(Bs.to(torch.float32) - ue8m0_bias), "n k -> n k 1"
|
|
)
|
|
B_f32_scale = rearrange(B_f32_scale_reshape, "n k b -> n (k b)")[:n, :k]
|
|
return einsum(A_f32_scale, B_f32_scale, "m k, n k -> m n").to(output_dtype)
|
|
|
|
|
|
def quantize_tensor(x, tile_size, n_padded, k_padded, quant_mode):
|
|
r"""
|
|
Quantizes a tensor to MXFP4 or MXFP8.
|
|
|
|
Args:
|
|
x (torch.Tensor): The input tensor.
|
|
tile_size (int): The tile size.
|
|
n_padded (int): The padded n dimension, None if not needed.
|
|
k_padded (int): The padded k dimension.
|
|
quant_mode (QuantMode): The quantization mode.
|
|
|
|
Returns:
|
|
tuple: A tuple containing the quantized tensor and the
|
|
calculated scales.
|
|
"""
|
|
# 1. Initial Setup
|
|
ue8m0_bias = 127
|
|
if quant_mode == QuantMode.MXFP8_E4M3:
|
|
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
|
quant_amax = torch.tensor(fp8_info.max, dtype=torch.float32, device=x.device)
|
|
elif quant_mode == QuantMode.MXFP8_E5M2:
|
|
fp8_info = torch.finfo(torch.float8_e5m2)
|
|
quant_amax = torch.tensor(fp8_info.max, dtype=torch.float32, device=x.device)
|
|
elif quant_mode == QuantMode.MXFP4:
|
|
quant_amax = torch.tensor(6, dtype=torch.float32, device=x.device)
|
|
else:
|
|
raise ValueError(f"Unsupported quantization mode: {quant_mode}")
|
|
if n_padded is not None and x.shape[-2] != n_padded:
|
|
x = torch.cat(
|
|
[
|
|
x,
|
|
torch.zeros(
|
|
(*x.shape[:-2], n_padded - x.shape[-2], x.shape[-1]),
|
|
dtype=x.dtype,
|
|
device=x.device,
|
|
),
|
|
],
|
|
dim=-2,
|
|
)
|
|
if x.shape[-1] != k_padded:
|
|
x = torch.cat(
|
|
[
|
|
x,
|
|
torch.zeros(
|
|
(*x.shape[:-1], k_padded - x.shape[-1]),
|
|
dtype=x.dtype,
|
|
device=x.device,
|
|
),
|
|
],
|
|
dim=-1,
|
|
)
|
|
|
|
# 2. Tiling and Scale Calculation
|
|
x_tiled = x.unflatten(-1, (-1, tile_size))
|
|
x_tiled_abs = x_tiled.abs()
|
|
log2_x_scale = (
|
|
torch.floor(torch.log2(x_tiled_abs.amax(dim=-1)))
|
|
- torch.floor(torch.log2(quant_amax))
|
|
).clamp(-ue8m0_bias, ue8m0_bias)
|
|
|
|
# 3. Final Quantization
|
|
# Divide the original tensor by the broadcasted scales
|
|
x_tiled_quant = (
|
|
torch.exp2(torch.log2(x_tiled_abs) - log2_x_scale[..., None]).clamp(
|
|
0, quant_amax
|
|
)
|
|
* x_tiled.sign()
|
|
)
|
|
x_quant = x_tiled_quant.flatten(-2, -1)
|
|
|
|
# Convert the result to the target format
|
|
if quant_mode == QuantMode.MXFP8_E4M3:
|
|
x_quant_data = x_quant.to(torch.float8_e4m3fn)
|
|
elif quant_mode == QuantMode.MXFP8_E5M2:
|
|
x_quant_data = x_quant.to(torch.float8_e5m2)
|
|
elif quant_mode == QuantMode.MXFP4:
|
|
x_quant_data = quantize_e2m1(x_quant)
|
|
else:
|
|
raise ValueError(f"Unsupported quantization mode: {quant_mode}")
|
|
x_scale_data = (log2_x_scale + ue8m0_bias).to(torch.uint8)
|
|
|
|
return x_quant_data, x_scale_data
|
|
|
|
|
|
@pytest.mark.parametrize("m", [4, 128, 256, 512, 4096, 8192])
|
|
@pytest.mark.parametrize("n", [128, 256, 512, 2879, 4096, 8192])
|
|
@pytest.mark.parametrize("k", [128, 256, 512, 2880, 4096, 8192])
|
|
@pytest.mark.parametrize("group_size", [1, 2, 4, 8])
|
|
@pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
|
|
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
|
def test_mxfp8_mxfp4_groupwise_group_gemm(
|
|
m,
|
|
n,
|
|
k,
|
|
group_size,
|
|
fp8_dtype,
|
|
out_dtype,
|
|
):
|
|
torch.random.manual_seed(0)
|
|
tile_size = 32
|
|
alignment_n = 8
|
|
alignment_k = 128
|
|
|
|
a_val = torch.randn((group_size * m, k), dtype=torch.float32, device="cuda")
|
|
b_val = torch.randn(
|
|
(group_size, n, k), dtype=torch.float32, device="cuda"
|
|
) / math.sqrt(k)
|
|
n_padded = (n + alignment_n - 1) // alignment_n * alignment_n
|
|
k_padded = (k + alignment_k - 1) // alignment_k * alignment_k
|
|
|
|
if fp8_dtype == torch.float8_e4m3fn:
|
|
a_quant_mode = QuantMode.MXFP8_E4M3
|
|
elif fp8_dtype == torch.float8_e5m2:
|
|
a_quant_mode = QuantMode.MXFP8_E5M2
|
|
else:
|
|
raise ValueError(f"Unsupported FP8 dtype: {fp8_dtype}")
|
|
a_fp8, a_scale = quantize_tensor(a_val, tile_size, None, k_padded, a_quant_mode)
|
|
b_fp4, b_scale = quantize_tensor(
|
|
b_val, tile_size, n_padded, k_padded, QuantMode.MXFP4
|
|
)
|
|
|
|
a_scale_swizzled = swizzle_blockscale(
|
|
a_scale.unflatten(0, (group_size, m)), group_size, m, k_padded, tile_size
|
|
).flatten(0, 1)
|
|
b_scale_swizzled = swizzle_blockscale(
|
|
b_scale, group_size, n_padded, k_padded, tile_size
|
|
)
|
|
|
|
group_arange = torch.arange(0, group_size + 1, dtype=torch.int32, device="cuda")
|
|
m_indptr = group_arange * m
|
|
|
|
# Pad a_scale_swizzled according to the function compute_sm100_cutlass_group_gemm_args
|
|
# in group_gemm_mxfp4_groupwise_sm100.cuh
|
|
alignment_m_sf = 128
|
|
m_indptr_padded = (
|
|
(m_indptr + group_arange * (alignment_m_sf - 1))
|
|
// alignment_m_sf
|
|
* alignment_m_sf
|
|
)
|
|
m_sf = m_indptr_padded[1:] - m_indptr_padded[:-1]
|
|
a_scale_chunked = a_scale_swizzled.chunk(group_size, dim=0)
|
|
a_scale_chunked = [
|
|
torch.cat(
|
|
[
|
|
x,
|
|
torch.zeros(
|
|
m_sf[i] - x.shape[0], *x.shape[1:], dtype=x.dtype, device=x.device
|
|
),
|
|
]
|
|
)
|
|
for i, x in enumerate(a_scale_chunked)
|
|
]
|
|
a_scale_swizzled = torch.cat(a_scale_chunked)
|
|
|
|
out_ref = torch.empty((group_size * m, n), dtype=out_dtype, device="cuda")
|
|
for i in range(group_size):
|
|
out_ref[m * i : m * (i + 1)] = gemm_mxfp8_mxfp4_nt_groupwise_ref(
|
|
a_fp8[m * i : m * (i + 1)],
|
|
b_fp4[i],
|
|
a_scale[m * i : m * (i + 1)],
|
|
b_scale[i],
|
|
tile_size,
|
|
n,
|
|
k,
|
|
out_dtype,
|
|
)
|
|
|
|
mma_sm_list = [1, 2]
|
|
tile_m_list = [128]
|
|
tile_n_list = [64, 128, 192, 256]
|
|
tile_k_list = [128, 256]
|
|
swap_ab_list = [True, False]
|
|
for mma_sm, tile_m, tile_n, tile_k, swap_ab in product(
|
|
mma_sm_list, tile_m_list, tile_n_list, tile_k_list, swap_ab_list
|
|
):
|
|
out = group_gemm_mxfp4_nt_groupwise(
|
|
a_fp8,
|
|
b_fp4,
|
|
a_scale_swizzled,
|
|
b_scale_swizzled,
|
|
m_indptr,
|
|
mma_sm=mma_sm,
|
|
tile_m=tile_m,
|
|
tile_n=tile_n,
|
|
tile_k=tile_k,
|
|
swap_ab=swap_ab,
|
|
out_dtype=out_dtype,
|
|
)[:, :n]
|
|
torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
for fp8_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
|
for out_dtype in [torch.bfloat16, torch.float16]:
|
|
test_mxfp8_mxfp4_groupwise_group_gemm(
|
|
4, 2879, 2880, 2, fp8_dtype, out_dtype
|
|
)
|