613 lines
22 KiB
Python
613 lines
22 KiB
Python
"""
|
|
Copyright (c) 2025 by FlashInfer team.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import functools
|
|
from enum import Enum
|
|
from types import SimpleNamespace
|
|
from typing import List, Optional, Tuple
|
|
|
|
import torch
|
|
|
|
from .jit import JitSpec
|
|
from .jit import env as jit_env
|
|
from .jit import (
|
|
gen_jit_spec,
|
|
sm121a_nvcc_flags,
|
|
sm120a_nvcc_flags,
|
|
sm110a_nvcc_flags,
|
|
sm103a_nvcc_flags,
|
|
sm100a_nvcc_flags,
|
|
sm90a_nvcc_flags,
|
|
)
|
|
from .jit.cpp_ext import is_cuda_version_at_least
|
|
from .utils import (
|
|
device_support_pdl,
|
|
get_shuffle_matrix_a_row_indices,
|
|
get_shuffle_matrix_sf_a_row_indices,
|
|
register_custom_op,
|
|
register_fake_op,
|
|
get_compute_capability,
|
|
)
|
|
|
|
|
|
def _pad_scale_factors(
|
|
unswizzled_sf: torch.Tensor, m: int, n: int, sf_vec_size: int = 16
|
|
) -> torch.Tensor:
|
|
"""Pad scale factors tensor to meet alignment requirements.
|
|
|
|
Args:
|
|
unswizzled_sf (torch.Tensor): Input scale factors tensor with dtype uint8.
|
|
m (int): M dimension.
|
|
n (int): N dimension.
|
|
sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
|
|
|
|
Returns:
|
|
torch.Tensor: Padded scale factors tensor.
|
|
"""
|
|
factor = sf_vec_size * 4
|
|
padded_row = ((m + 128 - 1) // 128) * 128 # Next multiple of 128
|
|
padded_col = ((n + factor - 1) // factor) * factor # Next multiple of 64
|
|
|
|
# Pad the input tensor to [padded_row, padded_col // scaling_vector_size]
|
|
pad_rows = padded_row - m
|
|
pad_cols = (padded_col - n) // sf_vec_size
|
|
if pad_rows == 0 and pad_cols == 0:
|
|
return unswizzled_sf
|
|
else:
|
|
return torch.nn.functional.pad(
|
|
unswizzled_sf, (0, pad_cols, 0, pad_rows), mode="constant", value=0
|
|
).contiguous()
|
|
|
|
|
|
def gen_fp4_quantization_sm100_module() -> JitSpec:
|
|
return gen_fp4_quantization_module(sm100a_nvcc_flags, "100")
|
|
|
|
|
|
def gen_fp4_quantization_sm103_module() -> JitSpec:
|
|
return gen_fp4_quantization_module(sm103a_nvcc_flags, "103")
|
|
|
|
|
|
def gen_fp4_quantization_sm90_module() -> JitSpec:
|
|
return gen_fp4_quantization_module(sm90a_nvcc_flags, "90")
|
|
|
|
|
|
def gen_fp4_quantization_sm110_module() -> JitSpec:
|
|
return gen_fp4_quantization_module(sm110a_nvcc_flags, "110")
|
|
|
|
|
|
def gen_fp4_quantization_sm120_module() -> JitSpec:
|
|
return gen_fp4_quantization_module(sm120a_nvcc_flags, "120")
|
|
|
|
|
|
def gen_fp4_quantization_sm121_module() -> JitSpec:
|
|
return gen_fp4_quantization_module(sm121a_nvcc_flags, "121")
|
|
|
|
|
|
def gen_fp4_quantization_module(nvcc_flags: List[str], device_arch: str) -> JitSpec:
|
|
return gen_jit_spec(
|
|
f"fp4_quantization_{device_arch}",
|
|
[
|
|
jit_env.FLASHINFER_CSRC_DIR
|
|
/ "nv_internal/tensorrt_llm/thop/fp4Quantize.cpp",
|
|
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/tensorrt_llm/thop/fp4Op.cpp",
|
|
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/kernels/quantization.cu",
|
|
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/envUtils.cpp",
|
|
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/logger.cpp",
|
|
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/stringUtils.cpp",
|
|
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/tllmException.cpp",
|
|
],
|
|
extra_cuda_cflags=nvcc_flags
|
|
+ [
|
|
"-DENABLE_BF16",
|
|
"-DENABLE_FP8",
|
|
"-DENABLE_FP4" if is_cuda_version_at_least("12.8") else "",
|
|
],
|
|
extra_cflags=[
|
|
"-DENABLE_BF16",
|
|
"-DENABLE_FP8",
|
|
"-DENABLE_FP4" if is_cuda_version_at_least("12.8") else "",
|
|
],
|
|
extra_include_paths=[
|
|
jit_env.FLASHINFER_CSRC_DIR / "nv_internal",
|
|
jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include",
|
|
],
|
|
)
|
|
|
|
|
|
@functools.cache
|
|
def get_fp4_quantization_module(backend: str = "100"):
|
|
backend_modules = {
|
|
"121": gen_fp4_quantization_sm121_module,
|
|
"120": gen_fp4_quantization_sm120_module,
|
|
"110": gen_fp4_quantization_sm110_module,
|
|
"103": gen_fp4_quantization_sm103_module,
|
|
"100": gen_fp4_quantization_sm100_module,
|
|
"90": gen_fp4_quantization_sm90_module,
|
|
}
|
|
|
|
if backend not in backend_modules:
|
|
raise ValueError(f"Invalid backend: {backend}")
|
|
|
|
module = backend_modules[backend]().build_and_load()
|
|
|
|
@register_custom_op(
|
|
"flashinfer::fp4_quantize_sm100",
|
|
mutates_args=(""),
|
|
)
|
|
def fp4_quantize_sm100(
|
|
input: torch.Tensor,
|
|
global_scale: Optional[torch.Tensor] = None,
|
|
sf_vec_size: int = 16,
|
|
sf_use_ue8m0: bool = False,
|
|
is_sf_swizzled_layout: bool = True,
|
|
is_sf_8x4_layout: bool = False,
|
|
enable_pdl: Optional[bool] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Quantize input tensor to FP4 format.
|
|
|
|
Args:
|
|
input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
|
|
global_scale (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
|
|
sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
|
|
sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False.
|
|
is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
|
|
is_sf_8x4_layout (bool, optional): Whether to use 8x4 layout or 128x4 layout for scale factors. Defaults to False.
|
|
enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
|
|
If None, automatically detects based on device capability. Defaults to None.
|
|
|
|
Returns:
|
|
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
|
- Quantized tensor of shape [M, K/2] with dtype FLOAT4_E2M1X2
|
|
- Scale factors tensor with shape determined by layout and sf_vec_size
|
|
"""
|
|
if enable_pdl is None:
|
|
enable_pdl = device_support_pdl(input.device)
|
|
return module.fp4_quantize(
|
|
input,
|
|
global_scale,
|
|
sf_vec_size,
|
|
sf_use_ue8m0,
|
|
is_sf_swizzled_layout,
|
|
is_sf_8x4_layout,
|
|
enable_pdl,
|
|
)
|
|
|
|
@register_fake_op("flashinfer::fp4_quantize_sm100")
|
|
def _fake_fp4_quantize_sm100(
|
|
input: torch.Tensor,
|
|
global_scale: Optional[torch.Tensor] = None,
|
|
sf_vec_size: int = 16,
|
|
sf_use_ue8m0: bool = False,
|
|
is_sf_swizzled_layout: bool = True,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
m, k = input.shape
|
|
return (
|
|
input.new_empty([m, k // 2], dtype=torch.int64), # FLOAT4_E2M1X2
|
|
input.new_empty([m * k // sf_vec_size], dtype=torch.int32), # Scale factors
|
|
)
|
|
|
|
@register_custom_op(
|
|
"flashinfer::mxfp4_dequantize_host",
|
|
mutates_args=(""),
|
|
)
|
|
def mxfp4_dequantize_host(
|
|
weight: torch.Tensor,
|
|
scale: torch.Tensor,
|
|
group_size: int = 32,
|
|
) -> torch.Tensor:
|
|
return module.mxfp4_dequantize_host(
|
|
weight,
|
|
scale,
|
|
group_size,
|
|
)
|
|
|
|
@register_fake_op("flashinfer::mxfp4_dequantize_host")
|
|
def _fake_mxfp4_dequantize_host(
|
|
weight: torch.Tensor,
|
|
scale: torch.Tensor,
|
|
group_size: int = 32,
|
|
) -> torch.Tensor:
|
|
return weight.new_empty(
|
|
[weight.shape[0], weight.shape[1] * 2], dtype=torch.float32
|
|
)
|
|
|
|
@register_custom_op(
|
|
"flashinfer::block_scale_interleave_sm100",
|
|
mutates_args=("",),
|
|
)
|
|
def block_scale_interleave_sm100(
|
|
unswizzled_sf: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""Swizzle block scale tensor for FP4 format.
|
|
|
|
Args:
|
|
unswizzled_sf (torch.Tensor): unswizzled block scale tensor with dtype uint8.
|
|
|
|
Returns:
|
|
torch.Tensor: output tensor for swizzled block scale with dtype uint8.
|
|
"""
|
|
return module.block_scale_interleave_sm100(
|
|
unswizzled_sf,
|
|
)
|
|
|
|
@register_fake_op("flashinfer::block_scale_interleave_sm100")
|
|
def _fake_block_scale_interleave_sm100(
|
|
unswizzled_sf: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
return unswizzled_sf.new_empty(
|
|
[unswizzled_sf.shape[0] * unswizzled_sf.shape[1] // 16], dtype=torch.uint8
|
|
)
|
|
|
|
@register_custom_op(
|
|
"flashinfer::e2m1_and_ufp8sf_scale_to_float_sm100",
|
|
mutates_args=(""),
|
|
)
|
|
def e2m1_and_ufp8sf_scale_to_float_sm100(
|
|
e2m1_tensor: torch.Tensor,
|
|
ufp8_scale_tensor: torch.Tensor,
|
|
global_scale_tensor: Optional[torch.Tensor] = None,
|
|
sf_vec_size: int = 16,
|
|
ufp8_type: int = 1,
|
|
is_sf_swizzled_layout: bool = True,
|
|
) -> torch.Tensor:
|
|
"""Convert E2M1 format tensor and UFP8 scale factors to float tensor.
|
|
|
|
This function performs dequantization by converting a packed FP4 tensor in E2M1 format
|
|
back to float values using the associated UFP8 scale factors and global scale.
|
|
|
|
Args:
|
|
e2m1_tensor (torch.Tensor): Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8.
|
|
ufp8_scale_tensor (torch.Tensor): Scale factors tensor in UFP8 format with dtype uint8.
|
|
global_scale_tensor (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
|
|
sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
|
|
ufp8_type (int, optional): UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1.
|
|
is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True.
|
|
|
|
Returns:
|
|
torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32.
|
|
"""
|
|
return module.e2m1_and_ufp8sf_scale_to_float_sm100(
|
|
e2m1_tensor.cpu(),
|
|
ufp8_scale_tensor.cpu().reshape(-1),
|
|
global_scale_tensor.cpu(),
|
|
sf_vec_size,
|
|
ufp8_type,
|
|
is_sf_swizzled_layout,
|
|
)
|
|
|
|
@register_fake_op("flashinfer::e2m1_and_ufp8sf_scale_to_float_sm100")
|
|
def _fake_e2m1_and_ufp8sf_scale_to_float_sm100(
|
|
e2m1_tensor: torch.Tensor,
|
|
ufp8_scale_tensor: torch.Tensor,
|
|
global_scale_tensor: Optional[torch.Tensor] = None,
|
|
sf_vec_size: int = 16,
|
|
ufp8_type: int = 1,
|
|
is_sf_swizzled_layout: bool = True,
|
|
) -> torch.Tensor:
|
|
return e2m1_tensor.new_empty(
|
|
[e2m1_tensor.shape[0], e2m1_tensor.shape[1] * 2], dtype=torch.float32
|
|
)
|
|
|
|
# Register the module
|
|
return SimpleNamespace(
|
|
fp4_quantize_sm100=fp4_quantize_sm100,
|
|
block_scale_interleave_sm100=block_scale_interleave_sm100,
|
|
e2m1_and_ufp8sf_scale_to_float_sm100=e2m1_and_ufp8sf_scale_to_float_sm100,
|
|
mxfp4_dequantize_host=mxfp4_dequantize_host,
|
|
)
|
|
|
|
|
|
def fp4_quantize(
|
|
input: torch.Tensor,
|
|
global_scale: Optional[torch.Tensor] = None,
|
|
sf_vec_size: int = 16,
|
|
sf_use_ue8m0: bool = False,
|
|
is_sf_swizzled_layout: bool = True,
|
|
is_sf_8x4_layout: bool = False,
|
|
enable_pdl: Optional[bool] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Quantize input tensor to FP4 format.
|
|
|
|
This function implements FP4 quantization that converts input tensors to a compressed FP4 format
|
|
with associated scale factors. It supports various input data types and scale factor layouts.
|
|
|
|
Args:
|
|
input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
|
|
global_scale (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
|
|
sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
|
|
sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False.
|
|
is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
|
|
is_sf_8x4_layout (bool, optional): Whether to use 8x4 layout or 128x4 layout for scale factors. Defaults to False.
|
|
enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
|
|
If None, automatically detects based on device capability. Defaults to None.
|
|
|
|
Returns:
|
|
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
|
- Quantized tensor of shape [M, K/2] with dtype FLOAT4_E2M1X2
|
|
- Scale factors tensor with shape determined by layout and sf_vec_size
|
|
|
|
Raises:
|
|
NotImplementedError: If any of the following features are requested but not implemented:
|
|
- BFloat16 input when BFloat16 is not enabled
|
|
- FP8 input when FP8 is not enabled
|
|
- sf_vec_size other than 16 or 32
|
|
"""
|
|
if sf_vec_size != 16 and sf_vec_size != 32:
|
|
raise NotImplementedError("sf_vec_size can only be 16 or 32")
|
|
|
|
# for column major input, we need to transpose the input
|
|
is_column_major = input.stride(-2) == 1
|
|
if is_column_major:
|
|
input = input.transpose(-2, -1)
|
|
|
|
assert input.shape[-1] % sf_vec_size == 0
|
|
if enable_pdl is None:
|
|
enable_pdl = device_support_pdl(input.device)
|
|
# get input device sm version
|
|
major, minor = get_compute_capability(input.device)
|
|
x_q, sf = get_fp4_quantization_module(f"{major}{minor}").fp4_quantize_sm100(
|
|
input,
|
|
global_scale,
|
|
sf_vec_size,
|
|
sf_use_ue8m0,
|
|
is_sf_swizzled_layout,
|
|
is_sf_8x4_layout,
|
|
enable_pdl,
|
|
)
|
|
sf = sf.reshape((-1, input.shape[-1] // sf_vec_size))
|
|
if is_column_major:
|
|
x_q = x_q.transpose(-2, -1)
|
|
sf = sf.transpose(-2, -1)
|
|
|
|
return x_q, sf
|
|
|
|
|
|
def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor:
|
|
"""Swizzle block scale tensor for FP4 format.
|
|
|
|
This function swizzles the block scale tensor to optimize memory access patterns
|
|
for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128.
|
|
|
|
Args:
|
|
unswizzled_sf (torch.Tensor): Input tensor with dtype uint8.
|
|
|
|
Returns:
|
|
torch.Tensor: Swizzled tensor with the same shape as input.
|
|
|
|
Raises:
|
|
AssertionError: If input dtype is not uint8.
|
|
"""
|
|
# TODO(shuw): check input dtype is uint8
|
|
assert unswizzled_sf.dtype == torch.uint8, (
|
|
f"Input dtype must be uint8, got {unswizzled_sf.dtype}"
|
|
)
|
|
|
|
major, minor = torch.cuda.get_device_capability()
|
|
device_arch = f"{major * 10 + minor}"
|
|
|
|
return get_fp4_quantization_module(device_arch).block_scale_interleave_sm100(
|
|
unswizzled_sf,
|
|
)
|
|
|
|
|
|
# Maintain compatibility with libraries using the old name
|
|
nvfp4_block_scale_interleave = block_scale_interleave
|
|
|
|
|
|
def e2m1_and_ufp8sf_scale_to_float(
|
|
e2m1_tensor: torch.Tensor,
|
|
ufp8_scale_tensor: torch.Tensor,
|
|
global_scale_tensor: Optional[torch.Tensor] = None,
|
|
sf_vec_size: int = 16,
|
|
ufp8_type: int = 1,
|
|
is_sf_swizzled_layout: bool = True,
|
|
) -> torch.Tensor:
|
|
"""Convert E2M1 format tensor and UFP8 scale factors to float tensor.
|
|
|
|
This function performs dequantization by converting a packed FP4 tensor in E2M1 format
|
|
back to float values using the associated UFP8 scale factors and global scale.
|
|
|
|
Args:
|
|
e2m1_tensor (torch.Tensor): Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8.
|
|
ufp8_scale_tensor (torch.Tensor): Scale factors tensor in UFP8 format with dtype uint8.
|
|
global_scale_tensor (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
|
|
sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
|
|
ufp8_type (int, optional): UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1.
|
|
is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True.
|
|
|
|
Returns:
|
|
torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32.
|
|
|
|
"""
|
|
major, minor = torch.cuda.get_device_capability()
|
|
device_arch = f"{major * 10 + minor}"
|
|
return get_fp4_quantization_module(
|
|
device_arch
|
|
).e2m1_and_ufp8sf_scale_to_float_sm100(
|
|
e2m1_tensor,
|
|
ufp8_scale_tensor,
|
|
global_scale_tensor,
|
|
sf_vec_size,
|
|
ufp8_type,
|
|
is_sf_swizzled_layout,
|
|
)
|
|
|
|
|
|
def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor:
|
|
"""
|
|
PyTorch equivalent of trtllm-gen `shuffleMatrixA`
|
|
"""
|
|
row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m)
|
|
|
|
return input_tensor[row_indices.to(input_tensor.device)]
|
|
|
|
|
|
def shuffle_matrix_sf_a(
|
|
input_tensor: torch.Tensor,
|
|
epilogue_tile_m: int,
|
|
num_elts_per_sf: int = 16,
|
|
):
|
|
"""
|
|
Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat.
|
|
`shuffleMatrixSfA` expects the input to be in 128x4 layout and then
|
|
apply the same shuffling in `shuffleMatrixA` and writes out in 128x4
|
|
layout.
|
|
This function expects the input to be in linear layout. It's done this
|
|
way because the scaling factors in the NVFP4 checkpoints are quantized
|
|
and are in linear layout.
|
|
This function doesn't add padding.
|
|
"""
|
|
|
|
row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor, epilogue_tile_m)
|
|
|
|
w_shuffled = input_tensor[row_indices.to(input_tensor.device)]
|
|
|
|
# 128x4
|
|
return block_scale_interleave(w_shuffled)
|
|
|
|
|
|
class SfLayout(Enum):
|
|
"""
|
|
Layout of scale factors for NVFP4.
|
|
"""
|
|
|
|
layout_128x4 = 0
|
|
layout_8x4 = 1
|
|
layout_linear = 2
|
|
|
|
|
|
def nvfp4_quantize(
|
|
a,
|
|
a_global_sf,
|
|
sfLayout=SfLayout.layout_128x4,
|
|
do_shuffle=False,
|
|
sf_vec_size=16,
|
|
enable_pdl=None,
|
|
):
|
|
"""
|
|
Quantize input tensor to NVFP4 format.
|
|
|
|
Parameters:
|
|
a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16.
|
|
a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
|
|
sfLayout (SfLayout, optional): Scale factor layout. Defaults to SfLayout.layout_128x4.
|
|
do_shuffle (bool, optional): Whether to shuffle the scale factors. Defaults to False. Only TRTLLM backend needs to shuffle the tensor B scale factors.
|
|
sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
|
|
enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
|
|
If None, automatically detects based on device capability. Defaults to None.
|
|
|
|
Returns:
|
|
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
|
- Quantized tensor of shape [M, K/2] with dtype FLOAT4_E2M1X2
|
|
- Scale factors tensor with shape determined by layout and sf_vec_size
|
|
"""
|
|
|
|
if do_shuffle:
|
|
# Weights 128x4 + shuffle. It is done during the model load and we do not care much about the perf
|
|
assert sfLayout == SfLayout.layout_128x4
|
|
a_fp4, a_sf = fp4_quantize(
|
|
a.cuda(),
|
|
a_global_sf.cuda(),
|
|
sf_vec_size,
|
|
sf_use_ue8m0=False,
|
|
is_sf_swizzled_layout=False,
|
|
is_sf_8x4_layout=False,
|
|
enable_pdl=enable_pdl,
|
|
)
|
|
|
|
epilogue_tile_m = 128
|
|
a_fp4 = shuffle_matrix_a(a_fp4.view(torch.uint8), epilogue_tile_m)
|
|
a_sf = shuffle_matrix_sf_a(a_sf.view(torch.uint8), epilogue_tile_m).reshape(
|
|
a_sf.shape
|
|
)
|
|
else:
|
|
# Activations with 8x4 layout for SFs (GEMM with small tileN)
|
|
# Activations with 128x4 layout for SFs (GEMM with large tileN)
|
|
a_fp4, a_sf = fp4_quantize(
|
|
a.cuda(),
|
|
a_global_sf.cuda(),
|
|
sf_vec_size,
|
|
sf_use_ue8m0=False,
|
|
is_sf_swizzled_layout=True,
|
|
is_sf_8x4_layout=sfLayout == SfLayout.layout_8x4,
|
|
enable_pdl=enable_pdl,
|
|
)
|
|
|
|
return a_fp4, a_sf
|
|
|
|
|
|
def mxfp4_quantize(a):
|
|
"""
|
|
Quantize input tensor to MXFP4 format.
|
|
|
|
Parameters:
|
|
a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16.
|
|
|
|
Returns:
|
|
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
|
- Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
|
|
- Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
|
|
"""
|
|
a_global_sf = (448 * 6) / a.float().abs().nan_to_num().max()
|
|
a_fp4, a_sf = fp4_quantize(a.cuda(), a_global_sf.cuda(), 32, True, True)
|
|
return a_fp4, a_sf
|
|
|
|
|
|
def mxfp4_dequantize(a_fp4, a_sf):
|
|
"""
|
|
Dequantize input tensor from MXFP4 format.
|
|
|
|
Parameters:
|
|
a_fp4 (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
|
|
a_sf (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
|
|
|
|
Returns:
|
|
torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
|
|
"""
|
|
return e2m1_and_ufp8sf_scale_to_float(
|
|
a_fp4.cpu().view(torch.uint8),
|
|
a_sf.cpu().view(torch.uint8).reshape(-1),
|
|
torch.tensor([1.0], device=a_fp4.device),
|
|
32,
|
|
0,
|
|
True,
|
|
)
|
|
|
|
|
|
def mxfp4_dequantize_host(
|
|
weight: torch.Tensor,
|
|
scale: torch.Tensor,
|
|
group_size: int = 32,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Dequantize input tensor from MXFP4 format on host.
|
|
|
|
Parameters:
|
|
weight (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
|
|
scale (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
|
|
group_size (int, optional): Group size for dequantization. Defaults to 32.
|
|
|
|
Returns:
|
|
torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
|
|
"""
|
|
major, minor = torch.cuda.get_device_capability()
|
|
device_arch = f"{major * 10 + minor}"
|
|
return get_fp4_quantization_module(device_arch).mxfp4_dequantize_host(
|
|
weight,
|
|
scale,
|
|
group_size,
|
|
)
|