sglang_v0.5.2/flashinfer_0.3.1/tests/test_trtllm_gen_fused_moe.py

2121 lines
75 KiB
Python

"""
Copyright (c) 2025 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from abc import ABC, abstractmethod
from enum import IntEnum
from typing import Dict
import pytest
import torch
from cuda.bindings import runtime
from torch.nn import functional as F
from flashinfer import (
RoutingMethodType,
GatedActType,
e2m1_and_ufp8sf_scale_to_float,
fp4_quantize,
mxfp8_dequantize_host,
mxfp8_quantize,
reorder_rows_for_gated_act_gemm,
shuffle_matrix_a,
)
from flashinfer.fp4_quantization import block_scale_interleave
from flashinfer.fused_moe import (
WeightLayout,
convert_to_block_layout,
trtllm_fp4_block_scale_moe,
trtllm_fp8_block_scale_moe,
trtllm_fp8_per_tensor_scale_moe,
)
from flashinfer.fused_moe.core import (
_maybe_get_cached_w2_permute_indices,
_maybe_get_cached_w3_w1_permute_indices,
)
from flashinfer.utils import calculate_tile_tokens_dim
def check_cuda(err):
"""Unified CUDA error checking function used throughout the file."""
if err != runtime.cudaError_t.cudaSuccess:
error_name = runtime.cudaGetErrorName(err)
error_string = runtime.cudaGetErrorString(err)
raise RuntimeError(f"CUDA error: {error_name[1]}: {error_string[1]}")
class CUDAGraphMoE:
"""
Simple CUDA Graph wrapper for MoE operations.
The graph captures tensor references and automatically updates them during execution.
Three core methods: capture(), launch(), cleanup()
Usage:
cuda_graph = CUDAGraphMoE(moe_impl, static_data, **config)
cuda_graph.capture(hidden_states_sample, expert_logits=logits, routing_bias=bias)
output = cuda_graph.launch(new_hidden_states) # Repeat as needed
cuda_graph.cleanup()
"""
def __init__(self, moe_impl, static_data, **config):
self.moe_impl = moe_impl
self.static_data = static_data
self.config = config
self.graph = None
self.graph_exec = None
self.stream = None
self.input_tensor = None
self.output_tensor = None
self.is_captured = False
def capture(self, hidden_states_sample, **runtime_args):
"""Capture CUDA graph with the given sample input."""
if self.is_captured:
raise RuntimeError(
"Graph already captured. Call cleanup() first to re-capture."
)
if not isinstance(self.moe_impl, FP4Moe):
raise NotImplementedError(
f"CUDA graph capture not yet implemented for {type(self.moe_impl)}"
)
# Create stream
err, self.stream = runtime.cudaStreamCreate()
check_cuda(err)
# Get the raw stream pointer for PyTorch
stream_ptr = int(self.stream)
torch_stream = torch.cuda.ExternalStream(stream_ptr)
# Store input tensor reference (will be updated in place during launch)
self.input_tensor = hidden_states_sample.clone()
# Warmup
with torch.cuda.stream(torch_stream):
for _ in range(1):
self._run_moe_computation(runtime_args)
# Synchronize our stream after warmup
err = runtime.cudaStreamSynchronize(self.stream)[0]
check_cuda(err)
# Begin capture
err, self.graph = runtime.cudaGraphCreate(0)
check_cuda(err)
err = runtime.cudaStreamBeginCapture(
self.stream, runtime.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal
)[0]
check_cuda(err)
try:
# Capture computation on our stream
with torch.cuda.stream(torch_stream):
self.output_tensor = self._run_moe_computation(runtime_args)
err, self.graph = runtime.cudaStreamEndCapture(self.stream)
check_cuda(err)
err, self.graph_exec = runtime.cudaGraphInstantiate(self.graph, 0)
check_cuda(err)
self.is_captured = True
except Exception as e:
self.cleanup()
raise RuntimeError(f"CUDA graph capture failed: {e}") from e
def launch(self, hidden_states_new):
"""Launch captured CUDA graph with new input."""
if not self.is_captured:
raise RuntimeError("Graph not captured. Call capture() first.")
# Update input tensor in place
self.input_tensor.copy_(hidden_states_new)
# Launch graph
err = runtime.cudaGraphLaunch(self.graph_exec, self.stream)[0]
check_cuda(err)
err = runtime.cudaStreamSynchronize(self.stream)[0]
check_cuda(err)
# Return output tensor (automatically updated by graph execution)
return self.output_tensor
def cleanup(self):
"""Clean up all CUDA graph resources."""
if self.graph_exec is not None:
err = runtime.cudaGraphExecDestroy(self.graph_exec)[0]
check_cuda(err)
self.graph_exec = None
if self.graph is not None:
err = runtime.cudaGraphDestroy(self.graph)[0]
check_cuda(err)
self.graph = None
if self.stream is not None:
err = runtime.cudaStreamDestroy(self.stream)[0]
check_cuda(err)
self.stream = None
self.input_tensor = None
self.output_tensor = None
self.is_captured = False
def _run_moe_computation(self, runtime_args):
"""Run the MoE computation."""
input_quantized = self.moe_impl.quantize_inputs(
self.input_tensor,
self.config["hidden_states_scale_global"],
is_swizzling=False,
)
output = trtllm_fp4_block_scale_moe(
routing_logits=runtime_args["expert_logits"],
routing_bias=runtime_args["routing_bias"],
hidden_states=input_quantized["hidden_states"],
hidden_states_scale=input_quantized["hidden_states_scale"],
gemm1_weights=self.static_data["gemm1_weights_fp4_shuffled"],
gemm1_weights_scale=self.static_data["gemm1_scales_fp4_shuffled"],
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=self.static_data["gemm2_weights_fp4_shuffled"],
gemm2_weights_scale=self.static_data["gemm2_scales_fp4_shuffled"],
gemm2_bias=None,
output1_scale_scalar=self.static_data["scale_c_fc1"],
output1_scale_gate_scalar=self.static_data["scale_gate_fc1"],
output2_scale_scalar=self.static_data["scale_c_fc2"],
num_experts=self.config["num_experts"],
top_k=self.config["top_k"],
n_group=self.config["n_groups"],
topk_group=self.config["top_k_groups"],
intermediate_size=self.config["intermediate_size"],
local_expert_offset=0,
local_num_experts=self.config["num_experts"],
routed_scaling_factor=self.config["routed_scaling"],
tile_tokens_dim=self.config["tile_tokens_dim"],
routing_method_type=self.config["routing_method_type"],
gated_act_type=self.config["gated_act_type"],
do_finalize=True,
)
return output # Extract tensor from tuple
class QuantMode(IntEnum):
"""Supported quantization modes for MoE testing."""
FP4_NVFP4_NVFP4 = 1
FP4_MXFP4_MXFP8 = 2
FP4_MXFP4_Bf16 = 3
FP8_BLOCK_SCALE = 4
FP8_PER_TENSOR = 5
# ====================================================================================
# Abstract Base Class for MoE Implementations
# ====================================================================================
class Moe(ABC):
"""Abstract base class for MoE implementations."""
def __init__(self):
self.name = self.__class__.__name__
@abstractmethod
def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample):
"""Quantize static weights and compute global scale factors (done offline)."""
pass
@abstractmethod
def quantize_inputs(self, hidden_states, hidden_states_scale_global):
"""Quantize dynamic inputs/hidden states using pre-computed global scale (done at runtime)."""
pass
@abstractmethod
def prepare_static_weights_for_kernel(
self,
args_dequant,
args,
gemm1_weights_orig,
gemm2_weights_orig,
hidden_size,
intermediate_size,
num_experts,
weight_processing,
):
"""
Prepare quantized weights for kernel (done offline with weights).
Args:
args_dequant: Contains c_global_sf and other dequantization parameters
args: Contains already quantized weights (gemm1_weights, gemm2_weights) and scales
gemm1_weights_orig: Original unquantized FC1 weights (used by FP4 for re-quantization)
gemm2_weights_orig: Original unquantized FC2 weights (used by FP4 for re-quantization)
Note:
- FP4 implementations use both original weights (for linear layout quantization)
and args.gemm*_weights (for swizzled layout)
- FP8 implementations typically only use args.gemm*_weights (already quantized)
"""
pass
@abstractmethod
def call_moe(
self, static_data, hidden_states_orig, hidden_states_scale_global, **kwargs
):
"""Call MoE with runtime input quantization + kernel execution (done at runtime)."""
pass
@abstractmethod
def compute_reference(self, args):
"""Compute reference output using dequantized operations."""
pass
def compute_production(self, args_dequant, args, **kwargs):
"""Unified actual computation that delegates to implementation-specific methods."""
return _compute_moe_actual_unified(self, args_dequant, args, **kwargs)
@abstractmethod
def get_tolerances(self):
"""Get accuracy tolerances for this quantization mode."""
pass
def __str__(self):
return self.name
# ====================================================================================
# FP4 Quantization Implementation
# ====================================================================================
class FP4Moe(Moe):
"""
FP4 NvFP4 / MxFP4 MoE implementation with block scaling.
Args:
is_mxfp4: Whether to use MxFP4 or NvFP4 weight quantization
If True, the activation is quantized to MxFP8, else the activation is quantized to NvFP4
"""
def __init__(self, quant_mode: QuantMode):
super().__init__()
self.quant_mode = quant_mode
self.is_mxfp4 = (
quant_mode == QuantMode.FP4_MXFP4_MXFP8
or quant_mode == QuantMode.FP4_MXFP4_Bf16
)
self.sf_vec_size = 32 if self.is_mxfp4 else 16
def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample):
"""Quantize weights to FP4 format and compute global scale factors."""
num_experts = gemm1_weights.shape[0]
# Compute global scale factor for hidden states (offline calibration)
if self.quant_mode == QuantMode.FP4_NVFP4_NVFP4:
# nvfp4 hidden states
hidden_states_scale_global = calculate_fp4_global_scale_factor(
hidden_states_sample,
False,
)
else:
# mxfp8 / bf16 hidden states
hidden_states_scale_global = 1.0
# Quantize the weights for FC1
gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, gemm1_scales_global = (
quant_fp4_batches(gemm1_weights, num_experts, self.is_mxfp4, True)
)
# Quantize the weights for FC2
gemm2_weights_fp4_bytes, gemm2_scales_fp4_bytes, gemm2_scales_global = (
quant_fp4_batches(gemm2_weights, num_experts, self.is_mxfp4, True)
)
return {
"hidden_states_scale_global": hidden_states_scale_global,
"gemm1_weights": gemm1_weights_fp4_bytes,
"gemm1_scales": gemm1_scales_fp4_bytes,
"gemm1_scales_global": gemm1_scales_global,
"gemm2_weights": gemm2_weights_fp4_bytes,
"gemm2_scales": gemm2_scales_fp4_bytes,
"gemm2_scales_global": gemm2_scales_global,
}
def quantize_inputs(
self, hidden_states, hidden_states_scale_global, is_swizzling=True
):
if self.quant_mode == QuantMode.FP4_MXFP4_MXFP8:
"""Quantize hidden states to MxFP8 format."""
hidden_states_quant, hidden_states_scale = mxfp8_quantize(
hidden_states, is_swizzling
)
hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape(
*hidden_states.shape[:-1], -1
)
return {
"hidden_states": hidden_states_quant,
"hidden_states_scale": hidden_states_scale,
}
elif self.quant_mode == QuantMode.FP4_NVFP4_NVFP4:
"""Quantize hidden states to NvFP4 format using pre-computed global scale."""
(
hidden_states_fp4_bytes,
hidden_states_scale_fp4_bytes,
_,
) = quant_fp4(
hidden_states, hidden_states_scale_global, False, is_swizzling
)
hidden_states_scale_fp4_bytes = hidden_states_scale_fp4_bytes.view(
torch.float8_e4m3fn
).reshape(*hidden_states.shape[:-1], -1)
return {
"hidden_states": hidden_states_fp4_bytes,
"hidden_states_scale": hidden_states_scale_fp4_bytes,
}
else: # bf16
return {
"hidden_states": hidden_states.to(torch.bfloat16),
"hidden_states_scale": None,
}
def prepare_static_weights_for_kernel(
self,
args_dequant,
args,
gemm1_weights_orig,
gemm2_weights_orig,
hidden_size,
intermediate_size,
num_experts,
weight_processing,
):
"""Prepare quantized weights for kernel (done offline with weights)."""
use_ue8m0 = self.is_mxfp4
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
# Quantize weights with linear layout for kernels
_, gemm1_scales_linear_fp4_bytes, _ = quant_fp4_batches(
gemm1_weights_orig, num_experts, use_ue8m0, False
)
_, gemm2_scales_linear_fp4_bytes, _ = quant_fp4_batches(
gemm2_weights_orig, num_experts, use_ue8m0, False
)
# Convert quantized weights to proper formats
gemm1_weights_fp4 = args.gemm1_weights.view(torch.float8_e4m3fn).reshape(
num_experts, 2 * intermediate_size, hidden_size // 2
) # packed fp4
gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
torch.float8_e4m3fn
).reshape(
num_experts, 2 * intermediate_size, hidden_size // self.sf_vec_size
) # fp8 scaling factors
gemm2_weights_fp4 = args.gemm2_weights.view(torch.float8_e4m3fn).reshape(
num_experts, hidden_size, intermediate_size // 2
) # packed fp4
gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
torch.float8_e4m3fn
).reshape(
num_experts, hidden_size, intermediate_size // self.sf_vec_size
) # fp8 scaling factors
# Using cached permute index calculation can speed up weights preprocessing
gemm1_weights_fp4_shuffled = []
gemm1_scales_fp4_shuffled = []
gemm2_weights_fp4_shuffled = []
gemm2_scales_fp4_shuffled = []
for i in range(num_experts):
# Calculate the permute indices for the following:
# 1. Reorder rows of W1 and scales for fused gated activation
# 2. Shuffle weights and scaling factors for transposed mma output
# for both w3_w1 and w2 weights and scale factors
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
self._cache_permute_indices,
gemm1_weights_fp4[i].view(torch.uint8),
epilogue_tile_m,
)
gemm1_weights_fp4_shuffled.append(
gemm1_weights_fp4[i]
.view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
.contiguous()
)
permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
self._cache_permute_indices,
gemm1_scales_linear_fp4[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm1_scales_fp4_shuffled.append(
block_scale_interleave(
gemm1_scales_linear_fp4[i]
.view(torch.uint8)[
permute_sf_indices.to(gemm1_scales_linear_fp4.device)
]
.contiguous()
)
)
permute_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
gemm2_weights_fp4[i].view(torch.uint8),
epilogue_tile_m,
)
gemm2_weights_fp4_shuffled.append(
gemm2_weights_fp4[i]
.view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
.contiguous()
)
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
gemm2_scales_linear_fp4[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm2_scales_fp4_shuffled.append(
block_scale_interleave(
gemm2_scales_linear_fp4[i]
.view(torch.uint8)[
permute_sf_indices.to(gemm2_scales_linear_fp4.device)
]
.contiguous()
)
)
# Stack weights for all experts
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
gemm1_scales_fp4_shuffled = (
torch.stack(gemm1_scales_fp4_shuffled)
.view(torch.float8_e4m3fn)
.reshape(
num_experts, 2 * intermediate_size, hidden_size // self.sf_vec_size
)
)
gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
gemm2_scales_fp4_shuffled = (
torch.stack(gemm2_scales_fp4_shuffled)
.view(torch.float8_e4m3fn)
.reshape(num_experts, hidden_size, intermediate_size // self.sf_vec_size)
)
# Calculate scaling factors that depend on weights
scale_c_fc1 = (
args_dequant.c_global_sf
* (1.0 / args.gemm1_scales_global)
* (1.0 / args.hidden_states_scale_global)
)
scale_gate_fc1 = (1.0 / args.gemm1_scales_global) * (
1.0 / args.hidden_states_scale_global
)
scale_c_fc2 = (1.0 / args_dequant.c_global_sf) * (
1.0 / args.gemm2_scales_global
)
return {
"gemm1_weights_fp4_shuffled": gemm1_weights_fp4_shuffled,
"gemm1_scales_fp4_shuffled": gemm1_scales_fp4_shuffled,
"gemm2_weights_fp4_shuffled": gemm2_weights_fp4_shuffled,
"gemm2_scales_fp4_shuffled": gemm2_scales_fp4_shuffled,
"scale_c_fc1": scale_c_fc1,
"scale_gate_fc1": scale_gate_fc1,
"scale_c_fc2": scale_c_fc2,
}
def call_moe(
self, static_data, hidden_states_orig, hidden_states_scale_global, **kwargs
):
"""Call MoE using CUDA graph for maximum performance (create, capture, launch)."""
# Extract runtime arguments
expert_logits = kwargs["expert_logits"]
routing_bias = kwargs["routing_bias"]
num_experts = kwargs["num_experts"]
top_k = kwargs["top_k"]
n_groups = kwargs["n_groups"]
top_k_groups = kwargs["top_k_groups"]
intermediate_size = kwargs["intermediate_size"]
routed_scaling = kwargs["routed_scaling"]
gated_act_type = kwargs["gated_act_type"]
routing_method_type = kwargs["routing_method_type"]
tile_tokens_dim = kwargs["tile_tokens_dim"]
# Create CUDA graph configuration
config = {
"hidden_states_scale_global": hidden_states_scale_global,
"num_experts": num_experts,
"top_k": top_k,
"n_groups": n_groups,
"top_k_groups": top_k_groups,
"intermediate_size": intermediate_size,
"routed_scaling": routed_scaling,
"tile_tokens_dim": tile_tokens_dim,
"gated_act_type": gated_act_type,
"routing_method_type": routing_method_type,
}
runtime_args = {
"expert_logits": expert_logits,
"routing_bias": routing_bias,
}
# Create, capture and launch CUDA graph in one shot
cuda_graph = CUDAGraphMoE(self, static_data, **config)
try:
cuda_graph.capture(hidden_states_orig, **runtime_args)
output = cuda_graph.launch(hidden_states_orig)
return output[0].to(torch.float)
finally:
cuda_graph.cleanup()
def compute_reference(self, args):
return run_moe_reference_fp4(args, self.quant_mode)
def get_tolerances(self):
"""Get FP4-specific accuracy tolerances."""
return {"atol": 0.1, "rtol": 0.85, "percent": 0.925}
# ====================================================================================
# FP8 Block Scale Quantization Implementation
# ====================================================================================
class FP8BlockScaleMoe(Moe):
"""FP8 MoE implementation with block scaling (DeepSeek style)."""
def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample):
"""Quantize weights to FP8 with block scaling."""
num_experts = gemm1_weights.shape[0]
intermediate_size = gemm1_weights.shape[1] // 2
hidden_size = gemm1_weights.shape[
2
] # [num_experts, 2*intermediate_size, hidden_size]
# Quantize weights to FP8
gemm1_weights_fp8 = gemm1_weights.to(torch.float8_e4m3fn)
gemm1_scales = 2 * torch.rand(
(num_experts, 2 * intermediate_size // 128, hidden_size // 128),
device="cuda",
).to(torch.float)
gemm2_weights_fp8 = gemm2_weights.to(torch.float8_e4m3fn)
gemm2_scales = 2 * torch.rand(
(num_experts, hidden_size // 128, intermediate_size // 128), device="cuda"
).to(torch.float)
return {
"hidden_states_scale_global": None, # Block scales computed at runtime
"gemm1_weights": gemm1_weights_fp8,
"gemm1_scales": gemm1_scales,
"gemm1_scales_global": None,
"gemm2_weights": gemm2_weights_fp8,
"gemm2_scales": gemm2_scales,
"gemm2_scales_global": None,
}
def quantize_inputs(self, hidden_states, hidden_states_scale_global):
"""For FP8 block scaling, no pre-quantization - everything happens at runtime."""
return {
"hidden_states": hidden_states, # Keep original
"hidden_states_scale": None, # No pre-computed scales
}
def prepare_static_weights_for_kernel(
self,
args_dequant,
args,
gemm1_weights_orig,
gemm2_weights_orig,
hidden_size,
intermediate_size,
num_experts,
weight_processing,
):
"""Prepare quantized weights for kernel (done offline with weights)."""
# Use shuffled weights with BlockMajorK layout for better performance
use_shuffled_weight = weight_processing["use_shuffled_weight"]
weight_layout = weight_processing["layout"]
if use_shuffled_weight:
# FIXME: this depends on the kernel internals
epilogue_tile_m = 64
gemm1_weights_fp8_shuffled = []
gemm2_weights_fp8_shuffled = []
for i in range(num_experts):
tmp_weights1 = shuffle_matrix_a(
args.gemm1_weights[i].view(torch.uint8), epilogue_tile_m
)
tmp_weights2 = shuffle_matrix_a(
args.gemm2_weights[i].view(torch.uint8), epilogue_tile_m
)
if weight_layout == WeightLayout.BlockMajorK:
block_k = 128
tmp_weights1 = convert_to_block_layout(tmp_weights1, block_k)
tmp_weights2 = convert_to_block_layout(tmp_weights2, block_k)
gemm1_weights_fp8_shuffled.append(tmp_weights1)
gemm2_weights_fp8_shuffled.append(tmp_weights2)
kernel_gemm1_weights = torch.stack(gemm1_weights_fp8_shuffled).view(
torch.float8_e4m3fn
)
kernel_gemm2_weights = torch.stack(gemm2_weights_fp8_shuffled).view(
torch.float8_e4m3fn
)
else:
kernel_gemm1_weights = args.gemm1_weights
kernel_gemm2_weights = args.gemm2_weights
return {
"gemm1_weights": kernel_gemm1_weights,
"gemm1_scales": args.gemm1_scales,
"gemm2_weights": kernel_gemm2_weights,
"gemm2_scales": args.gemm2_scales,
"use_shuffled_weight": use_shuffled_weight,
"weight_layout": weight_layout,
}
def call_moe(
self, static_data, hidden_states_orig, hidden_states_scale_global, **kwargs
):
"""Call MoE with runtime block scale generation + kernel execution."""
expert_logits = kwargs["expert_logits"]
routing_bias = kwargs["routing_bias"]
num_experts = kwargs["num_experts"]
num_tokens = kwargs["num_tokens"]
hidden_size = kwargs["hidden_size"]
top_k = kwargs["top_k"]
n_groups = kwargs["n_groups"]
top_k_groups = kwargs["top_k_groups"]
intermediate_size = kwargs["intermediate_size"]
routed_scaling = kwargs["routed_scaling"]
routing_method_type = kwargs["routing_method_type"]
tile_tokens_dim = kwargs["tile_tokens_dim"]
enable_pdl = kwargs.get("enable_pdl")
# Generate block scales and quantize hidden states at runtime
hidden_states_fp8 = hidden_states_orig.to(torch.float8_e4m3fn)
# Use deterministic scales for testing consistency
hidden_states_scale = 2.0 * torch.ones(
(hidden_size // 128, num_tokens), device="cuda", dtype=torch.float
)
output = trtllm_fp8_block_scale_moe(
expert_logits,
routing_bias,
hidden_states_fp8,
hidden_states_scale,
static_data["gemm1_weights"],
static_data["gemm1_scales"],
static_data["gemm2_weights"],
static_data["gemm2_scales"],
num_experts,
top_k,
n_groups,
top_k_groups,
intermediate_size,
0,
num_experts,
routed_scaling,
tile_tokens_dim,
routing_method_type,
use_shuffled_weight=static_data["use_shuffled_weight"],
weight_layout=static_data["weight_layout"],
enable_pdl=enable_pdl,
)
return output.to(torch.float)
def compute_reference(self, args):
"""FP8 block-scale reference implementation."""
return run_moe_reference_dsfp8(args)
def get_tolerances(self):
"""Get FP8 block-scale accuracy tolerances."""
return {"atol": 0.1, "rtol": 0.85, "percent": 0.925}
# ====================================================================================
# FP8 Per-Tensor Quantization Implementation
# ====================================================================================
class FP8PerTensorMoe(Moe):
"""FP8 MoE implementation with per-tensor scaling (Llama4 style)."""
def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample):
"""Quantize weights to FP8 per-tensor and compute global scale factors."""
# Compute global scale factor for hidden states (offline calibration)
hidden_states_global_scale = calculate_fp8_global_scale_factor(
hidden_states_sample
)
# Quantize to FP8 per-tensor
gemm1_weights_quant, gemm1_global_scales = quant_fp8_per_tensor_batches(
gemm1_weights
)
gemm2_weights_quant, gemm2_global_scales = quant_fp8_per_tensor_batches(
gemm2_weights
)
return {
"hidden_states_scale_global": hidden_states_global_scale,
"gemm1_weights": gemm1_weights_quant,
"gemm1_scales": None,
"gemm1_scales_global": gemm1_global_scales,
"gemm2_weights": gemm2_weights_quant,
"gemm2_scales": None,
"gemm2_scales_global": gemm2_global_scales,
}
def quantize_inputs(self, hidden_states, hidden_states_scale_global):
"""Quantize hidden states to FP8 per-tensor using pre-computed global scale."""
# Quantize to FP8 per-tensor using pre-computed global scale factor
hidden_states_quant, _ = quant_fp8_per_tensor(
hidden_states, hidden_states_scale_global
)
return {
"hidden_states": hidden_states_quant,
"hidden_states_scale": None,
}
def prepare_static_weights_for_kernel(
self,
args_dequant,
args,
gemm1_weights_orig,
gemm2_weights_orig,
hidden_size,
intermediate_size,
num_experts,
weight_processing,
):
"""Prepare quantized weights for kernel (done offline with weights)."""
# FIXME: this depends on the kernel internals
epilogue_tile_m = 128
# Reorder rows of W1 for fused gated activation
gemm1_weights_fp8_interleaved = []
for i in range(num_experts):
gemm1_weights_fp8_interleaved.append(
reorder_rows_for_gated_act_gemm(args.gemm1_weights[i].clone())
)
# Stack weights and scales for all experts
gemm1_weights_fp8_interleaved = torch.stack(
gemm1_weights_fp8_interleaved
).reshape(num_experts, 2 * intermediate_size, hidden_size)
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_fp8_shuffled = []
gemm2_weights_fp8_shuffled = []
for i in range(num_experts):
gemm1_weights_fp8_shuffled.append(
shuffle_matrix_a(
gemm1_weights_fp8_interleaved[i].view(torch.uint8), epilogue_tile_m
)
)
gemm2_weights_fp8_shuffled.append(
shuffle_matrix_a(
args.gemm2_weights[i].view(torch.uint8), epilogue_tile_m
)
)
# Stack weights for all experts
gemm1_weights_fp8_shuffled = torch.stack(gemm1_weights_fp8_shuffled).view(
torch.float8_e4m3fn
)
gemm2_weights_fp8_shuffled = torch.stack(gemm2_weights_fp8_shuffled).view(
torch.float8_e4m3fn
)
# Calculate scaling factors that depend on weights
scale_c_fc1 = (
args_dequant.c_global_sf
* (1.0 / args.gemm1_scales_global)
* (1.0 / args.hidden_states_scale_global)
)
scale_gate_fc1 = (1.0 / args.gemm1_scales_global) * (
1.0 / args.hidden_states_scale_global
)
scale_c_fc2 = (1.0 / args_dequant.c_global_sf) * (
1.0 / args.gemm2_scales_global
)
return {
"gemm1_weights": gemm1_weights_fp8_shuffled,
"gemm2_weights": gemm2_weights_fp8_shuffled,
"scale_c_fc1": scale_c_fc1,
"scale_gate_fc1": scale_gate_fc1,
"scale_c_fc2": scale_c_fc2,
}
def call_moe(
self, static_data, hidden_states_orig, hidden_states_scale_global, **kwargs
):
"""Call MoE with runtime input quantization + kernel execution (done at runtime)."""
expert_logits = kwargs["expert_logits"]
routing_bias = kwargs["routing_bias"]
num_experts = kwargs["num_experts"]
top_k = kwargs["top_k"]
n_groups = kwargs["n_groups"]
top_k_groups = kwargs["top_k_groups"]
intermediate_size = kwargs["intermediate_size"]
routed_scaling = kwargs["routed_scaling"]
routing_method_type = kwargs["routing_method_type"]
tile_tokens_dim = kwargs["tile_tokens_dim"]
# Quantize to FP8 per-tensor using pre-computed global scale factor
hidden_states_fp8, _ = quant_fp8_per_tensor(
hidden_states_orig, hidden_states_scale_global
)
output = trtllm_fp8_per_tensor_scale_moe(
(
expert_logits.to(torch.bfloat16)
if routing_method_type == RoutingMethodType.Llama4
else expert_logits
),
routing_bias,
hidden_states_fp8,
static_data["gemm1_weights"],
static_data["scale_c_fc1"],
static_data["scale_gate_fc1"],
static_data["gemm2_weights"],
static_data["scale_c_fc2"],
num_experts,
top_k,
n_groups,
top_k_groups,
intermediate_size,
0,
num_experts,
routed_scaling,
routing_method_type
== RoutingMethodType.Llama4, # Use_routing_scales_on_input
tile_tokens_dim,
routing_method_type,
)
return output.to(torch.float)
def compute_reference(self, args):
"""FP8 per-tensor reference implementation."""
return run_moe_reference_per_tensor_scale_fp8(args)
def get_tolerances(self):
"""Get FP8 per-tensor accuracy tolerances."""
return {"atol": 0.1, "rtol": 0.85, "percent": 0.925}
# ====================================================================================
# Quantizer Factory
# ====================================================================================
def get_moe_impl(quant_mode: QuantMode):
"""Factory function to get the appropriate MoE implementation."""
if quant_mode == QuantMode.FP8_BLOCK_SCALE:
return FP8BlockScaleMoe()
elif quant_mode == QuantMode.FP8_PER_TENSOR:
return FP8PerTensorMoe()
else:
return FP4Moe(quant_mode)
class moe_args:
"""Arguments container for MoE operations."""
def __init__(
self,
num_tokens,
num_experts,
hidden_size,
intermediate_size,
top_k,
padding,
hidden_states,
hidden_states_scale,
hidden_states_scale_global,
expert_logits,
gemm1_weights,
gemm1_scales,
gemm1_scales_global,
gemm2_weights,
gemm2_scales,
gemm2_scales_global,
permute_info,
use_routing_scales_on_input,
gated_act_type,
):
self.num_tokens = num_tokens
self.num_experts = num_experts
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.top_k = top_k
self.padding = padding
self.hidden_states = hidden_states
self.hidden_states_scale = hidden_states_scale
self.hidden_states_scale_global = hidden_states_scale_global
self.expert_logits = expert_logits
self.gemm1_weights = gemm1_weights
self.gemm1_scales = gemm1_scales
self.gemm1_scales_global = gemm1_scales_global
self.gemm2_weights = gemm2_weights
self.gemm2_scales = gemm2_scales
self.gemm2_scales_global = gemm2_scales_global
self.permute_info = permute_info
self.use_routing_scales_on_input = use_routing_scales_on_input
self.gated_act_type = gated_act_type
class moe_args_dequant:
"""Arguments container for dequantized MoE operations."""
def __init__(
self,
num_tokens,
num_experts,
hidden_size,
intermediate_size,
top_k,
padding,
hidden_states,
expert_logits,
gemm1_weights,
gemm2_weights,
permute_info,
use_routing_scales_on_input,
gated_act_type,
):
self.num_tokens = num_tokens
self.num_experts = num_experts
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.top_k = top_k
self.padding = padding
self.hidden_states = hidden_states
self.expert_logits = expert_logits
self.gemm1_weights = gemm1_weights
self.gemm2_weights = gemm2_weights
self.permute_info = permute_info
self.use_routing_scales_on_input = use_routing_scales_on_input
self.gated_act_type = gated_act_type
def routing_reference(expertLogits, topK, padding):
"""Reference routing implementation for permutation calculation."""
originalDevice = expertLogits.device
expertLogits = expertLogits.cpu()
numTokens, numExperts = expertLogits.shape
assert topK <= numExperts
numTokensPerExpert = torch.zeros(numExperts, dtype=torch.int64)
expandedTokenIdxToExpert = -torch.ones(numTokens * topK, dtype=torch.int64)
expandedTokenIdxToIdxInExpert = -torch.ones(numTokens * topK, dtype=torch.int64)
topKLogits, topKIndices = torch.topk(expertLogits, topK, dim=1)
for tokenIdx in range(numTokens):
for k in range(topK):
expandedIdx = tokenIdx * topK + k
expertIndex = topKIndices[tokenIdx, k]
expandedTokenIdxToExpert[expandedIdx] = expertIndex
expandedTokenIdxToIdxInExpert[expandedIdx] = numTokensPerExpert[expertIndex]
numTokensPerExpert[expertIndex] += 1
paddedTokensPerExpertPrefixSum = torch.zeros(numExperts + 1, dtype=torch.int64)
for ii in range(numExperts):
def divUpMul(a, b):
return (a + b - 1) // b * b
paddedTokensPerExpertPrefixSum[ii + 1] = paddedTokensPerExpertPrefixSum[
ii
] + divUpMul(numTokensPerExpert[ii], padding)
permutedBufferSize = paddedTokensPerExpertPrefixSum[numExperts]
expandedTokenIdxToPermutedIdx = -torch.ones(numTokens * topK, dtype=torch.int64)
permutedIdxToExpandedIdx = -torch.ones(permutedBufferSize, dtype=torch.int64)
permutedIdxToTokenIdx = -torch.ones(permutedBufferSize, dtype=torch.int64)
for tokenIdx in range(numTokens):
for k in range(topK):
expandedIdx = tokenIdx * topK + k
expert = expandedTokenIdxToExpert[expandedIdx]
offsetWithinExpert = expandedTokenIdxToIdxInExpert[expandedIdx]
offsetForExpert = paddedTokensPerExpertPrefixSum[expert]
permutedIdx = offsetForExpert + offsetWithinExpert
expandedTokenIdxToPermutedIdx[expandedIdx] = permutedIdx
permutedIdxToExpandedIdx[permutedIdx] = expandedIdx
permutedIdxToTokenIdx[permutedIdx] = tokenIdx
return {
"paddedTokensPerExpertPrefixSum": paddedTokensPerExpertPrefixSum.to(
originalDevice
),
"permutedBufferSize": permutedBufferSize.item(),
"expandedTokenIdxToPermutedIdx": expandedTokenIdxToPermutedIdx.to(
originalDevice
),
"permutedIdxToExpandedIdx": permutedIdxToExpandedIdx.to(originalDevice),
"numTokensPerExpert": numTokensPerExpert.to(originalDevice),
"expandedTokenIdxToExpert": expandedTokenIdxToExpert.to(originalDevice),
"topKLogits": topKLogits.to(originalDevice),
"permutedIdxToTokenIdx": permutedIdxToTokenIdx.to(originalDevice),
"topKIndices": topKIndices.to(originalDevice),
}
def noaux_tc_ref(logits, bias, n_group, topk_group, top_k, routed_scaling_factor):
"""DeepSeek-style no-aux routing reference implementation."""
scores = F.sigmoid(logits)
scores_with_bias = scores + bias
if n_group > 1:
scores_shape = list(scores_with_bias.shape)
group_scores = torch.sum(
torch.topk(
scores_with_bias.view(
scores_shape[:-1] + [n_group, scores_shape[-1] // n_group]
),
k=2,
dim=-1,
largest=True,
sorted=True,
)[0],
dim=-1,
)
_, group_idx = torch.topk(
group_scores, k=topk_group, dim=-1, largest=True, sorted=True
)
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(-1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1)
.expand(scores_shape[:-1] + [n_group, scores_shape[-1] // n_group])
.reshape(scores_shape)
)
scores_with_bias = scores_with_bias * score_mask
_, topk_idx = torch.topk(
scores_with_bias, k=top_k, dim=-1, largest=True, sorted=True
)
new_mask = torch.zeros_like(scores)
new_mask.scatter_(-1, topk_idx, 1)
scores = scores * new_mask
score_sum = torch.sum(scores, dim=-1, keepdim=True) + 1e-20
scores = scores / score_sum * routed_scaling_factor
return scores
def routing_reference_no_aux(
expert_logits,
routing_bias,
top_k,
n_groups,
top_k_groups,
routed_scaling,
padding,
use_routing_scales_on_input=False,
):
"""Tiered TopK routing used by DeepSeek."""
routing_logits = expert_logits.to(dtype=torch.float, device="cuda")
if use_routing_scales_on_input:
# if using routing scales on input, topK == 1 and the score is a plain sigmoid
scores = F.sigmoid(routing_logits)
else:
scores = noaux_tc_ref(
routing_logits, routing_bias, n_groups, top_k_groups, top_k, routed_scaling
)
permute_info = routing_reference(scores, top_k, padding)
return permute_info, scores
def routing_reference_renormalize(expert_logits, top_k, num_experts, padding):
"""TopK -> Softmax routing reference."""
topk_values, topk_idx = torch.topk(expert_logits, k=top_k, dim=-1)
topk_values = torch.nn.functional.softmax(topk_values.float(), dim=-1)
new_mask = torch.zeros_like(expert_logits)
new_mask.scatter_(-1, topk_idx, 1)
scores = expert_logits * new_mask
for i in range(topk_idx.shape[0]):
for j in range(topk_idx.shape[1]):
scores[i, topk_idx[i, j]] = topk_values[i, j]
permute_info = routing_reference(scores, top_k, padding)
return permute_info, scores
def routing_reference_renormalize_naive(expert_logits, top_k, num_experts, padding):
"""Softmax->TopK -> Normalize routing reference."""
norm_topk_prob = True
scores = torch.nn.functional.softmax(expert_logits.float(), dim=-1)
topk_values, topk_idx = torch.topk(scores, k=top_k, dim=-1)
if norm_topk_prob: # only diff with mixtral sparse moe block!
topk_values /= topk_values.sum(dim=-1, keepdim=True)
topk_values = topk_values.to(expert_logits.dtype)
scores = scores.to(expert_logits.dtype)
new_mask = torch.zeros_like(expert_logits)
new_mask.scatter_(-1, topk_idx, 1)
scores = expert_logits * new_mask
for i in range(topk_idx.shape[0]):
for j in range(topk_idx.shape[1]):
scores[i, topk_idx[i, j]] = topk_values[i, j]
permute_info = routing_reference(scores, top_k, padding)
return permute_info, scores
def routing_reference_topk(expert_logits, top_k, num_experts, padding):
"""TopK only (no softmax) routing reference."""
topk_values, topk_idx = torch.topk(expert_logits, k=top_k, dim=-1)
new_mask = torch.zeros_like(expert_logits)
new_mask.scatter_(-1, topk_idx, 1)
scores = expert_logits * new_mask
for i in range(topk_idx.shape[0]):
for j in range(topk_idx.shape[1]):
scores[i, topk_idx[i, j]] = topk_values[i, j]
permute_info = routing_reference(scores, top_k, padding)
return permute_info, scores
def check_accuracy(a, b, atol, rtol, percent):
"""Unified accuracy checking function with detailed error reporting."""
if torch.any(torch.isnan(a)):
raise Exception("NaN in reference output")
if torch.any(torch.isnan(b)):
raise Exception("NaN in actual output")
if torch.any(torch.isinf(a)):
raise Exception("Inf in reference output")
if torch.any(torch.isinf(b)):
raise Exception("Inf in actual output")
assert a.shape == b.shape, f"Shape mismatch: {a.shape} vs {b.shape}"
left = torch.abs(a - b)
right = atol + rtol * torch.abs(b)
count = torch.sum(left > right)
mismatch_percent = count / a.numel()
if mismatch_percent > 1 - percent:
raise Exception(
f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} "
f"(threshold: {1 - percent:.4f})"
)
# ====================================================================================
# FP4 Quantization Functions
# ====================================================================================
def calculate_fp4_global_scale_factor(tensor, use_ue8m0=False):
"""
Calculate FP4 global scale factor for a tensor.
NOTE: In production, global scale factors are typically obtained offline during:
- Post-Training Quantization (PTQ) calibration process
- Quantization-Aware Training (QAT) process
This function is used here for testing/reference purposes.
Formula: (448 * 6) represents max representable value in FP4 format.
"""
if use_ue8m0:
return torch.tensor(1.0, dtype=torch.float32)
else:
return (448 * 6) / tensor.float().abs().nan_to_num().max()
def e2m1_and_ufp8_scale_batches(
mat_fp4: torch.Tensor,
scale_tensor: torch.Tensor,
global_scale_tensor: torch.Tensor,
sf_vec_size: int,
ufp8_type: int = 1,
):
"""Batch FP4 dequantization helper."""
num_batches = mat_fp4.size(0)
scale_tensor = scale_tensor.view(num_batches, -1)
tensors = [
e2m1_and_ufp8sf_scale_to_float(
mat_fp4[b, :, :].cpu(),
scale_tensor[b, :].cpu().reshape(-1),
global_scale_tensor[b].cpu(),
sf_vec_size,
ufp8_type,
True, # is_sf_swizzled_layout
)
for b in range(num_batches)
]
result = torch.stack(tensors)
return result
def quant_fp4(a, a_global_sf, use_ue8m0=False, is_sf_swizzled_layout=True):
"""
Quantize FP4 with pre-computed global scale factor.
This function expects global scale factors that have been pre-computed offline
during PTQ/QAT calibration process. The global scale factor should NOT be
computed at runtime to avoid performance overhead.
Pure function - same inputs always produce same outputs.
"""
sf_vec_size = 32 if use_ue8m0 else 16
a_fp4, a_sf = fp4_quantize(
a.cuda(), a_global_sf.cuda(), sf_vec_size, use_ue8m0, is_sf_swizzled_layout
)
return a_fp4, a_sf, a_global_sf
def quant_fp4_batches(a, num_experts, use_ue8m0=False, is_sf_swizzled_layout=True):
"""FP4 batch quantization function with centralized global scale factor calculation."""
quant_a = []
sfs = []
global_sfs = []
for i in range(num_experts):
# Use centralized global scale factor calculation
a_global_sf = calculate_fp4_global_scale_factor(a[i], use_ue8m0)
a_fp4, a_sf, _ = quant_fp4(a[i], a_global_sf, use_ue8m0, is_sf_swizzled_layout)
quant_a.append(a_fp4)
sfs.append(a_sf)
global_sfs.append(a_global_sf)
result_quant_a = torch.stack(quant_a)
result_sfs = torch.stack(sfs)
result_global_sfs = torch.stack(global_sfs)
return result_quant_a, result_sfs, result_global_sfs
def quant_dequant_fp4(a, use_ue8m0=False, is_sf_swizzled_layout=True):
"""FP4 quantize-dequantize roundtrip function with centralized global scale factor calculation."""
# Use centralized global scale factor calculation
a_global_sf = calculate_fp4_global_scale_factor(a, use_ue8m0)
sf_vec_size = 32 if use_ue8m0 else 16
a_fp4, a_sf = fp4_quantize(
a.cuda(), a_global_sf.cuda(), sf_vec_size, use_ue8m0, is_sf_swizzled_layout
)
a_pt = e2m1_and_ufp8sf_scale_to_float(
a_fp4.cpu(),
a_sf.cpu().reshape(-1),
(1 / a_global_sf).cpu(),
sf_vec_size,
1 if not use_ue8m0 else 0, # ufp8_type
is_sf_swizzled_layout,
)
return a_pt.cuda(), a_global_sf
# ====================================================================================
# FP8 Quantization Functions
# ====================================================================================
def calculate_fp8_global_scale_factor(tensor):
"""
Calculate FP8 global scale factor for a tensor.
NOTE: In production, global scale factors are typically obtained offline during:
- Post-Training Quantization (PTQ) calibration process
- Quantization-Aware Training (QAT) process
This function is used here for testing/reference purposes.
Formula: 448 represents max representable value in FP8 E4M3 format.
"""
return 448 / tensor.float().abs().nan_to_num().max()
def quant_fp8_per_tensor(a, a_global_sf):
"""
Quantize FP8 per-tensor with pre-computed global scale factor.
This function expects global scale factors that have been pre-computed offline
during PTQ/QAT calibration process. The global scale factor should NOT be
computed at runtime to avoid performance overhead.
Pure function - same inputs always produce same outputs.
"""
a_fp8 = (a * a_global_sf).to(torch.float8_e4m3fn)
return a_fp8, a_global_sf
def quant_fp8_per_tensor_batches(a):
"""FP8 per-tensor batch quantization function with centralized global scale factor calculation."""
num_batches = a.size(0)
a_quant = []
a_scales = []
for i in range(num_batches):
# Use centralized global scale factor calculation
a_global_sf = calculate_fp8_global_scale_factor(a[i])
a_fp8, _ = quant_fp8_per_tensor(a[i], a_global_sf)
a_quant.append(a_fp8)
a_scales.append(a_global_sf)
result_a_quant = torch.stack(a_quant)
result_a_scales = torch.stack(a_scales)
return result_a_quant, result_a_scales
def quant_dequant_per_tensor_fp8(a):
"""FP8 per-tensor quantize-dequantize roundtrip function with centralized global scale factor calculation."""
# Use centralized global scale factor calculation
a_global_sf = calculate_fp8_global_scale_factor(a)
a_fp8, _ = quant_fp8_per_tensor(a, a_global_sf)
a_pt = a_fp8.to(torch.float) / a_global_sf
return a_pt.cuda(), a_global_sf
def dequant_reference_dsfp8(input, scale, transpose_scale, block_m, block_n):
"""Reference FP8 block-scale dequantization."""
input = input.to(torch.float)
scale = scale.to(torch.float)
if transpose_scale:
scale = scale.t()
m, n = input.shape
m_tile = 128 if block_m else 1
n_tile = 128 if block_n else 1
assert m % m_tile == 0
assert n % n_tile == 0
assert scale.shape == (m // m_tile, n // n_tile)
# Expand scale to match input dimensions using tensor operations
if m_tile > 1:
scale = torch.repeat_interleave(scale, m_tile, dim=0)
if n_tile > 1:
scale = torch.repeat_interleave(scale, n_tile, dim=1)
# Element-wise multiplication (equivalent to the nested loop logic)
output = input * scale
return output
# ====================================================================================
# Common MoE Reference Implementation
# ====================================================================================
def run_moe_dequant(args, quant_mode: QuantMode):
"""Common dequantized MoE reference implementation."""
# Permute
total_num_padded_tokens = args.permute_info["permutedBufferSize"]
expanded_idx_to_permuted_idx = args.permute_info[
"expandedTokenIdxToPermutedIdx"
].cpu()
num_tokens_per_expert = args.permute_info["numTokensPerExpert"].cpu()
permute_output = torch.full(
(total_num_padded_tokens, args.hidden_size), float("nan"), device="cuda"
).to(torch.float)
for i in range(args.num_tokens):
for j in range(args.top_k):
permuted_idx = expanded_idx_to_permuted_idx[i * args.top_k + j]
permute_output[permuted_idx] = args.hidden_states[i]
# Gemm1
gemm1_output = torch.full(
(total_num_padded_tokens, 2 * args.intermediate_size),
float("nan"),
device="cuda",
).to(torch.float)
i = 0
for expert_idx in range(args.num_experts):
my_num_tokens = num_tokens_per_expert[expert_idx]
if my_num_tokens == 0:
continue
my_a = permute_output[i : i + my_num_tokens]
my_b = args.gemm1_weights[expert_idx]
my_c = my_a @ my_b.t()
gemm1_output[i : i + my_num_tokens] = my_c
i += my_num_tokens
i = (i + args.padding - 1) // args.padding * args.padding
if args.use_routing_scales_on_input:
assert args.top_k == 1
# For each token and its top_k experts
for token_idx in range(args.num_tokens):
for k in range(args.top_k):
# Get the permuted index for this token's k-th expert
expanded_idx = token_idx * args.top_k + k
permuted_idx = expanded_idx_to_permuted_idx[expanded_idx]
expert_weight = args.permute_info["topKLogits"].to(torch.float)
# Get the expert weight for this token and expert
weight = expert_weight[token_idx, k]
# Scale the corresponding row in gemm1_output
gemm1_output[permuted_idx] *= weight
# Activation
activation_output = torch.full(
(total_num_padded_tokens, args.intermediate_size), float("nan"), device="cuda"
).to(torch.float)
gated_act_type = args.gated_act_type
gated_act_type_to_func = {
0: F.silu,
1: F.gelu,
}
gated_act_func = gated_act_type_to_func[gated_act_type]
i = 0
for expert_idx in range(args.num_experts):
my_num_tokens = num_tokens_per_expert[expert_idx]
if my_num_tokens == 0:
continue
my_a = gemm1_output[i : i + my_num_tokens]
my_x1 = my_a[:, : args.intermediate_size]
my_x2 = my_a[:, args.intermediate_size :]
activation_output[i : i + my_num_tokens] = gated_act_func(my_x2) * my_x1
i += my_num_tokens
i = (i + args.padding - 1) // args.padding * args.padding
if quant_mode == QuantMode.FP4_NVFP4_NVFP4:
# Use centralized function for activation quantization
activation_output, c_global_sf = quant_dequant_fp4(
activation_output.to(torch.bfloat16), False, True
)
activation_output = activation_output.to(torch.float)
args.c_global_sf = c_global_sf
elif quant_mode == QuantMode.FP8_PER_TENSOR:
activation_output, c_global_sf = quant_dequant_per_tensor_fp8(
activation_output.to(torch.bfloat16)
)
activation_output = activation_output.to(torch.float)
args.c_global_sf = c_global_sf
elif quant_mode == QuantMode.FP4_MXFP4_MXFP8:
activation_output, scale_bytes = mxfp8_quantize(
activation_output.to(torch.bfloat16), True
)
scale_bytes = scale_bytes.view(torch.uint8).reshape(-1).cpu()
activation_output = (
mxfp8_dequantize_host(
activation_output.cpu().view(torch.uint8), scale_bytes
)
.cuda()
.to(torch.float)
)
args.c_global_sf = 1.0
else: # mxfp4Bf16
activation_output = activation_output.to(torch.bfloat16).to(torch.float)
args.c_global_sf = 1.0
# Gemm2
gemm2_output = torch.full(
(total_num_padded_tokens, args.hidden_size), float("nan"), device="cuda"
).to(torch.float)
i = 0
for expert_idx in range(args.num_experts):
my_num_tokens = num_tokens_per_expert[expert_idx]
if my_num_tokens == 0:
continue
my_a = activation_output[i : i + my_num_tokens]
my_b = args.gemm2_weights[expert_idx]
my_c = my_a @ my_b.t()
gemm2_output[i : i + my_num_tokens] = my_c
i += my_num_tokens
i = (i + args.padding - 1) // args.padding * args.padding
# Finalize
expert_weight = args.permute_info["topKLogits"].to(torch.float)
finalize_output = torch.full(
(args.num_tokens, args.hidden_size), float("nan"), device="cuda"
).to(torch.float)
for i in range(args.num_tokens):
acc = torch.zeros(args.hidden_size, dtype=torch.float, device="cuda")
for top_k_idx in range(args.top_k):
expanded_idx = i * args.top_k + top_k_idx
permuted_idx = expanded_idx_to_permuted_idx[expanded_idx]
original_vector = gemm2_output[permuted_idx]
weight = (
expert_weight[i, top_k_idx]
if not args.use_routing_scales_on_input
else 1.0
)
acc += original_vector * weight
finalize_output[i] = acc
return finalize_output
# ====================================================================================
# Quantization-Specific Reference Implementations
# ====================================================================================
def run_moe_reference_fp4(args, quant_mode: QuantMode):
sf_vec_size = 16 if quant_mode == QuantMode.FP4_NVFP4_NVFP4 else 32
ufp8_type_weights = 1 if quant_mode == QuantMode.FP4_NVFP4_NVFP4 else 0
if quant_mode == QuantMode.FP4_NVFP4_NVFP4:
hidden_states_dequant = e2m1_and_ufp8sf_scale_to_float(
args.hidden_states.cpu(),
args.hidden_states_scale.cpu().view(torch.uint8).reshape(-1),
(1 / args.hidden_states_scale_global).cpu(),
sf_vec_size,
ufp8_type_weights,
True, # is_sf_swizzled_layout
).cuda()
elif quant_mode == QuantMode.FP4_MXFP4_MXFP8:
hidden_states_dequant = mxfp8_dequantize_host(
args.hidden_states.cpu().view(torch.uint8),
args.hidden_states_scale.cpu().view(torch.uint8).reshape(-1),
True, # is_sf_swizzled_layout
).cuda()
else:
hidden_states_dequant = args.hidden_states.to(torch.bfloat16).to(torch.float)
gemm1_weights_dequant = e2m1_and_ufp8_scale_batches(
args.gemm1_weights,
args.gemm1_scales,
1 / args.gemm1_scales_global,
sf_vec_size,
ufp8_type_weights,
).cuda()
gemm2_weights_dequant = e2m1_and_ufp8_scale_batches(
args.gemm2_weights,
args.gemm2_scales,
1 / args.gemm2_scales_global,
sf_vec_size,
ufp8_type_weights,
).cuda()
args_dequant = moe_args_dequant(
args.num_tokens,
args.num_experts,
args.hidden_size,
args.intermediate_size,
args.top_k,
args.padding,
hidden_states_dequant,
args.expert_logits,
gemm1_weights_dequant,
gemm2_weights_dequant,
args.permute_info,
args.use_routing_scales_on_input,
args.gated_act_type,
)
return run_moe_dequant(args_dequant, quant_mode), args_dequant
def run_moe_reference_dsfp8(args):
"""FP8 block-scale reference implementation."""
# Generate block scales at runtime for FP8 block scaling
# Use deterministic scales for testing consistency
hidden_states_scale = 2.0 * torch.ones(
(args.hidden_size // 128, args.num_tokens), device="cuda", dtype=torch.float
)
hidden_states_dequant = dequant_reference_dsfp8(
args.hidden_states, hidden_states_scale, True, False, True
)
gemm1_weights_dequant = {}
for i in range(args.num_experts):
gemm1_weights_dequant[i] = dequant_reference_dsfp8(
args.gemm1_weights[i], args.gemm1_scales[i], False, True, True
)
gemm2_weights_dequant = {}
for i in range(args.num_experts):
gemm2_weights_dequant[i] = dequant_reference_dsfp8(
args.gemm2_weights[i], args.gemm2_scales[i], False, True, True
)
args_dequant = moe_args_dequant(
args.num_tokens,
args.num_experts,
args.hidden_size,
args.intermediate_size,
args.top_k,
args.padding,
hidden_states_dequant,
args.expert_logits,
gemm1_weights_dequant,
gemm2_weights_dequant,
args.permute_info,
args.use_routing_scales_on_input,
GatedActType.SwiGlu.value, # gated_act_type
)
return run_moe_dequant(args_dequant, QuantMode.FP8_BLOCK_SCALE), args_dequant
def run_moe_reference_per_tensor_scale_fp8(args):
"""FP8 per-tensor reference implementation."""
hidden_states_dequant = (
args.hidden_states.to(torch.float) / args.hidden_states_scale_global
)
gemm1_weights_dequant = {}
for i in range(args.num_experts):
gemm1_weights_dequant[i] = (
args.gemm1_weights[i].to(torch.float) / args.gemm1_scales_global[i]
)
gemm2_weights_dequant = {}
for i in range(args.num_experts):
gemm2_weights_dequant[i] = (
args.gemm2_weights[i].to(torch.float) / args.gemm2_scales_global[i]
)
args_dequant = moe_args_dequant(
args.num_tokens,
args.num_experts,
args.hidden_size,
args.intermediate_size,
args.top_k,
args.padding,
hidden_states_dequant,
args.expert_logits,
gemm1_weights_dequant,
gemm2_weights_dequant,
args.permute_info,
args.use_routing_scales_on_input,
GatedActType.SwiGlu.value, # gated_act_type
)
return run_moe_dequant(args_dequant, QuantMode.FP8_PER_TENSOR), args_dequant
def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs):
"""Unified actual computation that delegates to implementation-specific methods."""
# 1. Prepare static weights for the kernel (offline processing)
static_data = moe_impl.prepare_static_weights_for_kernel(
args_dequant,
args,
kwargs["gemm1_weights_orig"],
kwargs["gemm2_weights_orig"],
args.hidden_size,
args.intermediate_size,
args.num_experts,
kwargs["weight_processing"],
)
# 2. Call MoE with runtime input quantization + kernel execution
kernel_kwargs = {
"expert_logits": kwargs["expert_logits"],
"routing_bias": kwargs["routing_bias"],
"num_experts": args.num_experts,
"num_tokens": args.num_tokens,
"hidden_size": args.hidden_size,
"top_k": args.top_k,
"n_groups": kwargs["n_groups"],
"top_k_groups": kwargs["top_k_groups"],
"intermediate_size": args.intermediate_size,
"routed_scaling": kwargs["routed_scaling"],
"routing_method_type": kwargs["routing_method_type"],
"tile_tokens_dim": kwargs["tile_tokens_dim"],
"do_finalize": True,
"gated_act_type": args.gated_act_type,
}
return moe_impl.call_moe(
static_data,
kwargs["hidden_states_orig"],
args.hidden_states_scale_global,
**kernel_kwargs,
)
@pytest.fixture(scope="module")
def cache_permute_indices():
_cache_permute_indices: Dict[torch.Size, torch.Tensor] = {}
return _cache_permute_indices
@pytest.mark.parametrize("num_tokens", [1, 8, 1024])
@pytest.mark.parametrize("hidden_size", [1024, 8192])
@pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 384])
@pytest.mark.parametrize(
"moe_impl",
[
pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"),
pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"),
pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"),
pytest.param(FP8BlockScaleMoe(), id="FP8_Block"),
pytest.param(FP8PerTensorMoe(), id="FP8_Tensor"),
],
)
@pytest.mark.parametrize(
"routing_config",
[
pytest.param(
{
"num_experts": 256,
"top_k": 8,
"padding": 8,
"n_groups": 8,
"top_k_groups": 4,
"routed_scaling": 2.5,
"has_routing_bias": True,
"routing_method_type": RoutingMethodType.DeepSeekV3,
"compatible_moe_impls": [
FP4Moe,
FP8BlockScaleMoe,
],
},
id="DSv3",
),
pytest.param(
{
"num_experts": 72,
"top_k": 6,
"padding": 8,
"n_groups": 1,
"top_k_groups": 1,
"routed_scaling": 2.5,
"has_routing_bias": True,
"routing_method_type": RoutingMethodType.DeepSeekV3,
"compatible_moe_impls": [
FP4Moe,
FP8BlockScaleMoe,
],
},
id="DSLite",
),
pytest.param(
{
"num_experts": 128,
"top_k": 8,
"padding": 8,
"n_groups": None,
"top_k_groups": None,
"routed_scaling": None,
"has_routing_bias": False,
"routing_method_type": RoutingMethodType.Renormalize,
"compatible_moe_impls": [FP4Moe, FP8PerTensorMoe, FP8BlockScaleMoe],
},
id="Renorm",
marks=pytest.mark.skip(
reason="Disabled for testing speed - similar to RenormalizeNaive"
),
),
pytest.param(
{
"num_experts": 128,
"top_k": 8,
"padding": 8,
"n_groups": None,
"top_k_groups": None,
"routed_scaling": None,
"has_routing_bias": False,
"routing_method_type": RoutingMethodType.RenormalizeNaive,
"compatible_moe_impls": [FP4Moe],
},
id="RenormNaive",
),
pytest.param(
{
"num_experts": 16,
"top_k": 2,
"padding": 8,
"n_groups": None,
"top_k_groups": None,
"routed_scaling": None,
"has_routing_bias": False,
"routing_method_type": RoutingMethodType.TopK,
"compatible_moe_impls": [FP4Moe],
},
id="TopK",
),
pytest.param(
{
"num_experts": 128,
"top_k": 1,
"padding": 8,
"n_groups": 0,
"top_k_groups": 0,
"routed_scaling": 2.5,
"has_routing_bias": True,
"routing_method_type": RoutingMethodType.Llama4,
"compatible_moe_impls": [FP8PerTensorMoe],
},
id="Llama4",
),
],
)
@pytest.mark.parametrize(
"weight_processing",
[
pytest.param(
{
"use_shuffled_weight": False,
"layout": WeightLayout.MajorK,
"compatible_moe_impls": [FP8BlockScaleMoe],
},
id="NoShuffle_MajorK",
),
pytest.param(
{
"use_shuffled_weight": True,
"layout": WeightLayout.MajorK,
"compatible_moe_impls": [FP4Moe, FP8PerTensorMoe, FP8BlockScaleMoe],
},
id="Shuffled_MajorK",
),
pytest.param(
{
"use_shuffled_weight": True,
"layout": WeightLayout.BlockMajorK,
"compatible_moe_impls": [FP8BlockScaleMoe],
},
id="Shuffled_BlockMajorK",
),
],
)
@pytest.mark.parametrize(
"gated_act_type",
[
pytest.param(GatedActType.SwiGlu, id="SwiGlu"),
pytest.param(GatedActType.GeGlu, id="GeGlu"),
],
)
def test_moe_quantization_classes(
num_tokens,
hidden_size,
intermediate_size,
moe_impl,
routing_config,
weight_processing,
gated_act_type,
cache_permute_indices,
):
"""
Test MoE implementations using separated quantization workflow.
This test demonstrates the clean separation between:
- Static weight quantization (done offline)
- Dynamic input quantization (done at runtime)
Each quantization class clearly shows which precision is being used.
"""
# Skip incompatible combinations
if gated_act_type == GatedActType.GeGlu and (
type(moe_impl) is not FP4Moe
or moe_impl.quant_mode != QuantMode.FP4_NVFP4_NVFP4
or routing_config["routing_method_type"] != RoutingMethodType.TopK
or num_tokens > 128
):
# GeGlu is only supported for FP4Moe FP4_NVFP4_NVFP4 and TopK routing
pytest.skip(
f"Incompatible: {moe_impl.name} + {gated_act_type} + {routing_config['routing_method_type']} + {num_tokens}"
)
elif gated_act_type == GatedActType.SwiGlu and (
hidden_size > 1024 or intermediate_size > 1024
):
# Skip some tests for SwiGlu for testing speed
pytest.skip(
f"Skip for testing speed: {gated_act_type} + {hidden_size} + {intermediate_size}"
)
if type(moe_impl) not in routing_config["compatible_moe_impls"]:
pytest.skip(
f"Incompatible: {moe_impl.name} + {routing_config['routing_method_type'].name}"
)
if type(moe_impl) not in weight_processing["compatible_moe_impls"]:
pytest.skip(
f"Incompatible: {moe_impl.name} + {weight_processing['use_shuffled_weight']} + {weight_processing['layout']}"
)
moe_impl._cache_permute_indices = cache_permute_indices
seed = 0
torch.random.manual_seed(seed)
# Extract routing configuration
top_k = routing_config["top_k"]
padding = routing_config["padding"]
n_groups = routing_config["n_groups"]
top_k_groups = routing_config["top_k_groups"]
routed_scaling = routing_config["routed_scaling"]
num_experts = routing_config["num_experts"]
routing_method_type = routing_config["routing_method_type"]
tile_tokens_dim = calculate_tile_tokens_dim(num_tokens, num_experts, top_k)
# Validation checks
assert top_k <= num_experts
assert top_k <= 8
if (top_k_groups is not None) and (n_groups is not None) and (n_groups > 0):
assert top_k_groups <= 4
assert num_experts > n_groups
assert num_experts % n_groups == 0
assert num_experts % 4 == 0
assert top_k < (top_k_groups * num_experts / n_groups)
# Create test data based on routing method and quantization mode
# Different kernels have different dtype requirements for routing logits
if routing_method_type == RoutingMethodType.DeepSeekV3:
# DeepSeekV3 uses float for routing logits
expert_logits = torch.randn((num_tokens, num_experts), device="cuda").to(
torch.float
)
else:
# Other routing methods (Renormalize, RenormalizeNaive, Llama4) use bfloat16
expert_logits = torch.randn((num_tokens, num_experts), device="cuda").to(
torch.bfloat16
)
if routing_config["has_routing_bias"]:
routing_bias = torch.randn(num_experts, device="cuda", dtype=torch.bfloat16)
else:
routing_bias = None
hidden_states = 2 * torch.randn(
(num_tokens, hidden_size), device="cuda", dtype=torch.bfloat16
)
gemm1_weights = torch.randn(
(num_experts, 2 * intermediate_size, hidden_size),
device="cuda",
dtype=torch.bfloat16,
)
gemm2_weights = torch.randn(
(num_experts, hidden_size, intermediate_size),
device="cuda",
dtype=torch.bfloat16,
)
# Generate routing info
use_routing_scales_on_input = routing_method_type == RoutingMethodType.Llama4
if routing_method_type == RoutingMethodType.DeepSeekV3:
permute_info, scores = routing_reference_no_aux(
expert_logits,
routing_bias,
top_k,
n_groups,
top_k_groups,
routed_scaling,
padding,
use_routing_scales_on_input,
)
elif routing_method_type == RoutingMethodType.Renormalize:
permute_info, scores = routing_reference_renormalize(
expert_logits, top_k, num_experts, padding
)
elif routing_method_type == RoutingMethodType.RenormalizeNaive:
permute_info, scores = routing_reference_renormalize_naive(
expert_logits, top_k, num_experts, padding
)
elif routing_method_type == RoutingMethodType.TopK:
permute_info, scores = routing_reference_topk(
expert_logits, top_k, num_experts, padding
)
elif routing_method_type == RoutingMethodType.Llama4:
permute_info, scores = routing_reference_no_aux(
expert_logits,
routing_bias,
top_k,
n_groups,
top_k_groups,
routed_scaling,
padding,
use_routing_scales_on_input=True,
)
else:
raise NotImplementedError(
f"Routing method {routing_method_type} not implemented"
)
# 1. Quantize weights offline (static, done once) + compute global scale factors
weights_data = moe_impl.quantize_weights(
gemm1_weights, gemm2_weights, hidden_states
)
# 2. Quantize inputs at runtime (dynamic, done per inference) using pre-computed scales
inputs_data = moe_impl.quantize_inputs(
hidden_states, weights_data["hidden_states_scale_global"]
)
# 3. Combine quantized data
quant_data = {**weights_data, **inputs_data}
# Create arguments for reference computation
args = moe_args(
num_tokens,
num_experts,
hidden_size,
intermediate_size,
top_k,
padding,
quant_data["hidden_states"],
quant_data["hidden_states_scale"],
quant_data["hidden_states_scale_global"],
scores,
quant_data["gemm1_weights"],
quant_data["gemm1_scales"],
quant_data["gemm1_scales_global"],
quant_data["gemm2_weights"],
quant_data["gemm2_scales"],
quant_data["gemm2_scales_global"],
permute_info,
use_routing_scales_on_input,
gated_act_type,
)
# Compute reference output using the moe_impl
output_dequant_reference, args_dequant = moe_impl.compute_reference(args)
# Validate that reference computation succeeded
if output_dequant_reference is None:
pytest.fail("Reference computation failed to produce output")
# Compute actual output using the moe_impl
output_dequant_actual = moe_impl.compute_production(
args_dequant,
args,
expert_logits=expert_logits,
routing_bias=routing_bias,
hidden_states_orig=hidden_states,
gemm1_weights_orig=gemm1_weights,
gemm2_weights_orig=gemm2_weights,
n_groups=n_groups,
top_k_groups=top_k_groups,
routed_scaling=routed_scaling,
routing_method_type=routing_method_type,
tile_tokens_dim=tile_tokens_dim,
weight_processing=weight_processing,
enable_pdl=True,
)
# Compare outputs using moe_impl-specific tolerances
tolerances = moe_impl.get_tolerances()
check_accuracy(
output_dequant_reference,
output_dequant_actual,
atol=tolerances["atol"],
rtol=tolerances["rtol"],
percent=tolerances["percent"],
)
if __name__ == "__main__":
pytest.main([__file__, "-v"])