2121 lines
75 KiB
Python
2121 lines
75 KiB
Python
"""
|
|
Copyright (c) 2025 by FlashInfer team.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from enum import IntEnum
|
|
from typing import Dict
|
|
|
|
import pytest
|
|
import torch
|
|
from cuda.bindings import runtime
|
|
from torch.nn import functional as F
|
|
|
|
from flashinfer import (
|
|
RoutingMethodType,
|
|
GatedActType,
|
|
e2m1_and_ufp8sf_scale_to_float,
|
|
fp4_quantize,
|
|
mxfp8_dequantize_host,
|
|
mxfp8_quantize,
|
|
reorder_rows_for_gated_act_gemm,
|
|
shuffle_matrix_a,
|
|
)
|
|
from flashinfer.fp4_quantization import block_scale_interleave
|
|
from flashinfer.fused_moe import (
|
|
WeightLayout,
|
|
convert_to_block_layout,
|
|
trtllm_fp4_block_scale_moe,
|
|
trtllm_fp8_block_scale_moe,
|
|
trtllm_fp8_per_tensor_scale_moe,
|
|
)
|
|
from flashinfer.fused_moe.core import (
|
|
_maybe_get_cached_w2_permute_indices,
|
|
_maybe_get_cached_w3_w1_permute_indices,
|
|
)
|
|
from flashinfer.utils import calculate_tile_tokens_dim
|
|
|
|
|
|
def check_cuda(err):
|
|
"""Unified CUDA error checking function used throughout the file."""
|
|
if err != runtime.cudaError_t.cudaSuccess:
|
|
error_name = runtime.cudaGetErrorName(err)
|
|
error_string = runtime.cudaGetErrorString(err)
|
|
raise RuntimeError(f"CUDA error: {error_name[1]}: {error_string[1]}")
|
|
|
|
|
|
class CUDAGraphMoE:
|
|
"""
|
|
Simple CUDA Graph wrapper for MoE operations.
|
|
|
|
The graph captures tensor references and automatically updates them during execution.
|
|
|
|
Three core methods: capture(), launch(), cleanup()
|
|
|
|
Usage:
|
|
cuda_graph = CUDAGraphMoE(moe_impl, static_data, **config)
|
|
cuda_graph.capture(hidden_states_sample, expert_logits=logits, routing_bias=bias)
|
|
output = cuda_graph.launch(new_hidden_states) # Repeat as needed
|
|
cuda_graph.cleanup()
|
|
"""
|
|
|
|
def __init__(self, moe_impl, static_data, **config):
|
|
self.moe_impl = moe_impl
|
|
self.static_data = static_data
|
|
self.config = config
|
|
self.graph = None
|
|
self.graph_exec = None
|
|
self.stream = None
|
|
self.input_tensor = None
|
|
self.output_tensor = None
|
|
self.is_captured = False
|
|
|
|
def capture(self, hidden_states_sample, **runtime_args):
|
|
"""Capture CUDA graph with the given sample input."""
|
|
if self.is_captured:
|
|
raise RuntimeError(
|
|
"Graph already captured. Call cleanup() first to re-capture."
|
|
)
|
|
if not isinstance(self.moe_impl, FP4Moe):
|
|
raise NotImplementedError(
|
|
f"CUDA graph capture not yet implemented for {type(self.moe_impl)}"
|
|
)
|
|
|
|
# Create stream
|
|
err, self.stream = runtime.cudaStreamCreate()
|
|
check_cuda(err)
|
|
|
|
# Get the raw stream pointer for PyTorch
|
|
stream_ptr = int(self.stream)
|
|
torch_stream = torch.cuda.ExternalStream(stream_ptr)
|
|
|
|
# Store input tensor reference (will be updated in place during launch)
|
|
self.input_tensor = hidden_states_sample.clone()
|
|
|
|
# Warmup
|
|
with torch.cuda.stream(torch_stream):
|
|
for _ in range(1):
|
|
self._run_moe_computation(runtime_args)
|
|
|
|
# Synchronize our stream after warmup
|
|
err = runtime.cudaStreamSynchronize(self.stream)[0]
|
|
check_cuda(err)
|
|
|
|
# Begin capture
|
|
err, self.graph = runtime.cudaGraphCreate(0)
|
|
check_cuda(err)
|
|
err = runtime.cudaStreamBeginCapture(
|
|
self.stream, runtime.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal
|
|
)[0]
|
|
check_cuda(err)
|
|
|
|
try:
|
|
# Capture computation on our stream
|
|
with torch.cuda.stream(torch_stream):
|
|
self.output_tensor = self._run_moe_computation(runtime_args)
|
|
err, self.graph = runtime.cudaStreamEndCapture(self.stream)
|
|
check_cuda(err)
|
|
err, self.graph_exec = runtime.cudaGraphInstantiate(self.graph, 0)
|
|
check_cuda(err)
|
|
self.is_captured = True
|
|
except Exception as e:
|
|
self.cleanup()
|
|
raise RuntimeError(f"CUDA graph capture failed: {e}") from e
|
|
|
|
def launch(self, hidden_states_new):
|
|
"""Launch captured CUDA graph with new input."""
|
|
if not self.is_captured:
|
|
raise RuntimeError("Graph not captured. Call capture() first.")
|
|
|
|
# Update input tensor in place
|
|
self.input_tensor.copy_(hidden_states_new)
|
|
|
|
# Launch graph
|
|
err = runtime.cudaGraphLaunch(self.graph_exec, self.stream)[0]
|
|
check_cuda(err)
|
|
err = runtime.cudaStreamSynchronize(self.stream)[0]
|
|
check_cuda(err)
|
|
|
|
# Return output tensor (automatically updated by graph execution)
|
|
return self.output_tensor
|
|
|
|
def cleanup(self):
|
|
"""Clean up all CUDA graph resources."""
|
|
if self.graph_exec is not None:
|
|
err = runtime.cudaGraphExecDestroy(self.graph_exec)[0]
|
|
check_cuda(err)
|
|
self.graph_exec = None
|
|
if self.graph is not None:
|
|
err = runtime.cudaGraphDestroy(self.graph)[0]
|
|
check_cuda(err)
|
|
self.graph = None
|
|
if self.stream is not None:
|
|
err = runtime.cudaStreamDestroy(self.stream)[0]
|
|
check_cuda(err)
|
|
self.stream = None
|
|
self.input_tensor = None
|
|
self.output_tensor = None
|
|
self.is_captured = False
|
|
|
|
def _run_moe_computation(self, runtime_args):
|
|
"""Run the MoE computation."""
|
|
input_quantized = self.moe_impl.quantize_inputs(
|
|
self.input_tensor,
|
|
self.config["hidden_states_scale_global"],
|
|
is_swizzling=False,
|
|
)
|
|
|
|
output = trtllm_fp4_block_scale_moe(
|
|
routing_logits=runtime_args["expert_logits"],
|
|
routing_bias=runtime_args["routing_bias"],
|
|
hidden_states=input_quantized["hidden_states"],
|
|
hidden_states_scale=input_quantized["hidden_states_scale"],
|
|
gemm1_weights=self.static_data["gemm1_weights_fp4_shuffled"],
|
|
gemm1_weights_scale=self.static_data["gemm1_scales_fp4_shuffled"],
|
|
gemm1_bias=None,
|
|
gemm1_alpha=None,
|
|
gemm1_beta=None,
|
|
gemm1_clamp_limit=None,
|
|
gemm2_weights=self.static_data["gemm2_weights_fp4_shuffled"],
|
|
gemm2_weights_scale=self.static_data["gemm2_scales_fp4_shuffled"],
|
|
gemm2_bias=None,
|
|
output1_scale_scalar=self.static_data["scale_c_fc1"],
|
|
output1_scale_gate_scalar=self.static_data["scale_gate_fc1"],
|
|
output2_scale_scalar=self.static_data["scale_c_fc2"],
|
|
num_experts=self.config["num_experts"],
|
|
top_k=self.config["top_k"],
|
|
n_group=self.config["n_groups"],
|
|
topk_group=self.config["top_k_groups"],
|
|
intermediate_size=self.config["intermediate_size"],
|
|
local_expert_offset=0,
|
|
local_num_experts=self.config["num_experts"],
|
|
routed_scaling_factor=self.config["routed_scaling"],
|
|
tile_tokens_dim=self.config["tile_tokens_dim"],
|
|
routing_method_type=self.config["routing_method_type"],
|
|
gated_act_type=self.config["gated_act_type"],
|
|
do_finalize=True,
|
|
)
|
|
return output # Extract tensor from tuple
|
|
|
|
|
|
class QuantMode(IntEnum):
|
|
"""Supported quantization modes for MoE testing."""
|
|
|
|
FP4_NVFP4_NVFP4 = 1
|
|
FP4_MXFP4_MXFP8 = 2
|
|
FP4_MXFP4_Bf16 = 3
|
|
FP8_BLOCK_SCALE = 4
|
|
FP8_PER_TENSOR = 5
|
|
|
|
|
|
# ====================================================================================
|
|
# Abstract Base Class for MoE Implementations
|
|
# ====================================================================================
|
|
|
|
|
|
class Moe(ABC):
|
|
"""Abstract base class for MoE implementations."""
|
|
|
|
def __init__(self):
|
|
self.name = self.__class__.__name__
|
|
|
|
@abstractmethod
|
|
def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample):
|
|
"""Quantize static weights and compute global scale factors (done offline)."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def quantize_inputs(self, hidden_states, hidden_states_scale_global):
|
|
"""Quantize dynamic inputs/hidden states using pre-computed global scale (done at runtime)."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def prepare_static_weights_for_kernel(
|
|
self,
|
|
args_dequant,
|
|
args,
|
|
gemm1_weights_orig,
|
|
gemm2_weights_orig,
|
|
hidden_size,
|
|
intermediate_size,
|
|
num_experts,
|
|
weight_processing,
|
|
):
|
|
"""
|
|
Prepare quantized weights for kernel (done offline with weights).
|
|
|
|
Args:
|
|
args_dequant: Contains c_global_sf and other dequantization parameters
|
|
args: Contains already quantized weights (gemm1_weights, gemm2_weights) and scales
|
|
gemm1_weights_orig: Original unquantized FC1 weights (used by FP4 for re-quantization)
|
|
gemm2_weights_orig: Original unquantized FC2 weights (used by FP4 for re-quantization)
|
|
|
|
Note:
|
|
- FP4 implementations use both original weights (for linear layout quantization)
|
|
and args.gemm*_weights (for swizzled layout)
|
|
- FP8 implementations typically only use args.gemm*_weights (already quantized)
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def call_moe(
|
|
self, static_data, hidden_states_orig, hidden_states_scale_global, **kwargs
|
|
):
|
|
"""Call MoE with runtime input quantization + kernel execution (done at runtime)."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def compute_reference(self, args):
|
|
"""Compute reference output using dequantized operations."""
|
|
pass
|
|
|
|
def compute_production(self, args_dequant, args, **kwargs):
|
|
"""Unified actual computation that delegates to implementation-specific methods."""
|
|
return _compute_moe_actual_unified(self, args_dequant, args, **kwargs)
|
|
|
|
@abstractmethod
|
|
def get_tolerances(self):
|
|
"""Get accuracy tolerances for this quantization mode."""
|
|
pass
|
|
|
|
def __str__(self):
|
|
return self.name
|
|
|
|
|
|
# ====================================================================================
|
|
# FP4 Quantization Implementation
|
|
# ====================================================================================
|
|
|
|
|
|
class FP4Moe(Moe):
|
|
"""
|
|
FP4 NvFP4 / MxFP4 MoE implementation with block scaling.
|
|
Args:
|
|
is_mxfp4: Whether to use MxFP4 or NvFP4 weight quantization
|
|
If True, the activation is quantized to MxFP8, else the activation is quantized to NvFP4
|
|
"""
|
|
|
|
def __init__(self, quant_mode: QuantMode):
|
|
super().__init__()
|
|
self.quant_mode = quant_mode
|
|
self.is_mxfp4 = (
|
|
quant_mode == QuantMode.FP4_MXFP4_MXFP8
|
|
or quant_mode == QuantMode.FP4_MXFP4_Bf16
|
|
)
|
|
self.sf_vec_size = 32 if self.is_mxfp4 else 16
|
|
|
|
def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample):
|
|
"""Quantize weights to FP4 format and compute global scale factors."""
|
|
num_experts = gemm1_weights.shape[0]
|
|
# Compute global scale factor for hidden states (offline calibration)
|
|
if self.quant_mode == QuantMode.FP4_NVFP4_NVFP4:
|
|
# nvfp4 hidden states
|
|
hidden_states_scale_global = calculate_fp4_global_scale_factor(
|
|
hidden_states_sample,
|
|
False,
|
|
)
|
|
else:
|
|
# mxfp8 / bf16 hidden states
|
|
hidden_states_scale_global = 1.0
|
|
|
|
# Quantize the weights for FC1
|
|
gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, gemm1_scales_global = (
|
|
quant_fp4_batches(gemm1_weights, num_experts, self.is_mxfp4, True)
|
|
)
|
|
|
|
# Quantize the weights for FC2
|
|
gemm2_weights_fp4_bytes, gemm2_scales_fp4_bytes, gemm2_scales_global = (
|
|
quant_fp4_batches(gemm2_weights, num_experts, self.is_mxfp4, True)
|
|
)
|
|
|
|
return {
|
|
"hidden_states_scale_global": hidden_states_scale_global,
|
|
"gemm1_weights": gemm1_weights_fp4_bytes,
|
|
"gemm1_scales": gemm1_scales_fp4_bytes,
|
|
"gemm1_scales_global": gemm1_scales_global,
|
|
"gemm2_weights": gemm2_weights_fp4_bytes,
|
|
"gemm2_scales": gemm2_scales_fp4_bytes,
|
|
"gemm2_scales_global": gemm2_scales_global,
|
|
}
|
|
|
|
def quantize_inputs(
|
|
self, hidden_states, hidden_states_scale_global, is_swizzling=True
|
|
):
|
|
if self.quant_mode == QuantMode.FP4_MXFP4_MXFP8:
|
|
"""Quantize hidden states to MxFP8 format."""
|
|
hidden_states_quant, hidden_states_scale = mxfp8_quantize(
|
|
hidden_states, is_swizzling
|
|
)
|
|
hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape(
|
|
*hidden_states.shape[:-1], -1
|
|
)
|
|
return {
|
|
"hidden_states": hidden_states_quant,
|
|
"hidden_states_scale": hidden_states_scale,
|
|
}
|
|
elif self.quant_mode == QuantMode.FP4_NVFP4_NVFP4:
|
|
"""Quantize hidden states to NvFP4 format using pre-computed global scale."""
|
|
(
|
|
hidden_states_fp4_bytes,
|
|
hidden_states_scale_fp4_bytes,
|
|
_,
|
|
) = quant_fp4(
|
|
hidden_states, hidden_states_scale_global, False, is_swizzling
|
|
)
|
|
hidden_states_scale_fp4_bytes = hidden_states_scale_fp4_bytes.view(
|
|
torch.float8_e4m3fn
|
|
).reshape(*hidden_states.shape[:-1], -1)
|
|
|
|
return {
|
|
"hidden_states": hidden_states_fp4_bytes,
|
|
"hidden_states_scale": hidden_states_scale_fp4_bytes,
|
|
}
|
|
else: # bf16
|
|
return {
|
|
"hidden_states": hidden_states.to(torch.bfloat16),
|
|
"hidden_states_scale": None,
|
|
}
|
|
|
|
def prepare_static_weights_for_kernel(
|
|
self,
|
|
args_dequant,
|
|
args,
|
|
gemm1_weights_orig,
|
|
gemm2_weights_orig,
|
|
hidden_size,
|
|
intermediate_size,
|
|
num_experts,
|
|
weight_processing,
|
|
):
|
|
"""Prepare quantized weights for kernel (done offline with weights)."""
|
|
use_ue8m0 = self.is_mxfp4
|
|
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
|
|
|
# Quantize weights with linear layout for kernels
|
|
_, gemm1_scales_linear_fp4_bytes, _ = quant_fp4_batches(
|
|
gemm1_weights_orig, num_experts, use_ue8m0, False
|
|
)
|
|
_, gemm2_scales_linear_fp4_bytes, _ = quant_fp4_batches(
|
|
gemm2_weights_orig, num_experts, use_ue8m0, False
|
|
)
|
|
|
|
# Convert quantized weights to proper formats
|
|
gemm1_weights_fp4 = args.gemm1_weights.view(torch.float8_e4m3fn).reshape(
|
|
num_experts, 2 * intermediate_size, hidden_size // 2
|
|
) # packed fp4
|
|
gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
|
|
torch.float8_e4m3fn
|
|
).reshape(
|
|
num_experts, 2 * intermediate_size, hidden_size // self.sf_vec_size
|
|
) # fp8 scaling factors
|
|
|
|
gemm2_weights_fp4 = args.gemm2_weights.view(torch.float8_e4m3fn).reshape(
|
|
num_experts, hidden_size, intermediate_size // 2
|
|
) # packed fp4
|
|
gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
|
|
torch.float8_e4m3fn
|
|
).reshape(
|
|
num_experts, hidden_size, intermediate_size // self.sf_vec_size
|
|
) # fp8 scaling factors
|
|
|
|
# Using cached permute index calculation can speed up weights preprocessing
|
|
gemm1_weights_fp4_shuffled = []
|
|
gemm1_scales_fp4_shuffled = []
|
|
gemm2_weights_fp4_shuffled = []
|
|
gemm2_scales_fp4_shuffled = []
|
|
for i in range(num_experts):
|
|
# Calculate the permute indices for the following:
|
|
# 1. Reorder rows of W1 and scales for fused gated activation
|
|
# 2. Shuffle weights and scaling factors for transposed mma output
|
|
# for both w3_w1 and w2 weights and scale factors
|
|
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
|
|
self._cache_permute_indices,
|
|
gemm1_weights_fp4[i].view(torch.uint8),
|
|
epilogue_tile_m,
|
|
)
|
|
gemm1_weights_fp4_shuffled.append(
|
|
gemm1_weights_fp4[i]
|
|
.view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
|
|
.contiguous()
|
|
)
|
|
|
|
permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
|
|
self._cache_permute_indices,
|
|
gemm1_scales_linear_fp4[i].view(torch.uint8),
|
|
epilogue_tile_m,
|
|
num_elts_per_sf=16,
|
|
)
|
|
gemm1_scales_fp4_shuffled.append(
|
|
block_scale_interleave(
|
|
gemm1_scales_linear_fp4[i]
|
|
.view(torch.uint8)[
|
|
permute_sf_indices.to(gemm1_scales_linear_fp4.device)
|
|
]
|
|
.contiguous()
|
|
)
|
|
)
|
|
|
|
permute_indices = _maybe_get_cached_w2_permute_indices(
|
|
self._cache_permute_indices,
|
|
gemm2_weights_fp4[i].view(torch.uint8),
|
|
epilogue_tile_m,
|
|
)
|
|
gemm2_weights_fp4_shuffled.append(
|
|
gemm2_weights_fp4[i]
|
|
.view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
|
|
.contiguous()
|
|
)
|
|
|
|
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
|
self._cache_permute_indices,
|
|
gemm2_scales_linear_fp4[i].view(torch.uint8),
|
|
epilogue_tile_m,
|
|
num_elts_per_sf=16,
|
|
)
|
|
gemm2_scales_fp4_shuffled.append(
|
|
block_scale_interleave(
|
|
gemm2_scales_linear_fp4[i]
|
|
.view(torch.uint8)[
|
|
permute_sf_indices.to(gemm2_scales_linear_fp4.device)
|
|
]
|
|
.contiguous()
|
|
)
|
|
)
|
|
|
|
# Stack weights for all experts
|
|
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
|
|
gemm1_scales_fp4_shuffled = (
|
|
torch.stack(gemm1_scales_fp4_shuffled)
|
|
.view(torch.float8_e4m3fn)
|
|
.reshape(
|
|
num_experts, 2 * intermediate_size, hidden_size // self.sf_vec_size
|
|
)
|
|
)
|
|
|
|
gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
|
|
gemm2_scales_fp4_shuffled = (
|
|
torch.stack(gemm2_scales_fp4_shuffled)
|
|
.view(torch.float8_e4m3fn)
|
|
.reshape(num_experts, hidden_size, intermediate_size // self.sf_vec_size)
|
|
)
|
|
|
|
# Calculate scaling factors that depend on weights
|
|
scale_c_fc1 = (
|
|
args_dequant.c_global_sf
|
|
* (1.0 / args.gemm1_scales_global)
|
|
* (1.0 / args.hidden_states_scale_global)
|
|
)
|
|
scale_gate_fc1 = (1.0 / args.gemm1_scales_global) * (
|
|
1.0 / args.hidden_states_scale_global
|
|
)
|
|
scale_c_fc2 = (1.0 / args_dequant.c_global_sf) * (
|
|
1.0 / args.gemm2_scales_global
|
|
)
|
|
|
|
return {
|
|
"gemm1_weights_fp4_shuffled": gemm1_weights_fp4_shuffled,
|
|
"gemm1_scales_fp4_shuffled": gemm1_scales_fp4_shuffled,
|
|
"gemm2_weights_fp4_shuffled": gemm2_weights_fp4_shuffled,
|
|
"gemm2_scales_fp4_shuffled": gemm2_scales_fp4_shuffled,
|
|
"scale_c_fc1": scale_c_fc1,
|
|
"scale_gate_fc1": scale_gate_fc1,
|
|
"scale_c_fc2": scale_c_fc2,
|
|
}
|
|
|
|
def call_moe(
|
|
self, static_data, hidden_states_orig, hidden_states_scale_global, **kwargs
|
|
):
|
|
"""Call MoE using CUDA graph for maximum performance (create, capture, launch)."""
|
|
# Extract runtime arguments
|
|
expert_logits = kwargs["expert_logits"]
|
|
routing_bias = kwargs["routing_bias"]
|
|
num_experts = kwargs["num_experts"]
|
|
top_k = kwargs["top_k"]
|
|
n_groups = kwargs["n_groups"]
|
|
top_k_groups = kwargs["top_k_groups"]
|
|
intermediate_size = kwargs["intermediate_size"]
|
|
routed_scaling = kwargs["routed_scaling"]
|
|
gated_act_type = kwargs["gated_act_type"]
|
|
routing_method_type = kwargs["routing_method_type"]
|
|
tile_tokens_dim = kwargs["tile_tokens_dim"]
|
|
|
|
# Create CUDA graph configuration
|
|
config = {
|
|
"hidden_states_scale_global": hidden_states_scale_global,
|
|
"num_experts": num_experts,
|
|
"top_k": top_k,
|
|
"n_groups": n_groups,
|
|
"top_k_groups": top_k_groups,
|
|
"intermediate_size": intermediate_size,
|
|
"routed_scaling": routed_scaling,
|
|
"tile_tokens_dim": tile_tokens_dim,
|
|
"gated_act_type": gated_act_type,
|
|
"routing_method_type": routing_method_type,
|
|
}
|
|
|
|
runtime_args = {
|
|
"expert_logits": expert_logits,
|
|
"routing_bias": routing_bias,
|
|
}
|
|
|
|
# Create, capture and launch CUDA graph in one shot
|
|
cuda_graph = CUDAGraphMoE(self, static_data, **config)
|
|
try:
|
|
cuda_graph.capture(hidden_states_orig, **runtime_args)
|
|
output = cuda_graph.launch(hidden_states_orig)
|
|
return output[0].to(torch.float)
|
|
finally:
|
|
cuda_graph.cleanup()
|
|
|
|
def compute_reference(self, args):
|
|
return run_moe_reference_fp4(args, self.quant_mode)
|
|
|
|
def get_tolerances(self):
|
|
"""Get FP4-specific accuracy tolerances."""
|
|
return {"atol": 0.1, "rtol": 0.85, "percent": 0.925}
|
|
|
|
|
|
# ====================================================================================
|
|
# FP8 Block Scale Quantization Implementation
|
|
# ====================================================================================
|
|
|
|
|
|
class FP8BlockScaleMoe(Moe):
|
|
"""FP8 MoE implementation with block scaling (DeepSeek style)."""
|
|
|
|
def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample):
|
|
"""Quantize weights to FP8 with block scaling."""
|
|
num_experts = gemm1_weights.shape[0]
|
|
intermediate_size = gemm1_weights.shape[1] // 2
|
|
hidden_size = gemm1_weights.shape[
|
|
2
|
|
] # [num_experts, 2*intermediate_size, hidden_size]
|
|
|
|
# Quantize weights to FP8
|
|
gemm1_weights_fp8 = gemm1_weights.to(torch.float8_e4m3fn)
|
|
gemm1_scales = 2 * torch.rand(
|
|
(num_experts, 2 * intermediate_size // 128, hidden_size // 128),
|
|
device="cuda",
|
|
).to(torch.float)
|
|
|
|
gemm2_weights_fp8 = gemm2_weights.to(torch.float8_e4m3fn)
|
|
gemm2_scales = 2 * torch.rand(
|
|
(num_experts, hidden_size // 128, intermediate_size // 128), device="cuda"
|
|
).to(torch.float)
|
|
|
|
return {
|
|
"hidden_states_scale_global": None, # Block scales computed at runtime
|
|
"gemm1_weights": gemm1_weights_fp8,
|
|
"gemm1_scales": gemm1_scales,
|
|
"gemm1_scales_global": None,
|
|
"gemm2_weights": gemm2_weights_fp8,
|
|
"gemm2_scales": gemm2_scales,
|
|
"gemm2_scales_global": None,
|
|
}
|
|
|
|
def quantize_inputs(self, hidden_states, hidden_states_scale_global):
|
|
"""For FP8 block scaling, no pre-quantization - everything happens at runtime."""
|
|
return {
|
|
"hidden_states": hidden_states, # Keep original
|
|
"hidden_states_scale": None, # No pre-computed scales
|
|
}
|
|
|
|
def prepare_static_weights_for_kernel(
|
|
self,
|
|
args_dequant,
|
|
args,
|
|
gemm1_weights_orig,
|
|
gemm2_weights_orig,
|
|
hidden_size,
|
|
intermediate_size,
|
|
num_experts,
|
|
weight_processing,
|
|
):
|
|
"""Prepare quantized weights for kernel (done offline with weights)."""
|
|
|
|
# Use shuffled weights with BlockMajorK layout for better performance
|
|
use_shuffled_weight = weight_processing["use_shuffled_weight"]
|
|
weight_layout = weight_processing["layout"]
|
|
|
|
if use_shuffled_weight:
|
|
# FIXME: this depends on the kernel internals
|
|
epilogue_tile_m = 64
|
|
|
|
gemm1_weights_fp8_shuffled = []
|
|
gemm2_weights_fp8_shuffled = []
|
|
for i in range(num_experts):
|
|
tmp_weights1 = shuffle_matrix_a(
|
|
args.gemm1_weights[i].view(torch.uint8), epilogue_tile_m
|
|
)
|
|
tmp_weights2 = shuffle_matrix_a(
|
|
args.gemm2_weights[i].view(torch.uint8), epilogue_tile_m
|
|
)
|
|
|
|
if weight_layout == WeightLayout.BlockMajorK:
|
|
block_k = 128
|
|
tmp_weights1 = convert_to_block_layout(tmp_weights1, block_k)
|
|
tmp_weights2 = convert_to_block_layout(tmp_weights2, block_k)
|
|
|
|
gemm1_weights_fp8_shuffled.append(tmp_weights1)
|
|
|
|
gemm2_weights_fp8_shuffled.append(tmp_weights2)
|
|
kernel_gemm1_weights = torch.stack(gemm1_weights_fp8_shuffled).view(
|
|
torch.float8_e4m3fn
|
|
)
|
|
kernel_gemm2_weights = torch.stack(gemm2_weights_fp8_shuffled).view(
|
|
torch.float8_e4m3fn
|
|
)
|
|
else:
|
|
kernel_gemm1_weights = args.gemm1_weights
|
|
kernel_gemm2_weights = args.gemm2_weights
|
|
|
|
return {
|
|
"gemm1_weights": kernel_gemm1_weights,
|
|
"gemm1_scales": args.gemm1_scales,
|
|
"gemm2_weights": kernel_gemm2_weights,
|
|
"gemm2_scales": args.gemm2_scales,
|
|
"use_shuffled_weight": use_shuffled_weight,
|
|
"weight_layout": weight_layout,
|
|
}
|
|
|
|
def call_moe(
|
|
self, static_data, hidden_states_orig, hidden_states_scale_global, **kwargs
|
|
):
|
|
"""Call MoE with runtime block scale generation + kernel execution."""
|
|
expert_logits = kwargs["expert_logits"]
|
|
routing_bias = kwargs["routing_bias"]
|
|
num_experts = kwargs["num_experts"]
|
|
num_tokens = kwargs["num_tokens"]
|
|
hidden_size = kwargs["hidden_size"]
|
|
top_k = kwargs["top_k"]
|
|
n_groups = kwargs["n_groups"]
|
|
top_k_groups = kwargs["top_k_groups"]
|
|
intermediate_size = kwargs["intermediate_size"]
|
|
routed_scaling = kwargs["routed_scaling"]
|
|
routing_method_type = kwargs["routing_method_type"]
|
|
tile_tokens_dim = kwargs["tile_tokens_dim"]
|
|
enable_pdl = kwargs.get("enable_pdl")
|
|
|
|
# Generate block scales and quantize hidden states at runtime
|
|
hidden_states_fp8 = hidden_states_orig.to(torch.float8_e4m3fn)
|
|
# Use deterministic scales for testing consistency
|
|
hidden_states_scale = 2.0 * torch.ones(
|
|
(hidden_size // 128, num_tokens), device="cuda", dtype=torch.float
|
|
)
|
|
|
|
output = trtllm_fp8_block_scale_moe(
|
|
expert_logits,
|
|
routing_bias,
|
|
hidden_states_fp8,
|
|
hidden_states_scale,
|
|
static_data["gemm1_weights"],
|
|
static_data["gemm1_scales"],
|
|
static_data["gemm2_weights"],
|
|
static_data["gemm2_scales"],
|
|
num_experts,
|
|
top_k,
|
|
n_groups,
|
|
top_k_groups,
|
|
intermediate_size,
|
|
0,
|
|
num_experts,
|
|
routed_scaling,
|
|
tile_tokens_dim,
|
|
routing_method_type,
|
|
use_shuffled_weight=static_data["use_shuffled_weight"],
|
|
weight_layout=static_data["weight_layout"],
|
|
enable_pdl=enable_pdl,
|
|
)
|
|
|
|
return output.to(torch.float)
|
|
|
|
def compute_reference(self, args):
|
|
"""FP8 block-scale reference implementation."""
|
|
return run_moe_reference_dsfp8(args)
|
|
|
|
def get_tolerances(self):
|
|
"""Get FP8 block-scale accuracy tolerances."""
|
|
return {"atol": 0.1, "rtol": 0.85, "percent": 0.925}
|
|
|
|
|
|
# ====================================================================================
|
|
# FP8 Per-Tensor Quantization Implementation
|
|
# ====================================================================================
|
|
|
|
|
|
class FP8PerTensorMoe(Moe):
|
|
"""FP8 MoE implementation with per-tensor scaling (Llama4 style)."""
|
|
|
|
def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample):
|
|
"""Quantize weights to FP8 per-tensor and compute global scale factors."""
|
|
# Compute global scale factor for hidden states (offline calibration)
|
|
hidden_states_global_scale = calculate_fp8_global_scale_factor(
|
|
hidden_states_sample
|
|
)
|
|
|
|
# Quantize to FP8 per-tensor
|
|
gemm1_weights_quant, gemm1_global_scales = quant_fp8_per_tensor_batches(
|
|
gemm1_weights
|
|
)
|
|
gemm2_weights_quant, gemm2_global_scales = quant_fp8_per_tensor_batches(
|
|
gemm2_weights
|
|
)
|
|
|
|
return {
|
|
"hidden_states_scale_global": hidden_states_global_scale,
|
|
"gemm1_weights": gemm1_weights_quant,
|
|
"gemm1_scales": None,
|
|
"gemm1_scales_global": gemm1_global_scales,
|
|
"gemm2_weights": gemm2_weights_quant,
|
|
"gemm2_scales": None,
|
|
"gemm2_scales_global": gemm2_global_scales,
|
|
}
|
|
|
|
def quantize_inputs(self, hidden_states, hidden_states_scale_global):
|
|
"""Quantize hidden states to FP8 per-tensor using pre-computed global scale."""
|
|
# Quantize to FP8 per-tensor using pre-computed global scale factor
|
|
hidden_states_quant, _ = quant_fp8_per_tensor(
|
|
hidden_states, hidden_states_scale_global
|
|
)
|
|
|
|
return {
|
|
"hidden_states": hidden_states_quant,
|
|
"hidden_states_scale": None,
|
|
}
|
|
|
|
def prepare_static_weights_for_kernel(
|
|
self,
|
|
args_dequant,
|
|
args,
|
|
gemm1_weights_orig,
|
|
gemm2_weights_orig,
|
|
hidden_size,
|
|
intermediate_size,
|
|
num_experts,
|
|
weight_processing,
|
|
):
|
|
"""Prepare quantized weights for kernel (done offline with weights)."""
|
|
# FIXME: this depends on the kernel internals
|
|
epilogue_tile_m = 128
|
|
|
|
# Reorder rows of W1 for fused gated activation
|
|
gemm1_weights_fp8_interleaved = []
|
|
for i in range(num_experts):
|
|
gemm1_weights_fp8_interleaved.append(
|
|
reorder_rows_for_gated_act_gemm(args.gemm1_weights[i].clone())
|
|
)
|
|
|
|
# Stack weights and scales for all experts
|
|
gemm1_weights_fp8_interleaved = torch.stack(
|
|
gemm1_weights_fp8_interleaved
|
|
).reshape(num_experts, 2 * intermediate_size, hidden_size)
|
|
|
|
# Shuffle weights and scaling factors for transposed mma output
|
|
gemm1_weights_fp8_shuffled = []
|
|
gemm2_weights_fp8_shuffled = []
|
|
for i in range(num_experts):
|
|
gemm1_weights_fp8_shuffled.append(
|
|
shuffle_matrix_a(
|
|
gemm1_weights_fp8_interleaved[i].view(torch.uint8), epilogue_tile_m
|
|
)
|
|
)
|
|
|
|
gemm2_weights_fp8_shuffled.append(
|
|
shuffle_matrix_a(
|
|
args.gemm2_weights[i].view(torch.uint8), epilogue_tile_m
|
|
)
|
|
)
|
|
|
|
# Stack weights for all experts
|
|
gemm1_weights_fp8_shuffled = torch.stack(gemm1_weights_fp8_shuffled).view(
|
|
torch.float8_e4m3fn
|
|
)
|
|
gemm2_weights_fp8_shuffled = torch.stack(gemm2_weights_fp8_shuffled).view(
|
|
torch.float8_e4m3fn
|
|
)
|
|
|
|
# Calculate scaling factors that depend on weights
|
|
scale_c_fc1 = (
|
|
args_dequant.c_global_sf
|
|
* (1.0 / args.gemm1_scales_global)
|
|
* (1.0 / args.hidden_states_scale_global)
|
|
)
|
|
scale_gate_fc1 = (1.0 / args.gemm1_scales_global) * (
|
|
1.0 / args.hidden_states_scale_global
|
|
)
|
|
scale_c_fc2 = (1.0 / args_dequant.c_global_sf) * (
|
|
1.0 / args.gemm2_scales_global
|
|
)
|
|
|
|
return {
|
|
"gemm1_weights": gemm1_weights_fp8_shuffled,
|
|
"gemm2_weights": gemm2_weights_fp8_shuffled,
|
|
"scale_c_fc1": scale_c_fc1,
|
|
"scale_gate_fc1": scale_gate_fc1,
|
|
"scale_c_fc2": scale_c_fc2,
|
|
}
|
|
|
|
def call_moe(
|
|
self, static_data, hidden_states_orig, hidden_states_scale_global, **kwargs
|
|
):
|
|
"""Call MoE with runtime input quantization + kernel execution (done at runtime)."""
|
|
expert_logits = kwargs["expert_logits"]
|
|
routing_bias = kwargs["routing_bias"]
|
|
num_experts = kwargs["num_experts"]
|
|
top_k = kwargs["top_k"]
|
|
n_groups = kwargs["n_groups"]
|
|
top_k_groups = kwargs["top_k_groups"]
|
|
intermediate_size = kwargs["intermediate_size"]
|
|
routed_scaling = kwargs["routed_scaling"]
|
|
routing_method_type = kwargs["routing_method_type"]
|
|
tile_tokens_dim = kwargs["tile_tokens_dim"]
|
|
|
|
# Quantize to FP8 per-tensor using pre-computed global scale factor
|
|
hidden_states_fp8, _ = quant_fp8_per_tensor(
|
|
hidden_states_orig, hidden_states_scale_global
|
|
)
|
|
|
|
output = trtllm_fp8_per_tensor_scale_moe(
|
|
(
|
|
expert_logits.to(torch.bfloat16)
|
|
if routing_method_type == RoutingMethodType.Llama4
|
|
else expert_logits
|
|
),
|
|
routing_bias,
|
|
hidden_states_fp8,
|
|
static_data["gemm1_weights"],
|
|
static_data["scale_c_fc1"],
|
|
static_data["scale_gate_fc1"],
|
|
static_data["gemm2_weights"],
|
|
static_data["scale_c_fc2"],
|
|
num_experts,
|
|
top_k,
|
|
n_groups,
|
|
top_k_groups,
|
|
intermediate_size,
|
|
0,
|
|
num_experts,
|
|
routed_scaling,
|
|
routing_method_type
|
|
== RoutingMethodType.Llama4, # Use_routing_scales_on_input
|
|
tile_tokens_dim,
|
|
routing_method_type,
|
|
)
|
|
|
|
return output.to(torch.float)
|
|
|
|
def compute_reference(self, args):
|
|
"""FP8 per-tensor reference implementation."""
|
|
return run_moe_reference_per_tensor_scale_fp8(args)
|
|
|
|
def get_tolerances(self):
|
|
"""Get FP8 per-tensor accuracy tolerances."""
|
|
return {"atol": 0.1, "rtol": 0.85, "percent": 0.925}
|
|
|
|
|
|
# ====================================================================================
|
|
# Quantizer Factory
|
|
# ====================================================================================
|
|
|
|
|
|
def get_moe_impl(quant_mode: QuantMode):
|
|
"""Factory function to get the appropriate MoE implementation."""
|
|
if quant_mode == QuantMode.FP8_BLOCK_SCALE:
|
|
return FP8BlockScaleMoe()
|
|
elif quant_mode == QuantMode.FP8_PER_TENSOR:
|
|
return FP8PerTensorMoe()
|
|
else:
|
|
return FP4Moe(quant_mode)
|
|
|
|
|
|
class moe_args:
|
|
"""Arguments container for MoE operations."""
|
|
|
|
def __init__(
|
|
self,
|
|
num_tokens,
|
|
num_experts,
|
|
hidden_size,
|
|
intermediate_size,
|
|
top_k,
|
|
padding,
|
|
hidden_states,
|
|
hidden_states_scale,
|
|
hidden_states_scale_global,
|
|
expert_logits,
|
|
gemm1_weights,
|
|
gemm1_scales,
|
|
gemm1_scales_global,
|
|
gemm2_weights,
|
|
gemm2_scales,
|
|
gemm2_scales_global,
|
|
permute_info,
|
|
use_routing_scales_on_input,
|
|
gated_act_type,
|
|
):
|
|
self.num_tokens = num_tokens
|
|
self.num_experts = num_experts
|
|
self.hidden_size = hidden_size
|
|
self.intermediate_size = intermediate_size
|
|
self.top_k = top_k
|
|
self.padding = padding
|
|
self.hidden_states = hidden_states
|
|
self.hidden_states_scale = hidden_states_scale
|
|
self.hidden_states_scale_global = hidden_states_scale_global
|
|
self.expert_logits = expert_logits
|
|
self.gemm1_weights = gemm1_weights
|
|
self.gemm1_scales = gemm1_scales
|
|
self.gemm1_scales_global = gemm1_scales_global
|
|
self.gemm2_weights = gemm2_weights
|
|
self.gemm2_scales = gemm2_scales
|
|
self.gemm2_scales_global = gemm2_scales_global
|
|
self.permute_info = permute_info
|
|
self.use_routing_scales_on_input = use_routing_scales_on_input
|
|
self.gated_act_type = gated_act_type
|
|
|
|
|
|
class moe_args_dequant:
|
|
"""Arguments container for dequantized MoE operations."""
|
|
|
|
def __init__(
|
|
self,
|
|
num_tokens,
|
|
num_experts,
|
|
hidden_size,
|
|
intermediate_size,
|
|
top_k,
|
|
padding,
|
|
hidden_states,
|
|
expert_logits,
|
|
gemm1_weights,
|
|
gemm2_weights,
|
|
permute_info,
|
|
use_routing_scales_on_input,
|
|
gated_act_type,
|
|
):
|
|
self.num_tokens = num_tokens
|
|
self.num_experts = num_experts
|
|
self.hidden_size = hidden_size
|
|
self.intermediate_size = intermediate_size
|
|
self.top_k = top_k
|
|
self.padding = padding
|
|
self.hidden_states = hidden_states
|
|
self.expert_logits = expert_logits
|
|
self.gemm1_weights = gemm1_weights
|
|
self.gemm2_weights = gemm2_weights
|
|
self.permute_info = permute_info
|
|
self.use_routing_scales_on_input = use_routing_scales_on_input
|
|
self.gated_act_type = gated_act_type
|
|
|
|
|
|
def routing_reference(expertLogits, topK, padding):
|
|
"""Reference routing implementation for permutation calculation."""
|
|
originalDevice = expertLogits.device
|
|
expertLogits = expertLogits.cpu()
|
|
numTokens, numExperts = expertLogits.shape
|
|
assert topK <= numExperts
|
|
|
|
numTokensPerExpert = torch.zeros(numExperts, dtype=torch.int64)
|
|
expandedTokenIdxToExpert = -torch.ones(numTokens * topK, dtype=torch.int64)
|
|
expandedTokenIdxToIdxInExpert = -torch.ones(numTokens * topK, dtype=torch.int64)
|
|
|
|
topKLogits, topKIndices = torch.topk(expertLogits, topK, dim=1)
|
|
for tokenIdx in range(numTokens):
|
|
for k in range(topK):
|
|
expandedIdx = tokenIdx * topK + k
|
|
expertIndex = topKIndices[tokenIdx, k]
|
|
expandedTokenIdxToExpert[expandedIdx] = expertIndex
|
|
expandedTokenIdxToIdxInExpert[expandedIdx] = numTokensPerExpert[expertIndex]
|
|
numTokensPerExpert[expertIndex] += 1
|
|
|
|
paddedTokensPerExpertPrefixSum = torch.zeros(numExperts + 1, dtype=torch.int64)
|
|
for ii in range(numExperts):
|
|
|
|
def divUpMul(a, b):
|
|
return (a + b - 1) // b * b
|
|
|
|
paddedTokensPerExpertPrefixSum[ii + 1] = paddedTokensPerExpertPrefixSum[
|
|
ii
|
|
] + divUpMul(numTokensPerExpert[ii], padding)
|
|
permutedBufferSize = paddedTokensPerExpertPrefixSum[numExperts]
|
|
|
|
expandedTokenIdxToPermutedIdx = -torch.ones(numTokens * topK, dtype=torch.int64)
|
|
permutedIdxToExpandedIdx = -torch.ones(permutedBufferSize, dtype=torch.int64)
|
|
permutedIdxToTokenIdx = -torch.ones(permutedBufferSize, dtype=torch.int64)
|
|
for tokenIdx in range(numTokens):
|
|
for k in range(topK):
|
|
expandedIdx = tokenIdx * topK + k
|
|
expert = expandedTokenIdxToExpert[expandedIdx]
|
|
offsetWithinExpert = expandedTokenIdxToIdxInExpert[expandedIdx]
|
|
offsetForExpert = paddedTokensPerExpertPrefixSum[expert]
|
|
permutedIdx = offsetForExpert + offsetWithinExpert
|
|
|
|
expandedTokenIdxToPermutedIdx[expandedIdx] = permutedIdx
|
|
permutedIdxToExpandedIdx[permutedIdx] = expandedIdx
|
|
permutedIdxToTokenIdx[permutedIdx] = tokenIdx
|
|
return {
|
|
"paddedTokensPerExpertPrefixSum": paddedTokensPerExpertPrefixSum.to(
|
|
originalDevice
|
|
),
|
|
"permutedBufferSize": permutedBufferSize.item(),
|
|
"expandedTokenIdxToPermutedIdx": expandedTokenIdxToPermutedIdx.to(
|
|
originalDevice
|
|
),
|
|
"permutedIdxToExpandedIdx": permutedIdxToExpandedIdx.to(originalDevice),
|
|
"numTokensPerExpert": numTokensPerExpert.to(originalDevice),
|
|
"expandedTokenIdxToExpert": expandedTokenIdxToExpert.to(originalDevice),
|
|
"topKLogits": topKLogits.to(originalDevice),
|
|
"permutedIdxToTokenIdx": permutedIdxToTokenIdx.to(originalDevice),
|
|
"topKIndices": topKIndices.to(originalDevice),
|
|
}
|
|
|
|
|
|
def noaux_tc_ref(logits, bias, n_group, topk_group, top_k, routed_scaling_factor):
|
|
"""DeepSeek-style no-aux routing reference implementation."""
|
|
scores = F.sigmoid(logits)
|
|
scores_with_bias = scores + bias
|
|
if n_group > 1:
|
|
scores_shape = list(scores_with_bias.shape)
|
|
group_scores = torch.sum(
|
|
torch.topk(
|
|
scores_with_bias.view(
|
|
scores_shape[:-1] + [n_group, scores_shape[-1] // n_group]
|
|
),
|
|
k=2,
|
|
dim=-1,
|
|
largest=True,
|
|
sorted=True,
|
|
)[0],
|
|
dim=-1,
|
|
)
|
|
_, group_idx = torch.topk(
|
|
group_scores, k=topk_group, dim=-1, largest=True, sorted=True
|
|
)
|
|
group_mask = torch.zeros_like(group_scores)
|
|
group_mask.scatter_(-1, group_idx, 1)
|
|
score_mask = (
|
|
group_mask.unsqueeze(-1)
|
|
.expand(scores_shape[:-1] + [n_group, scores_shape[-1] // n_group])
|
|
.reshape(scores_shape)
|
|
)
|
|
scores_with_bias = scores_with_bias * score_mask
|
|
|
|
_, topk_idx = torch.topk(
|
|
scores_with_bias, k=top_k, dim=-1, largest=True, sorted=True
|
|
)
|
|
new_mask = torch.zeros_like(scores)
|
|
new_mask.scatter_(-1, topk_idx, 1)
|
|
scores = scores * new_mask
|
|
score_sum = torch.sum(scores, dim=-1, keepdim=True) + 1e-20
|
|
scores = scores / score_sum * routed_scaling_factor
|
|
return scores
|
|
|
|
|
|
def routing_reference_no_aux(
|
|
expert_logits,
|
|
routing_bias,
|
|
top_k,
|
|
n_groups,
|
|
top_k_groups,
|
|
routed_scaling,
|
|
padding,
|
|
use_routing_scales_on_input=False,
|
|
):
|
|
"""Tiered TopK routing used by DeepSeek."""
|
|
routing_logits = expert_logits.to(dtype=torch.float, device="cuda")
|
|
if use_routing_scales_on_input:
|
|
# if using routing scales on input, topK == 1 and the score is a plain sigmoid
|
|
scores = F.sigmoid(routing_logits)
|
|
else:
|
|
scores = noaux_tc_ref(
|
|
routing_logits, routing_bias, n_groups, top_k_groups, top_k, routed_scaling
|
|
)
|
|
permute_info = routing_reference(scores, top_k, padding)
|
|
return permute_info, scores
|
|
|
|
|
|
def routing_reference_renormalize(expert_logits, top_k, num_experts, padding):
|
|
"""TopK -> Softmax routing reference."""
|
|
topk_values, topk_idx = torch.topk(expert_logits, k=top_k, dim=-1)
|
|
topk_values = torch.nn.functional.softmax(topk_values.float(), dim=-1)
|
|
|
|
new_mask = torch.zeros_like(expert_logits)
|
|
new_mask.scatter_(-1, topk_idx, 1)
|
|
scores = expert_logits * new_mask
|
|
|
|
for i in range(topk_idx.shape[0]):
|
|
for j in range(topk_idx.shape[1]):
|
|
scores[i, topk_idx[i, j]] = topk_values[i, j]
|
|
permute_info = routing_reference(scores, top_k, padding)
|
|
return permute_info, scores
|
|
|
|
|
|
def routing_reference_renormalize_naive(expert_logits, top_k, num_experts, padding):
|
|
"""Softmax->TopK -> Normalize routing reference."""
|
|
norm_topk_prob = True
|
|
scores = torch.nn.functional.softmax(expert_logits.float(), dim=-1)
|
|
topk_values, topk_idx = torch.topk(scores, k=top_k, dim=-1)
|
|
|
|
if norm_topk_prob: # only diff with mixtral sparse moe block!
|
|
topk_values /= topk_values.sum(dim=-1, keepdim=True)
|
|
topk_values = topk_values.to(expert_logits.dtype)
|
|
scores = scores.to(expert_logits.dtype)
|
|
|
|
new_mask = torch.zeros_like(expert_logits)
|
|
new_mask.scatter_(-1, topk_idx, 1)
|
|
scores = expert_logits * new_mask
|
|
|
|
for i in range(topk_idx.shape[0]):
|
|
for j in range(topk_idx.shape[1]):
|
|
scores[i, topk_idx[i, j]] = topk_values[i, j]
|
|
permute_info = routing_reference(scores, top_k, padding)
|
|
return permute_info, scores
|
|
|
|
|
|
def routing_reference_topk(expert_logits, top_k, num_experts, padding):
|
|
"""TopK only (no softmax) routing reference."""
|
|
topk_values, topk_idx = torch.topk(expert_logits, k=top_k, dim=-1)
|
|
|
|
new_mask = torch.zeros_like(expert_logits)
|
|
new_mask.scatter_(-1, topk_idx, 1)
|
|
scores = expert_logits * new_mask
|
|
|
|
for i in range(topk_idx.shape[0]):
|
|
for j in range(topk_idx.shape[1]):
|
|
scores[i, topk_idx[i, j]] = topk_values[i, j]
|
|
permute_info = routing_reference(scores, top_k, padding)
|
|
return permute_info, scores
|
|
|
|
|
|
def check_accuracy(a, b, atol, rtol, percent):
|
|
"""Unified accuracy checking function with detailed error reporting."""
|
|
if torch.any(torch.isnan(a)):
|
|
raise Exception("NaN in reference output")
|
|
if torch.any(torch.isnan(b)):
|
|
raise Exception("NaN in actual output")
|
|
if torch.any(torch.isinf(a)):
|
|
raise Exception("Inf in reference output")
|
|
if torch.any(torch.isinf(b)):
|
|
raise Exception("Inf in actual output")
|
|
assert a.shape == b.shape, f"Shape mismatch: {a.shape} vs {b.shape}"
|
|
|
|
left = torch.abs(a - b)
|
|
right = atol + rtol * torch.abs(b)
|
|
count = torch.sum(left > right)
|
|
mismatch_percent = count / a.numel()
|
|
if mismatch_percent > 1 - percent:
|
|
raise Exception(
|
|
f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} "
|
|
f"(threshold: {1 - percent:.4f})"
|
|
)
|
|
|
|
|
|
# ====================================================================================
|
|
# FP4 Quantization Functions
|
|
# ====================================================================================
|
|
|
|
|
|
def calculate_fp4_global_scale_factor(tensor, use_ue8m0=False):
|
|
"""
|
|
Calculate FP4 global scale factor for a tensor.
|
|
|
|
NOTE: In production, global scale factors are typically obtained offline during:
|
|
- Post-Training Quantization (PTQ) calibration process
|
|
- Quantization-Aware Training (QAT) process
|
|
|
|
This function is used here for testing/reference purposes.
|
|
Formula: (448 * 6) represents max representable value in FP4 format.
|
|
"""
|
|
if use_ue8m0:
|
|
return torch.tensor(1.0, dtype=torch.float32)
|
|
else:
|
|
return (448 * 6) / tensor.float().abs().nan_to_num().max()
|
|
|
|
|
|
def e2m1_and_ufp8_scale_batches(
|
|
mat_fp4: torch.Tensor,
|
|
scale_tensor: torch.Tensor,
|
|
global_scale_tensor: torch.Tensor,
|
|
sf_vec_size: int,
|
|
ufp8_type: int = 1,
|
|
):
|
|
"""Batch FP4 dequantization helper."""
|
|
num_batches = mat_fp4.size(0)
|
|
scale_tensor = scale_tensor.view(num_batches, -1)
|
|
|
|
tensors = [
|
|
e2m1_and_ufp8sf_scale_to_float(
|
|
mat_fp4[b, :, :].cpu(),
|
|
scale_tensor[b, :].cpu().reshape(-1),
|
|
global_scale_tensor[b].cpu(),
|
|
sf_vec_size,
|
|
ufp8_type,
|
|
True, # is_sf_swizzled_layout
|
|
)
|
|
for b in range(num_batches)
|
|
]
|
|
|
|
result = torch.stack(tensors)
|
|
return result
|
|
|
|
|
|
def quant_fp4(a, a_global_sf, use_ue8m0=False, is_sf_swizzled_layout=True):
|
|
"""
|
|
Quantize FP4 with pre-computed global scale factor.
|
|
|
|
This function expects global scale factors that have been pre-computed offline
|
|
during PTQ/QAT calibration process. The global scale factor should NOT be
|
|
computed at runtime to avoid performance overhead.
|
|
|
|
Pure function - same inputs always produce same outputs.
|
|
"""
|
|
sf_vec_size = 32 if use_ue8m0 else 16
|
|
|
|
a_fp4, a_sf = fp4_quantize(
|
|
a.cuda(), a_global_sf.cuda(), sf_vec_size, use_ue8m0, is_sf_swizzled_layout
|
|
)
|
|
|
|
return a_fp4, a_sf, a_global_sf
|
|
|
|
|
|
def quant_fp4_batches(a, num_experts, use_ue8m0=False, is_sf_swizzled_layout=True):
|
|
"""FP4 batch quantization function with centralized global scale factor calculation."""
|
|
quant_a = []
|
|
sfs = []
|
|
global_sfs = []
|
|
for i in range(num_experts):
|
|
# Use centralized global scale factor calculation
|
|
a_global_sf = calculate_fp4_global_scale_factor(a[i], use_ue8m0)
|
|
a_fp4, a_sf, _ = quant_fp4(a[i], a_global_sf, use_ue8m0, is_sf_swizzled_layout)
|
|
quant_a.append(a_fp4)
|
|
sfs.append(a_sf)
|
|
global_sfs.append(a_global_sf)
|
|
|
|
result_quant_a = torch.stack(quant_a)
|
|
result_sfs = torch.stack(sfs)
|
|
result_global_sfs = torch.stack(global_sfs)
|
|
|
|
return result_quant_a, result_sfs, result_global_sfs
|
|
|
|
|
|
def quant_dequant_fp4(a, use_ue8m0=False, is_sf_swizzled_layout=True):
|
|
"""FP4 quantize-dequantize roundtrip function with centralized global scale factor calculation."""
|
|
# Use centralized global scale factor calculation
|
|
a_global_sf = calculate_fp4_global_scale_factor(a, use_ue8m0)
|
|
sf_vec_size = 32 if use_ue8m0 else 16
|
|
|
|
a_fp4, a_sf = fp4_quantize(
|
|
a.cuda(), a_global_sf.cuda(), sf_vec_size, use_ue8m0, is_sf_swizzled_layout
|
|
)
|
|
|
|
a_pt = e2m1_and_ufp8sf_scale_to_float(
|
|
a_fp4.cpu(),
|
|
a_sf.cpu().reshape(-1),
|
|
(1 / a_global_sf).cpu(),
|
|
sf_vec_size,
|
|
1 if not use_ue8m0 else 0, # ufp8_type
|
|
is_sf_swizzled_layout,
|
|
)
|
|
|
|
return a_pt.cuda(), a_global_sf
|
|
|
|
|
|
# ====================================================================================
|
|
# FP8 Quantization Functions
|
|
# ====================================================================================
|
|
|
|
|
|
def calculate_fp8_global_scale_factor(tensor):
|
|
"""
|
|
Calculate FP8 global scale factor for a tensor.
|
|
|
|
NOTE: In production, global scale factors are typically obtained offline during:
|
|
- Post-Training Quantization (PTQ) calibration process
|
|
- Quantization-Aware Training (QAT) process
|
|
|
|
This function is used here for testing/reference purposes.
|
|
Formula: 448 represents max representable value in FP8 E4M3 format.
|
|
"""
|
|
return 448 / tensor.float().abs().nan_to_num().max()
|
|
|
|
|
|
def quant_fp8_per_tensor(a, a_global_sf):
|
|
"""
|
|
Quantize FP8 per-tensor with pre-computed global scale factor.
|
|
|
|
This function expects global scale factors that have been pre-computed offline
|
|
during PTQ/QAT calibration process. The global scale factor should NOT be
|
|
computed at runtime to avoid performance overhead.
|
|
|
|
Pure function - same inputs always produce same outputs.
|
|
"""
|
|
a_fp8 = (a * a_global_sf).to(torch.float8_e4m3fn)
|
|
return a_fp8, a_global_sf
|
|
|
|
|
|
def quant_fp8_per_tensor_batches(a):
|
|
"""FP8 per-tensor batch quantization function with centralized global scale factor calculation."""
|
|
num_batches = a.size(0)
|
|
a_quant = []
|
|
a_scales = []
|
|
|
|
for i in range(num_batches):
|
|
# Use centralized global scale factor calculation
|
|
a_global_sf = calculate_fp8_global_scale_factor(a[i])
|
|
a_fp8, _ = quant_fp8_per_tensor(a[i], a_global_sf)
|
|
a_quant.append(a_fp8)
|
|
a_scales.append(a_global_sf)
|
|
|
|
result_a_quant = torch.stack(a_quant)
|
|
result_a_scales = torch.stack(a_scales)
|
|
|
|
return result_a_quant, result_a_scales
|
|
|
|
|
|
def quant_dequant_per_tensor_fp8(a):
|
|
"""FP8 per-tensor quantize-dequantize roundtrip function with centralized global scale factor calculation."""
|
|
# Use centralized global scale factor calculation
|
|
a_global_sf = calculate_fp8_global_scale_factor(a)
|
|
a_fp8, _ = quant_fp8_per_tensor(a, a_global_sf)
|
|
a_pt = a_fp8.to(torch.float) / a_global_sf
|
|
return a_pt.cuda(), a_global_sf
|
|
|
|
|
|
def dequant_reference_dsfp8(input, scale, transpose_scale, block_m, block_n):
|
|
"""Reference FP8 block-scale dequantization."""
|
|
input = input.to(torch.float)
|
|
scale = scale.to(torch.float)
|
|
if transpose_scale:
|
|
scale = scale.t()
|
|
|
|
m, n = input.shape
|
|
m_tile = 128 if block_m else 1
|
|
n_tile = 128 if block_n else 1
|
|
|
|
assert m % m_tile == 0
|
|
assert n % n_tile == 0
|
|
assert scale.shape == (m // m_tile, n // n_tile)
|
|
|
|
# Expand scale to match input dimensions using tensor operations
|
|
if m_tile > 1:
|
|
scale = torch.repeat_interleave(scale, m_tile, dim=0)
|
|
if n_tile > 1:
|
|
scale = torch.repeat_interleave(scale, n_tile, dim=1)
|
|
|
|
# Element-wise multiplication (equivalent to the nested loop logic)
|
|
output = input * scale
|
|
return output
|
|
|
|
|
|
# ====================================================================================
|
|
# Common MoE Reference Implementation
|
|
# ====================================================================================
|
|
|
|
|
|
def run_moe_dequant(args, quant_mode: QuantMode):
|
|
"""Common dequantized MoE reference implementation."""
|
|
# Permute
|
|
total_num_padded_tokens = args.permute_info["permutedBufferSize"]
|
|
expanded_idx_to_permuted_idx = args.permute_info[
|
|
"expandedTokenIdxToPermutedIdx"
|
|
].cpu()
|
|
num_tokens_per_expert = args.permute_info["numTokensPerExpert"].cpu()
|
|
permute_output = torch.full(
|
|
(total_num_padded_tokens, args.hidden_size), float("nan"), device="cuda"
|
|
).to(torch.float)
|
|
for i in range(args.num_tokens):
|
|
for j in range(args.top_k):
|
|
permuted_idx = expanded_idx_to_permuted_idx[i * args.top_k + j]
|
|
permute_output[permuted_idx] = args.hidden_states[i]
|
|
|
|
# Gemm1
|
|
gemm1_output = torch.full(
|
|
(total_num_padded_tokens, 2 * args.intermediate_size),
|
|
float("nan"),
|
|
device="cuda",
|
|
).to(torch.float)
|
|
i = 0
|
|
for expert_idx in range(args.num_experts):
|
|
my_num_tokens = num_tokens_per_expert[expert_idx]
|
|
if my_num_tokens == 0:
|
|
continue
|
|
my_a = permute_output[i : i + my_num_tokens]
|
|
my_b = args.gemm1_weights[expert_idx]
|
|
my_c = my_a @ my_b.t()
|
|
gemm1_output[i : i + my_num_tokens] = my_c
|
|
i += my_num_tokens
|
|
i = (i + args.padding - 1) // args.padding * args.padding
|
|
|
|
if args.use_routing_scales_on_input:
|
|
assert args.top_k == 1
|
|
# For each token and its top_k experts
|
|
for token_idx in range(args.num_tokens):
|
|
for k in range(args.top_k):
|
|
# Get the permuted index for this token's k-th expert
|
|
expanded_idx = token_idx * args.top_k + k
|
|
permuted_idx = expanded_idx_to_permuted_idx[expanded_idx]
|
|
expert_weight = args.permute_info["topKLogits"].to(torch.float)
|
|
# Get the expert weight for this token and expert
|
|
weight = expert_weight[token_idx, k]
|
|
# Scale the corresponding row in gemm1_output
|
|
gemm1_output[permuted_idx] *= weight
|
|
|
|
# Activation
|
|
activation_output = torch.full(
|
|
(total_num_padded_tokens, args.intermediate_size), float("nan"), device="cuda"
|
|
).to(torch.float)
|
|
|
|
gated_act_type = args.gated_act_type
|
|
gated_act_type_to_func = {
|
|
0: F.silu,
|
|
1: F.gelu,
|
|
}
|
|
gated_act_func = gated_act_type_to_func[gated_act_type]
|
|
|
|
i = 0
|
|
for expert_idx in range(args.num_experts):
|
|
my_num_tokens = num_tokens_per_expert[expert_idx]
|
|
if my_num_tokens == 0:
|
|
continue
|
|
my_a = gemm1_output[i : i + my_num_tokens]
|
|
my_x1 = my_a[:, : args.intermediate_size]
|
|
my_x2 = my_a[:, args.intermediate_size :]
|
|
activation_output[i : i + my_num_tokens] = gated_act_func(my_x2) * my_x1
|
|
i += my_num_tokens
|
|
i = (i + args.padding - 1) // args.padding * args.padding
|
|
|
|
if quant_mode == QuantMode.FP4_NVFP4_NVFP4:
|
|
# Use centralized function for activation quantization
|
|
activation_output, c_global_sf = quant_dequant_fp4(
|
|
activation_output.to(torch.bfloat16), False, True
|
|
)
|
|
activation_output = activation_output.to(torch.float)
|
|
args.c_global_sf = c_global_sf
|
|
elif quant_mode == QuantMode.FP8_PER_TENSOR:
|
|
activation_output, c_global_sf = quant_dequant_per_tensor_fp8(
|
|
activation_output.to(torch.bfloat16)
|
|
)
|
|
activation_output = activation_output.to(torch.float)
|
|
args.c_global_sf = c_global_sf
|
|
elif quant_mode == QuantMode.FP4_MXFP4_MXFP8:
|
|
activation_output, scale_bytes = mxfp8_quantize(
|
|
activation_output.to(torch.bfloat16), True
|
|
)
|
|
scale_bytes = scale_bytes.view(torch.uint8).reshape(-1).cpu()
|
|
activation_output = (
|
|
mxfp8_dequantize_host(
|
|
activation_output.cpu().view(torch.uint8), scale_bytes
|
|
)
|
|
.cuda()
|
|
.to(torch.float)
|
|
)
|
|
args.c_global_sf = 1.0
|
|
else: # mxfp4Bf16
|
|
activation_output = activation_output.to(torch.bfloat16).to(torch.float)
|
|
args.c_global_sf = 1.0
|
|
|
|
# Gemm2
|
|
gemm2_output = torch.full(
|
|
(total_num_padded_tokens, args.hidden_size), float("nan"), device="cuda"
|
|
).to(torch.float)
|
|
i = 0
|
|
for expert_idx in range(args.num_experts):
|
|
my_num_tokens = num_tokens_per_expert[expert_idx]
|
|
if my_num_tokens == 0:
|
|
continue
|
|
my_a = activation_output[i : i + my_num_tokens]
|
|
my_b = args.gemm2_weights[expert_idx]
|
|
my_c = my_a @ my_b.t()
|
|
gemm2_output[i : i + my_num_tokens] = my_c
|
|
i += my_num_tokens
|
|
i = (i + args.padding - 1) // args.padding * args.padding
|
|
|
|
# Finalize
|
|
expert_weight = args.permute_info["topKLogits"].to(torch.float)
|
|
finalize_output = torch.full(
|
|
(args.num_tokens, args.hidden_size), float("nan"), device="cuda"
|
|
).to(torch.float)
|
|
for i in range(args.num_tokens):
|
|
acc = torch.zeros(args.hidden_size, dtype=torch.float, device="cuda")
|
|
for top_k_idx in range(args.top_k):
|
|
expanded_idx = i * args.top_k + top_k_idx
|
|
permuted_idx = expanded_idx_to_permuted_idx[expanded_idx]
|
|
original_vector = gemm2_output[permuted_idx]
|
|
weight = (
|
|
expert_weight[i, top_k_idx]
|
|
if not args.use_routing_scales_on_input
|
|
else 1.0
|
|
)
|
|
acc += original_vector * weight
|
|
finalize_output[i] = acc
|
|
return finalize_output
|
|
|
|
|
|
# ====================================================================================
|
|
# Quantization-Specific Reference Implementations
|
|
# ====================================================================================
|
|
|
|
|
|
def run_moe_reference_fp4(args, quant_mode: QuantMode):
|
|
sf_vec_size = 16 if quant_mode == QuantMode.FP4_NVFP4_NVFP4 else 32
|
|
ufp8_type_weights = 1 if quant_mode == QuantMode.FP4_NVFP4_NVFP4 else 0
|
|
|
|
if quant_mode == QuantMode.FP4_NVFP4_NVFP4:
|
|
hidden_states_dequant = e2m1_and_ufp8sf_scale_to_float(
|
|
args.hidden_states.cpu(),
|
|
args.hidden_states_scale.cpu().view(torch.uint8).reshape(-1),
|
|
(1 / args.hidden_states_scale_global).cpu(),
|
|
sf_vec_size,
|
|
ufp8_type_weights,
|
|
True, # is_sf_swizzled_layout
|
|
).cuda()
|
|
elif quant_mode == QuantMode.FP4_MXFP4_MXFP8:
|
|
hidden_states_dequant = mxfp8_dequantize_host(
|
|
args.hidden_states.cpu().view(torch.uint8),
|
|
args.hidden_states_scale.cpu().view(torch.uint8).reshape(-1),
|
|
True, # is_sf_swizzled_layout
|
|
).cuda()
|
|
else:
|
|
hidden_states_dequant = args.hidden_states.to(torch.bfloat16).to(torch.float)
|
|
|
|
gemm1_weights_dequant = e2m1_and_ufp8_scale_batches(
|
|
args.gemm1_weights,
|
|
args.gemm1_scales,
|
|
1 / args.gemm1_scales_global,
|
|
sf_vec_size,
|
|
ufp8_type_weights,
|
|
).cuda()
|
|
|
|
gemm2_weights_dequant = e2m1_and_ufp8_scale_batches(
|
|
args.gemm2_weights,
|
|
args.gemm2_scales,
|
|
1 / args.gemm2_scales_global,
|
|
sf_vec_size,
|
|
ufp8_type_weights,
|
|
).cuda()
|
|
|
|
args_dequant = moe_args_dequant(
|
|
args.num_tokens,
|
|
args.num_experts,
|
|
args.hidden_size,
|
|
args.intermediate_size,
|
|
args.top_k,
|
|
args.padding,
|
|
hidden_states_dequant,
|
|
args.expert_logits,
|
|
gemm1_weights_dequant,
|
|
gemm2_weights_dequant,
|
|
args.permute_info,
|
|
args.use_routing_scales_on_input,
|
|
args.gated_act_type,
|
|
)
|
|
|
|
return run_moe_dequant(args_dequant, quant_mode), args_dequant
|
|
|
|
|
|
def run_moe_reference_dsfp8(args):
|
|
"""FP8 block-scale reference implementation."""
|
|
# Generate block scales at runtime for FP8 block scaling
|
|
# Use deterministic scales for testing consistency
|
|
hidden_states_scale = 2.0 * torch.ones(
|
|
(args.hidden_size // 128, args.num_tokens), device="cuda", dtype=torch.float
|
|
)
|
|
|
|
hidden_states_dequant = dequant_reference_dsfp8(
|
|
args.hidden_states, hidden_states_scale, True, False, True
|
|
)
|
|
|
|
gemm1_weights_dequant = {}
|
|
for i in range(args.num_experts):
|
|
gemm1_weights_dequant[i] = dequant_reference_dsfp8(
|
|
args.gemm1_weights[i], args.gemm1_scales[i], False, True, True
|
|
)
|
|
|
|
gemm2_weights_dequant = {}
|
|
for i in range(args.num_experts):
|
|
gemm2_weights_dequant[i] = dequant_reference_dsfp8(
|
|
args.gemm2_weights[i], args.gemm2_scales[i], False, True, True
|
|
)
|
|
|
|
args_dequant = moe_args_dequant(
|
|
args.num_tokens,
|
|
args.num_experts,
|
|
args.hidden_size,
|
|
args.intermediate_size,
|
|
args.top_k,
|
|
args.padding,
|
|
hidden_states_dequant,
|
|
args.expert_logits,
|
|
gemm1_weights_dequant,
|
|
gemm2_weights_dequant,
|
|
args.permute_info,
|
|
args.use_routing_scales_on_input,
|
|
GatedActType.SwiGlu.value, # gated_act_type
|
|
)
|
|
|
|
return run_moe_dequant(args_dequant, QuantMode.FP8_BLOCK_SCALE), args_dequant
|
|
|
|
|
|
def run_moe_reference_per_tensor_scale_fp8(args):
|
|
"""FP8 per-tensor reference implementation."""
|
|
hidden_states_dequant = (
|
|
args.hidden_states.to(torch.float) / args.hidden_states_scale_global
|
|
)
|
|
|
|
gemm1_weights_dequant = {}
|
|
for i in range(args.num_experts):
|
|
gemm1_weights_dequant[i] = (
|
|
args.gemm1_weights[i].to(torch.float) / args.gemm1_scales_global[i]
|
|
)
|
|
|
|
gemm2_weights_dequant = {}
|
|
for i in range(args.num_experts):
|
|
gemm2_weights_dequant[i] = (
|
|
args.gemm2_weights[i].to(torch.float) / args.gemm2_scales_global[i]
|
|
)
|
|
|
|
args_dequant = moe_args_dequant(
|
|
args.num_tokens,
|
|
args.num_experts,
|
|
args.hidden_size,
|
|
args.intermediate_size,
|
|
args.top_k,
|
|
args.padding,
|
|
hidden_states_dequant,
|
|
args.expert_logits,
|
|
gemm1_weights_dequant,
|
|
gemm2_weights_dequant,
|
|
args.permute_info,
|
|
args.use_routing_scales_on_input,
|
|
GatedActType.SwiGlu.value, # gated_act_type
|
|
)
|
|
|
|
return run_moe_dequant(args_dequant, QuantMode.FP8_PER_TENSOR), args_dequant
|
|
|
|
|
|
def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs):
|
|
"""Unified actual computation that delegates to implementation-specific methods."""
|
|
# 1. Prepare static weights for the kernel (offline processing)
|
|
static_data = moe_impl.prepare_static_weights_for_kernel(
|
|
args_dequant,
|
|
args,
|
|
kwargs["gemm1_weights_orig"],
|
|
kwargs["gemm2_weights_orig"],
|
|
args.hidden_size,
|
|
args.intermediate_size,
|
|
args.num_experts,
|
|
kwargs["weight_processing"],
|
|
)
|
|
|
|
# 2. Call MoE with runtime input quantization + kernel execution
|
|
kernel_kwargs = {
|
|
"expert_logits": kwargs["expert_logits"],
|
|
"routing_bias": kwargs["routing_bias"],
|
|
"num_experts": args.num_experts,
|
|
"num_tokens": args.num_tokens,
|
|
"hidden_size": args.hidden_size,
|
|
"top_k": args.top_k,
|
|
"n_groups": kwargs["n_groups"],
|
|
"top_k_groups": kwargs["top_k_groups"],
|
|
"intermediate_size": args.intermediate_size,
|
|
"routed_scaling": kwargs["routed_scaling"],
|
|
"routing_method_type": kwargs["routing_method_type"],
|
|
"tile_tokens_dim": kwargs["tile_tokens_dim"],
|
|
"do_finalize": True,
|
|
"gated_act_type": args.gated_act_type,
|
|
}
|
|
|
|
return moe_impl.call_moe(
|
|
static_data,
|
|
kwargs["hidden_states_orig"],
|
|
args.hidden_states_scale_global,
|
|
**kernel_kwargs,
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def cache_permute_indices():
|
|
_cache_permute_indices: Dict[torch.Size, torch.Tensor] = {}
|
|
return _cache_permute_indices
|
|
|
|
|
|
@pytest.mark.parametrize("num_tokens", [1, 8, 1024])
|
|
@pytest.mark.parametrize("hidden_size", [1024, 8192])
|
|
@pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 384])
|
|
@pytest.mark.parametrize(
|
|
"moe_impl",
|
|
[
|
|
pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"),
|
|
pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"),
|
|
pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"),
|
|
pytest.param(FP8BlockScaleMoe(), id="FP8_Block"),
|
|
pytest.param(FP8PerTensorMoe(), id="FP8_Tensor"),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"routing_config",
|
|
[
|
|
pytest.param(
|
|
{
|
|
"num_experts": 256,
|
|
"top_k": 8,
|
|
"padding": 8,
|
|
"n_groups": 8,
|
|
"top_k_groups": 4,
|
|
"routed_scaling": 2.5,
|
|
"has_routing_bias": True,
|
|
"routing_method_type": RoutingMethodType.DeepSeekV3,
|
|
"compatible_moe_impls": [
|
|
FP4Moe,
|
|
FP8BlockScaleMoe,
|
|
],
|
|
},
|
|
id="DSv3",
|
|
),
|
|
pytest.param(
|
|
{
|
|
"num_experts": 72,
|
|
"top_k": 6,
|
|
"padding": 8,
|
|
"n_groups": 1,
|
|
"top_k_groups": 1,
|
|
"routed_scaling": 2.5,
|
|
"has_routing_bias": True,
|
|
"routing_method_type": RoutingMethodType.DeepSeekV3,
|
|
"compatible_moe_impls": [
|
|
FP4Moe,
|
|
FP8BlockScaleMoe,
|
|
],
|
|
},
|
|
id="DSLite",
|
|
),
|
|
pytest.param(
|
|
{
|
|
"num_experts": 128,
|
|
"top_k": 8,
|
|
"padding": 8,
|
|
"n_groups": None,
|
|
"top_k_groups": None,
|
|
"routed_scaling": None,
|
|
"has_routing_bias": False,
|
|
"routing_method_type": RoutingMethodType.Renormalize,
|
|
"compatible_moe_impls": [FP4Moe, FP8PerTensorMoe, FP8BlockScaleMoe],
|
|
},
|
|
id="Renorm",
|
|
marks=pytest.mark.skip(
|
|
reason="Disabled for testing speed - similar to RenormalizeNaive"
|
|
),
|
|
),
|
|
pytest.param(
|
|
{
|
|
"num_experts": 128,
|
|
"top_k": 8,
|
|
"padding": 8,
|
|
"n_groups": None,
|
|
"top_k_groups": None,
|
|
"routed_scaling": None,
|
|
"has_routing_bias": False,
|
|
"routing_method_type": RoutingMethodType.RenormalizeNaive,
|
|
"compatible_moe_impls": [FP4Moe],
|
|
},
|
|
id="RenormNaive",
|
|
),
|
|
pytest.param(
|
|
{
|
|
"num_experts": 16,
|
|
"top_k": 2,
|
|
"padding": 8,
|
|
"n_groups": None,
|
|
"top_k_groups": None,
|
|
"routed_scaling": None,
|
|
"has_routing_bias": False,
|
|
"routing_method_type": RoutingMethodType.TopK,
|
|
"compatible_moe_impls": [FP4Moe],
|
|
},
|
|
id="TopK",
|
|
),
|
|
pytest.param(
|
|
{
|
|
"num_experts": 128,
|
|
"top_k": 1,
|
|
"padding": 8,
|
|
"n_groups": 0,
|
|
"top_k_groups": 0,
|
|
"routed_scaling": 2.5,
|
|
"has_routing_bias": True,
|
|
"routing_method_type": RoutingMethodType.Llama4,
|
|
"compatible_moe_impls": [FP8PerTensorMoe],
|
|
},
|
|
id="Llama4",
|
|
),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"weight_processing",
|
|
[
|
|
pytest.param(
|
|
{
|
|
"use_shuffled_weight": False,
|
|
"layout": WeightLayout.MajorK,
|
|
"compatible_moe_impls": [FP8BlockScaleMoe],
|
|
},
|
|
id="NoShuffle_MajorK",
|
|
),
|
|
pytest.param(
|
|
{
|
|
"use_shuffled_weight": True,
|
|
"layout": WeightLayout.MajorK,
|
|
"compatible_moe_impls": [FP4Moe, FP8PerTensorMoe, FP8BlockScaleMoe],
|
|
},
|
|
id="Shuffled_MajorK",
|
|
),
|
|
pytest.param(
|
|
{
|
|
"use_shuffled_weight": True,
|
|
"layout": WeightLayout.BlockMajorK,
|
|
"compatible_moe_impls": [FP8BlockScaleMoe],
|
|
},
|
|
id="Shuffled_BlockMajorK",
|
|
),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"gated_act_type",
|
|
[
|
|
pytest.param(GatedActType.SwiGlu, id="SwiGlu"),
|
|
pytest.param(GatedActType.GeGlu, id="GeGlu"),
|
|
],
|
|
)
|
|
def test_moe_quantization_classes(
|
|
num_tokens,
|
|
hidden_size,
|
|
intermediate_size,
|
|
moe_impl,
|
|
routing_config,
|
|
weight_processing,
|
|
gated_act_type,
|
|
cache_permute_indices,
|
|
):
|
|
"""
|
|
Test MoE implementations using separated quantization workflow.
|
|
|
|
This test demonstrates the clean separation between:
|
|
- Static weight quantization (done offline)
|
|
- Dynamic input quantization (done at runtime)
|
|
|
|
Each quantization class clearly shows which precision is being used.
|
|
"""
|
|
# Skip incompatible combinations
|
|
if gated_act_type == GatedActType.GeGlu and (
|
|
type(moe_impl) is not FP4Moe
|
|
or moe_impl.quant_mode != QuantMode.FP4_NVFP4_NVFP4
|
|
or routing_config["routing_method_type"] != RoutingMethodType.TopK
|
|
or num_tokens > 128
|
|
):
|
|
# GeGlu is only supported for FP4Moe FP4_NVFP4_NVFP4 and TopK routing
|
|
pytest.skip(
|
|
f"Incompatible: {moe_impl.name} + {gated_act_type} + {routing_config['routing_method_type']} + {num_tokens}"
|
|
)
|
|
elif gated_act_type == GatedActType.SwiGlu and (
|
|
hidden_size > 1024 or intermediate_size > 1024
|
|
):
|
|
# Skip some tests for SwiGlu for testing speed
|
|
pytest.skip(
|
|
f"Skip for testing speed: {gated_act_type} + {hidden_size} + {intermediate_size}"
|
|
)
|
|
|
|
if type(moe_impl) not in routing_config["compatible_moe_impls"]:
|
|
pytest.skip(
|
|
f"Incompatible: {moe_impl.name} + {routing_config['routing_method_type'].name}"
|
|
)
|
|
if type(moe_impl) not in weight_processing["compatible_moe_impls"]:
|
|
pytest.skip(
|
|
f"Incompatible: {moe_impl.name} + {weight_processing['use_shuffled_weight']} + {weight_processing['layout']}"
|
|
)
|
|
|
|
moe_impl._cache_permute_indices = cache_permute_indices
|
|
|
|
seed = 0
|
|
torch.random.manual_seed(seed)
|
|
|
|
# Extract routing configuration
|
|
top_k = routing_config["top_k"]
|
|
padding = routing_config["padding"]
|
|
n_groups = routing_config["n_groups"]
|
|
top_k_groups = routing_config["top_k_groups"]
|
|
routed_scaling = routing_config["routed_scaling"]
|
|
num_experts = routing_config["num_experts"]
|
|
routing_method_type = routing_config["routing_method_type"]
|
|
tile_tokens_dim = calculate_tile_tokens_dim(num_tokens, num_experts, top_k)
|
|
|
|
# Validation checks
|
|
assert top_k <= num_experts
|
|
assert top_k <= 8
|
|
if (top_k_groups is not None) and (n_groups is not None) and (n_groups > 0):
|
|
assert top_k_groups <= 4
|
|
assert num_experts > n_groups
|
|
assert num_experts % n_groups == 0
|
|
assert num_experts % 4 == 0
|
|
assert top_k < (top_k_groups * num_experts / n_groups)
|
|
|
|
# Create test data based on routing method and quantization mode
|
|
# Different kernels have different dtype requirements for routing logits
|
|
if routing_method_type == RoutingMethodType.DeepSeekV3:
|
|
# DeepSeekV3 uses float for routing logits
|
|
expert_logits = torch.randn((num_tokens, num_experts), device="cuda").to(
|
|
torch.float
|
|
)
|
|
else:
|
|
# Other routing methods (Renormalize, RenormalizeNaive, Llama4) use bfloat16
|
|
expert_logits = torch.randn((num_tokens, num_experts), device="cuda").to(
|
|
torch.bfloat16
|
|
)
|
|
|
|
if routing_config["has_routing_bias"]:
|
|
routing_bias = torch.randn(num_experts, device="cuda", dtype=torch.bfloat16)
|
|
else:
|
|
routing_bias = None
|
|
|
|
hidden_states = 2 * torch.randn(
|
|
(num_tokens, hidden_size), device="cuda", dtype=torch.bfloat16
|
|
)
|
|
gemm1_weights = torch.randn(
|
|
(num_experts, 2 * intermediate_size, hidden_size),
|
|
device="cuda",
|
|
dtype=torch.bfloat16,
|
|
)
|
|
gemm2_weights = torch.randn(
|
|
(num_experts, hidden_size, intermediate_size),
|
|
device="cuda",
|
|
dtype=torch.bfloat16,
|
|
)
|
|
|
|
# Generate routing info
|
|
use_routing_scales_on_input = routing_method_type == RoutingMethodType.Llama4
|
|
|
|
if routing_method_type == RoutingMethodType.DeepSeekV3:
|
|
permute_info, scores = routing_reference_no_aux(
|
|
expert_logits,
|
|
routing_bias,
|
|
top_k,
|
|
n_groups,
|
|
top_k_groups,
|
|
routed_scaling,
|
|
padding,
|
|
use_routing_scales_on_input,
|
|
)
|
|
elif routing_method_type == RoutingMethodType.Renormalize:
|
|
permute_info, scores = routing_reference_renormalize(
|
|
expert_logits, top_k, num_experts, padding
|
|
)
|
|
elif routing_method_type == RoutingMethodType.RenormalizeNaive:
|
|
permute_info, scores = routing_reference_renormalize_naive(
|
|
expert_logits, top_k, num_experts, padding
|
|
)
|
|
elif routing_method_type == RoutingMethodType.TopK:
|
|
permute_info, scores = routing_reference_topk(
|
|
expert_logits, top_k, num_experts, padding
|
|
)
|
|
elif routing_method_type == RoutingMethodType.Llama4:
|
|
permute_info, scores = routing_reference_no_aux(
|
|
expert_logits,
|
|
routing_bias,
|
|
top_k,
|
|
n_groups,
|
|
top_k_groups,
|
|
routed_scaling,
|
|
padding,
|
|
use_routing_scales_on_input=True,
|
|
)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Routing method {routing_method_type} not implemented"
|
|
)
|
|
|
|
# 1. Quantize weights offline (static, done once) + compute global scale factors
|
|
weights_data = moe_impl.quantize_weights(
|
|
gemm1_weights, gemm2_weights, hidden_states
|
|
)
|
|
|
|
# 2. Quantize inputs at runtime (dynamic, done per inference) using pre-computed scales
|
|
inputs_data = moe_impl.quantize_inputs(
|
|
hidden_states, weights_data["hidden_states_scale_global"]
|
|
)
|
|
|
|
# 3. Combine quantized data
|
|
quant_data = {**weights_data, **inputs_data}
|
|
|
|
# Create arguments for reference computation
|
|
args = moe_args(
|
|
num_tokens,
|
|
num_experts,
|
|
hidden_size,
|
|
intermediate_size,
|
|
top_k,
|
|
padding,
|
|
quant_data["hidden_states"],
|
|
quant_data["hidden_states_scale"],
|
|
quant_data["hidden_states_scale_global"],
|
|
scores,
|
|
quant_data["gemm1_weights"],
|
|
quant_data["gemm1_scales"],
|
|
quant_data["gemm1_scales_global"],
|
|
quant_data["gemm2_weights"],
|
|
quant_data["gemm2_scales"],
|
|
quant_data["gemm2_scales_global"],
|
|
permute_info,
|
|
use_routing_scales_on_input,
|
|
gated_act_type,
|
|
)
|
|
|
|
# Compute reference output using the moe_impl
|
|
output_dequant_reference, args_dequant = moe_impl.compute_reference(args)
|
|
|
|
# Validate that reference computation succeeded
|
|
if output_dequant_reference is None:
|
|
pytest.fail("Reference computation failed to produce output")
|
|
|
|
# Compute actual output using the moe_impl
|
|
output_dequant_actual = moe_impl.compute_production(
|
|
args_dequant,
|
|
args,
|
|
expert_logits=expert_logits,
|
|
routing_bias=routing_bias,
|
|
hidden_states_orig=hidden_states,
|
|
gemm1_weights_orig=gemm1_weights,
|
|
gemm2_weights_orig=gemm2_weights,
|
|
n_groups=n_groups,
|
|
top_k_groups=top_k_groups,
|
|
routed_scaling=routed_scaling,
|
|
routing_method_type=routing_method_type,
|
|
tile_tokens_dim=tile_tokens_dim,
|
|
weight_processing=weight_processing,
|
|
enable_pdl=True,
|
|
)
|
|
|
|
# Compare outputs using moe_impl-specific tolerances
|
|
tolerances = moe_impl.get_tolerances()
|
|
check_accuracy(
|
|
output_dequant_reference,
|
|
output_dequant_actual,
|
|
atol=tolerances["atol"],
|
|
rtol=tolerances["rtol"],
|
|
percent=tolerances["percent"],
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|