import functools from types import SimpleNamespace from typing import Optional, Tuple import torch from .jit import JitSpec from .jit import env as jit_env from .jit import gen_jit_spec, current_compilation_context from .utils import ( device_support_pdl, register_custom_op, register_fake_op, ) def gen_mxfp8_quantization_sm100_module() -> JitSpec: return gen_jit_spec( "mxfp8_quantization_sm100", [ jit_env.FLASHINFER_CSRC_DIR / "nv_internal/tensorrt_llm/thop/fp8Quantize.cpp", jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/kernels/quantization.cu", jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/envUtils.cpp", jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/logger.cpp", jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/stringUtils.cpp", jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/tllmException.cpp", ], extra_cuda_cflags=current_compilation_context.get_nvcc_flags_list( supported_major_versions=None ) + [ "-DENABLE_BF16", "-DENABLE_FP8", "-DENABLE_FP4", ], extra_cflags=[ "-DENABLE_BF16", "-DENABLE_FP8", "-DENABLE_FP4", ], extra_include_paths=[ jit_env.FLASHINFER_CSRC_DIR / "nv_internal", jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include", ], ) @functools.cache def get_mxfp8_quantization_sm100_module(): module = gen_mxfp8_quantization_sm100_module().build_and_load() @register_custom_op( "flashinfer::mxfp8_quantize_sm100", mutates_args=(""), ) def mxfp8_quantize_sm100( input: torch.Tensor, is_sf_swizzled_layout: bool = True, alignment: int = 32, enable_pdl: Optional[bool] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Quantize input tensor to MxFP8 format. Args: input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. alignment (int, optional): sfVecSize. Defaults to 32. Note that alignment is not used in the host kernel. enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). If None, automatically detects based on device capability. Defaults to None. Returns: Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - Quantized tensor of shape [M, K] with dtype FLOAT8_E4M3 - Scale factors tensor with shape determined by layout and sf_vec_size """ if input.device.type == "cpu": return module.mxfp8_quantize_host( input, is_sf_swizzled_layout, ) else: if enable_pdl is None: enable_pdl = device_support_pdl(input.device) return module.mxfp8_quantize( input, is_sf_swizzled_layout, alignment, enable_pdl, ) @register_fake_op("flashinfer::mxfp8_quantize_sm100") def _fake_mxfp8_quantize_sm100( input: torch.Tensor, is_sf_swizzled_layout: bool = True, alignment: int = 32, ) -> Tuple[torch.Tensor, torch.Tensor]: m, k = input.shape return ( input.new_empty([m, k], dtype=torch.int64), # FLOAT8_E4M3 input.new_empty([m * k // 32], dtype=torch.int32), # Scale factors ) @register_custom_op( "flashinfer::mxfp8_dequantize_host_sm100", mutates_args=("",), ) def mxfp8_dequantize_host_sm100( input: torch.Tensor, scale_tensor: torch.Tensor, is_sf_swizzled_layout: bool = True, ) -> torch.Tensor: """Dequantize input tensor from MxFP8 format. Args: input (torch.Tensor): Input tensor of shape [M, K] with dtype FLOAT8_E4M3. scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size. is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. Returns: torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32. """ return module.mxfp8_dequantize_host( input, scale_tensor, is_sf_swizzled_layout, ) @register_fake_op("flashinfer::mxfp8_dequantize_host_sm100") def _fake_mxfp8_dequantize_host_sm100( input: torch.Tensor, scale_tensor: torch.Tensor, is_sf_swizzled_layout: bool = True, ) -> torch.Tensor: return input.new_empty([input.shape[0], input.shape[1]], dtype=torch.float32) # Register the module return SimpleNamespace( mxfp8_quantize_sm100=mxfp8_quantize_sm100, mxfp8_dequantize_host_sm100=mxfp8_dequantize_host_sm100, ) def mxfp8_quantize( input: torch.Tensor, is_sf_swizzled_layout: bool = True, alignment: int = 32, enable_pdl: Optional[bool] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Quantize input tensor to MxFP8 format. This function implements MxFP8 quantization that converts input tensors to a compressed MxFP8 format with associated scale factors. It supports various input data types and scale factor layouts. Args: input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. alignment (int, optional): sfVecSize. Defaults to 32. enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). If None, automatically detects based on device capability. Defaults to None. Returns: Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - Quantized tensor of shape [M, K] with dtype FLOAT8_E4M3 - Scale factors tensor with shape determined by layout and sf_vec_size """ sf_vec_size = 32 assert input.shape[-1] % sf_vec_size == 0 if enable_pdl is None: enable_pdl = device_support_pdl(input.device) x_q, sf = get_mxfp8_quantization_sm100_module().mxfp8_quantize_sm100( input, is_sf_swizzled_layout, alignment, enable_pdl, ) return x_q, sf def mxfp8_dequantize_host( input: torch.Tensor, scale_tensor: torch.Tensor, is_sf_swizzled_layout: bool = True, ) -> torch.Tensor: """Dequantize input tensor from MxFP8 format. This function performs dequantization by converting a packed FP8 tensor in MxFP8 format back to float values using the associated scale factors. Args: input (torch.Tensor): Packed FP8 tensor in MxFP8 format of shape [M, K] with dtype FLOAT8_E4M3. scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size. is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True. Returns: torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32. """ return get_mxfp8_quantization_sm100_module().mxfp8_dequantize_host_sm100( input, scale_tensor, is_sf_swizzled_layout, )