sglang_v0.5.2/flashinfer_0.3.1/benchmarks/routines/moe.py

1616 lines
56 KiB
Python

from collections import defaultdict
from typing import Optional
import numpy as np
import torch
import flashinfer
from flashinfer.autotuner import autotune
from flashinfer.fused_moe import (
WeightLayout,
trtllm_fp4_block_scale_moe,
trtllm_fp8_block_scale_moe,
trtllm_fp8_per_tensor_scale_moe,
cutlass_fused_moe,
convert_to_block_layout,
)
from flashinfer import fp4_quantize, shuffle_matrix_a
from flashinfer.testing.utils import (
bench_gpu_time,
bench_gpu_time_with_cudagraph,
)
from .flashinfer_benchmark_utils import (
dtype_str_to_torch_dtype,
get_device,
print_perf_metrics,
)
def run_moe_test(args):
"""
Run a MOE test.
Args:
args: Parsed command line arguments containing test configuration
Returns:
dict: List of dictionaries containing performance results
"""
if args.routine == "trtllm_fp4_block_scale_moe":
return testTrtllmFp4BlockScaleMoe(args)
elif args.routine == "trtllm_fp8_block_scale_moe":
return testTrtllmFp8BlockScaleMoe(args)
elif args.routine == "trtllm_fp8_per_tensor_scale_moe":
return testTrtllmFp8PerTensorScaleMoe(args)
elif args.routine == "cutlass_fused_moe":
return testCutlassFusedMoe(args)
else:
raise ValueError(f"Unsupported routine: {args.routine}")
def parse_moe_args(line, parser):
"""
Parse command line arguments for MOE test configuration.
Args:
line: Command line arguments
parser: ArgumentParser object already populated with shared arguments
Returns:
Parsed argument namespace
"""
parser.add_argument(
"--num_tokens", type=int, required=True, help="Number of input tokens."
)
parser.add_argument(
"--hidden_size", type=int, required=True, help="Hidden dimension size."
)
parser.add_argument(
"--intermediate_size",
type=int,
required=True,
help="Intermediate dimension size.",
)
parser.add_argument(
"--num_experts", type=int, required=True, help="Total number of experts."
)
parser.add_argument(
"--top_k",
type=int,
required=True,
help="Number of experts to route to per token.",
)
parser.add_argument(
"--n_group",
type=int,
required=False,
default=None,
help="Number of expert groups (for DeepSeek routing). Only used with DeepSeekV3 routing method.",
)
parser.add_argument(
"--topk_group",
type=int,
required=False,
default=None,
help="Number of groups to consider for top-k routing. Only used with DeepSeekV3 routing method.",
)
parser.add_argument(
"--routed_scaling_factor",
type=float,
required=False,
default=2.5,
help="Scaling factor for routing.",
)
parser.add_argument(
"--local_expert_offset",
type=int,
required=False,
default=0,
help="Offset of local experts in global expert space.",
)
parser.add_argument(
"--local_num_experts",
type=int,
required=False,
default=None,
help="Number of experts handled by this device. Defaults to num_experts.",
)
parser.add_argument(
"--tile_tokens_dim",
type=int,
required=False,
default=8,
help="Tile dimension for tokens.",
)
parser.add_argument(
"--routing_method",
type=str,
required=False,
default="deepseek_v3",
choices=[
"renormalize",
"deepseek_v3",
"llama4",
"renormalize_naive",
"topk",
],
help=(
"Routing method: renormalize | deepseek_v3 | llama4 | renormalize_naive | topk."
),
)
parser.add_argument(
"--use_shuffled_weight",
action="store_true",
default=False,
help="Whether to use shuffled weight layout.",
)
parser.add_argument(
"--weight_layout",
type=int,
required=False,
default=0,
choices=[0, 1, 2],
help="Weight layout: 0=MajorK, 1=MajorMn, 2=BlockMajorK.",
)
parser.add_argument(
"--use_routing_bias",
action="store_true",
default=False,
help="Whether to use routing bias.",
)
parser.add_argument(
"--use_routing_scales_on_input",
action="store_true",
default=False,
help="Whether to use routing scales on input (for Llama4 routing).",
)
parser.add_argument(
"--input_dtype",
type=str,
required=False,
default="bfloat16",
help="Data type of the input hidden states.",
)
parser.add_argument(
"--weight_dtype",
type=str,
required=False,
default="bfloat16",
help="Data type of the weights (before quantization).",
)
parser.add_argument(
"--gated_act",
type=str,
required=False,
default="swiglu",
choices=["swiglu", "geglu"],
help="Type of gated activation function: swiglu | geglu.",
)
parser.add_argument(
"--autotune",
action="store_true",
default=False,
help=(
"Enable autotuner warmup for supported routines (trtllm_fp4_block_scale_moe and cutlass_fused_moe)."
),
)
# CUTLASS fused MoE specific
parser.add_argument(
"--cutlass_variant",
type=str,
required=False,
default="base",
choices=["base", "fp8", "nvfp4"],
help="Variant for cutlass_fused_moe benchmark: base (no quant), fp8 (per-tensor), nvfp4 (fp4 blockscale)",
)
parser.add_argument(
"--quantized_input",
action="store_true",
default=False,
help="Quantize input activations (only used for nvfp4).",
)
parser.add_argument(
"--tp_size",
type=int,
required=False,
default=1,
help="Tensor parallel size for cutlass_fused_moe.",
)
parser.add_argument(
"--tp_rank",
type=int,
required=False,
default=0,
help="Tensor parallel rank for cutlass_fused_moe.",
)
parser.add_argument(
"--ep_size",
type=int,
required=False,
default=1,
help="Expert parallel size for cutlass_fused_moe.",
)
parser.add_argument(
"--ep_rank",
type=int,
required=False,
default=0,
help="Expert parallel rank for cutlass_fused_moe.",
)
args = parser.parse_args(line)
# Normalize routing method (map string to internal int expected by kernels)
routing_method_name_to_type = {
"renormalize": 1,
"deepseek_v3": 2,
"llama4": 3,
"renormalize_naive": 4,
"topk": 5,
}
args.routing_method_type = routing_method_name_to_type[args.routing_method]
# Normalize gated act type (map string to internal int expected by kernels)
gated_act_name_to_type = {
"swiglu": 0,
"geglu": 1,
}
args.gated_act_type = gated_act_name_to_type[args.gated_act]
if args.verbose >= 1:
print(f"[INFO] {args = }")
return args
def create_trtllm_moe_test_data(
num_tokens: int,
hidden_size: int,
intermediate_size: int,
num_experts: int,
routing_method_type: int,
use_routing_bias: bool,
input_dtype: torch.dtype,
weight_dtype: torch.dtype,
device: torch.device,
moe_kernel_type: str = "fp8_per_tensor",
):
"""
Create test data for TensorRT-LLM fused MoE benchmarking (trtllm_*_moe APIs).
This helper prepares inputs for the TensorRT-LLM fused MoE kernels exposed via
flashinfer.fused_moe (e.g., trtllm_fp4_block_scale_moe, trtllm_fp8_block_scale_moe,
trtllm_fp8_per_tensor_scale_moe). It is NOT used for CUTLASS MoE benchmarks,
which construct their own inputs specific to the CUTLASS path.
Returns:
Tuple of tensors needed for trtllm fused MoE computation
"""
# Create routing logits - dtype depends on both routing method AND MOE kernel type
# Different MOE kernels have different routing_logits dtype requirements:
if moe_kernel_type == "fp8_block_scale":
# FP8 block scale MOE always expects float32 routing logits (line 333 in kernel_launcher.cu)
routing_logits = torch.randn(
(num_tokens, num_experts), device=device, dtype=torch.float32
)
elif moe_kernel_type == "fp8_per_tensor":
# FP8 per-tensor MOE dtype depends on use_routing_scales_on_input parameter
# For Llama4: use_routing_scales_on_input=True -> bfloat16
# For others: use_routing_scales_on_input=False -> float32
if routing_method_type == 3: # Llama4 uses routing scales on input
routing_logits = torch.randn(
(num_tokens, num_experts), device=device, dtype=torch.bfloat16
)
else:
routing_logits = torch.randn(
(num_tokens, num_experts), device=device, dtype=torch.float32
)
elif moe_kernel_type == "fp4_block_scale":
# FP4 block scale MOE follows the test pattern: float32 for DeepSeekV3, bfloat16 for others
if routing_method_type == 2: # DeepSeekV3 - uses float32
routing_logits = torch.randn(
(num_tokens, num_experts), device=device, dtype=torch.float32
)
else: # All other routing methods (Renormalize, RenormalizeNaive, Llama4) - use bfloat16
routing_logits = torch.randn(
(num_tokens, num_experts), device=device, dtype=torch.bfloat16
)
else:
raise ValueError(f"Unknown MOE kernel type: {moe_kernel_type}")
# Create routing bias if needed - always bfloat16
routing_bias = None
if use_routing_bias:
routing_bias = torch.randn(num_experts, device=device, dtype=torch.bfloat16)
# Create hidden states - always start with bfloat16 for proper quantization
hidden_states = 2 * torch.randn(
(num_tokens, hidden_size), device=device, dtype=torch.bfloat16
)
# Create weights - always start with bfloat16 for proper quantization
gemm1_weights = torch.randn(
(num_experts, 2 * intermediate_size, hidden_size),
device=device,
dtype=torch.bfloat16,
)
gemm2_weights = torch.randn(
(num_experts, hidden_size, intermediate_size),
device=device,
dtype=torch.bfloat16,
)
return routing_logits, routing_bias, hidden_states, gemm1_weights, gemm2_weights
def calculate_fp4_global_scale_factor(tensor):
"""Calculate global scale factor for FP4 quantization."""
# Calculate as a tensor on the same device
# Using the same formula as in test files: FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
tensor_amax = tensor.abs().max().to(torch.float32)
# FLOAT8_E4M3_MAX = 448, FLOAT4_E2M1_MAX = 6
global_scale = (448.0 * 6.0) / tensor_amax
return global_scale
def quant_fp4_simple(a, a_global_sf, use_ue8m0=False, is_sf_swizzled_layout=True):
"""
Simplified FP4 quantization for benchmarking.
In production, use the actual fp4_quantize function.
"""
sf_vec_size = 16
# Use the actual fp4_quantize function from flashinfer
a_fp4, a_sf = fp4_quantize(
a, a_global_sf, sf_vec_size, use_ue8m0, is_sf_swizzled_layout
)
return a_fp4, a_sf, a_global_sf
def quant_fp4_batches_simple(
a, num_experts, use_ue8m0=False, is_sf_swizzled_layout=True
):
"""Simplified FP4 batch quantization for benchmarking."""
quant_a = []
sfs = []
global_sfs = []
for i in range(num_experts):
# Calculate global scale factor (returns tensor)
a_global_sf = calculate_fp4_global_scale_factor(a[i])
a_fp4, a_sf, _ = quant_fp4_simple(
a[i], a_global_sf, use_ue8m0, is_sf_swizzled_layout
)
quant_a.append(a_fp4)
sfs.append(a_sf)
global_sfs.append(a_global_sf)
result_quant_a = torch.stack(quant_a)
result_sfs = torch.stack(sfs)
result_global_sfs = torch.stack(global_sfs)
return result_quant_a, result_sfs, result_global_sfs
def calculate_moe_tflops(
num_tokens: int,
hidden_size: int,
intermediate_size: int,
num_experts: int,
top_k: int,
time_ms: float,
) -> float:
"""
Calculate TFLOPS for MOE operation.
MOE computation involves:
1. First GEMM: [num_tokens, hidden_size] x [num_experts, hidden_size, 2*intermediate_size]
2. Activation function (SwiGLU gate)
3. Second GEMM: [num_tokens, intermediate_size] x [num_experts, intermediate_size, hidden_size]
For each token, we only compute for top_k experts.
"""
# FLOPS per token per expert (base calculation)
flops_per_token_per_expert = (
2 * hidden_size * 2 * intermediate_size # First GEMM
+ 2 * intermediate_size * hidden_size # Second GEMM
)
total_flops = num_tokens * top_k * flops_per_token_per_expert
tflops = total_flops / (time_ms * 1e-3) / 1e12 # Convert to TFLOPS
return tflops
def calculate_moe_bandwidth(
num_tokens: int,
hidden_size: int,
intermediate_size: int,
num_experts: int,
top_k: int,
time_ms: float,
input_dtype: torch.dtype,
weight_dtype: torch.dtype,
input_format: Optional[str] = None,
weight_format: Optional[str] = None,
routing_logits_dtype: Optional[torch.dtype] = torch.float32,
active_experts: Optional[int] = None,
) -> float:
"""
Calculate memory bandwidth for MOE operation in TB/sec.
Args:
input_format: Override for input representation ("fp8" or "fp4"); None uses dtype.itemsize
weight_format: Override for weight representation ("fp8" or "fp4"); None uses dtype.itemsize
routing_logits_dtype: Dtype for routing logits memory accounting (default float32)
"""
# Get effective byte sizes
def get_effective_bytes(dtype: torch.dtype, fmt: Optional[str]) -> float:
if fmt == "fp4":
return 0.5
if fmt == "fp8":
return 1.0
return dtype.itemsize
input_bytes_per_element = get_effective_bytes(input_dtype, input_format)
weight_bytes_per_element = get_effective_bytes(weight_dtype, weight_format)
# Input memory: hidden states + routing logits
# Note: routing logits dtype depends on kernel; pass in when known, default float32; None means excluded
routing_logits_bytes = (
0 if routing_logits_dtype is None else routing_logits_dtype.itemsize
)
input_bytes = (
# Count hidden states once; kernels typically reuse inputs for multiple experts
num_tokens * hidden_size * input_bytes_per_element
+ num_tokens * num_experts * routing_logits_bytes
)
# Weight memory (reuse weights across tokens by grouping tokens per expert)
# Assume each active expert's weights are read once per run.
weight_bytes_per_expert = (
2 * intermediate_size * hidden_size * weight_bytes_per_element # gemm1
+ hidden_size * intermediate_size * weight_bytes_per_element # gemm2
)
if active_experts is not None:
num_active_experts = active_experts
else:
num_active_experts = min(num_experts, top_k * num_tokens)
weight_bytes = num_active_experts * weight_bytes_per_expert
# Output memory (typically full precision)
output_bytes = num_tokens * hidden_size * input_dtype.itemsize
total_bytes = input_bytes + weight_bytes + output_bytes
tb_per_sec = total_bytes / (time_ms * 1e-3) / 1e12 # Convert to TB/sec
return tb_per_sec
def _compute_routing(router_logits: torch.Tensor, top_k: int):
routing_weights = torch.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.float()
return routing_weights, selected_experts
def _dynamic_per_tensor_fp8_quant(x: torch.Tensor):
fp8_max = torch.finfo(torch.float8_e4m3fn).max
x_max = x.abs().max().float().clamp(min=1e-6)
scale = x_max / fp8_max
inv_scale = 1.0 / scale
out = (x.float() * inv_scale).clamp(-fp8_max, fp8_max).to(torch.float8_e4m3fn)
return out, scale.view((1,))
def testTrtllmFp4BlockScaleMoe(args):
"""
Test trtllm_fp4_block_scale_moe API (TensorRT-LLM fused MoE).
This test:
1. Creates quantized FP4 weights and scales
2. Runs FP4 block scale MOE
3. Measures performance metrics (TFLOPS, TB/sec)
Args:
args: Parsed command line arguments containing test configuration
Returns:
dict: List of dictionaries containing performance results
"""
if args.verbose >= 1:
print("[INFO] Running testTrtllmFp4BlockScaleMoe")
print(f"[INFO] FlashInfer version: {flashinfer.__version__}")
device = get_device(args)
if args.generate_repro_command:
print(
f"[INFO] To reproduce this test case, run the following command: {args.repro_command}"
)
input_dtype = dtype_str_to_torch_dtype(args.input_dtype)
weight_dtype = dtype_str_to_torch_dtype(args.weight_dtype)
# Parse configuration
num_tokens = args.num_tokens
hidden_size = args.hidden_size
intermediate_size = args.intermediate_size
num_experts = args.num_experts
top_k = args.top_k
n_group = (
args.n_group
if hasattr(args, "n_group") and args.n_group is not None and args.n_group > 0
else None
)
topk_group = (
args.topk_group
if hasattr(args, "topk_group")
and args.topk_group is not None
and args.topk_group > 0
else None
)
routed_scaling_factor = (
args.routed_scaling_factor
if hasattr(args, "routed_scaling_factor")
and args.routed_scaling_factor is not None
else None
)
local_expert_offset = args.local_expert_offset
local_num_experts = args.local_num_experts or num_experts
tile_tokens_dim = args.tile_tokens_dim
routing_method_type = args.routing_method_type
use_shuffled_weight = args.use_shuffled_weight
weight_layout = args.weight_layout
is_cuda_graph_compatible = not args.no_cuda_graph
gated_act_type = args.gated_act_type
if args.verbose >= 1:
print(
f"[INFO] Configuration: tokens={num_tokens}, hidden={hidden_size}, "
f"intermediate={intermediate_size}, experts={num_experts}, top_k={top_k}"
)
# Create test data
routing_logits, routing_bias, hidden_states, gemm1_weights, gemm2_weights = (
create_trtllm_moe_test_data(
num_tokens,
hidden_size,
intermediate_size,
num_experts,
routing_method_type,
args.use_routing_bias,
input_dtype,
weight_dtype,
device,
moe_kernel_type="fp4_block_scale",
)
)
# For FP4, we need to properly quantize weights and create scales
use_ue8m0 = False
# Calculate global scale factor for hidden states
hidden_states_scale_global = calculate_fp4_global_scale_factor(hidden_states)
# Quantize weights using proper FP4 quantization
gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, gemm1_scales_global = (
quant_fp4_batches_simple(gemm1_weights, num_experts, use_ue8m0, True)
)
gemm2_weights_fp4_bytes, gemm2_scales_fp4_bytes, gemm2_scales_global = (
quant_fp4_batches_simple(gemm2_weights, num_experts, use_ue8m0, True)
)
# Quantize hidden states
hidden_states_fp4_bytes, hidden_states_scale_fp4_bytes, _ = quant_fp4_simple(
hidden_states, hidden_states_scale_global, use_ue8m0, True
)
# Reshape hidden states for the kernel (pack 2 FP4 values into 1 byte)
# Keep as uint8 format for FP4 packed data
hidden_states_fp4 = hidden_states_fp4_bytes.view(torch.uint8).reshape(
hidden_states.shape[0], hidden_states.shape[1] // 2
)
# Hidden-states scale for FP4 must be 2D: [num_tokens, hidden_size // 16]
hidden_states_scale_linear_fp4 = hidden_states_scale_fp4_bytes.view(
torch.float8_e4m3fn
)
# Ensure expected shape (16 elements per hidden value for NvFP4)
expected_scale_elems = (num_tokens * hidden_size) // 16
if hidden_states_scale_linear_fp4.numel() != expected_scale_elems:
if args.verbose >= 1:
print(
f"[INFO] Adjusting FP4 hidden_states_scale from {hidden_states_scale_linear_fp4.numel()} to {expected_scale_elems} elements"
)
hidden_states_scale_linear_fp4 = torch.ones(
expected_scale_elems, device=device, dtype=torch.float8_e4m3fn
)
hidden_states_scale_linear_fp4 = hidden_states_scale_linear_fp4.reshape(
num_tokens, hidden_size // 16
)
# Prepare weights for kernel
# For FP4 weights, keep them as uint8 (packed format) - don't convert to float8_e4m3fn
gemm1_weights_fp4 = gemm1_weights_fp4_bytes.view(torch.uint8).reshape(
num_experts, 2 * intermediate_size, hidden_size // 2
)
# Scale factors should be viewed as float8_e4m3fn
gemm1_weights_scale = gemm1_scales_fp4_bytes.view(torch.float8_e4m3fn).reshape(
num_experts, 2 * intermediate_size, hidden_size // 16
)
gemm2_weights_fp4 = gemm2_weights_fp4_bytes.view(torch.uint8).reshape(
num_experts, hidden_size, intermediate_size // 2
)
gemm2_weights_scale = gemm2_scales_fp4_bytes.view(torch.float8_e4m3fn).reshape(
num_experts, hidden_size, intermediate_size // 16
)
# Optional parameters for FP4 (using None for simplicity in benchmarking)
gemm1_bias = None
gemm1_alpha = None
gemm1_beta = None
gemm1_clamp_limit = None
gemm2_bias = None
# Create scale scalars (simplified - in practice these would be computed)
output1_scale_scalar = torch.ones(
local_num_experts, device=device, dtype=torch.float32
)
output1_scale_gate_scalar = torch.ones(
local_num_experts, device=device, dtype=torch.float32
)
output2_scale_scalar = torch.ones(
local_num_experts, device=device, dtype=torch.float32
)
if args.verbose >= 2:
print(f"[VVERBOSE] routing_logits.shape = {routing_logits.shape}")
print(f"[VVERBOSE] hidden_states.shape = {hidden_states.shape}")
print(f"[VVERBOSE] gemm1_weights_fp4.shape = {gemm1_weights_fp4.shape}")
print(f"[VVERBOSE] gemm2_weights_fp4.shape = {gemm2_weights_fp4.shape}")
def run_fp4_moe():
return trtllm_fp4_block_scale_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=hidden_states_fp4,
hidden_states_scale=hidden_states_scale_linear_fp4,
gemm1_weights=gemm1_weights_fp4,
gemm1_weights_scale=gemm1_weights_scale,
gemm1_bias=gemm1_bias,
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=gemm1_clamp_limit,
gemm2_weights=gemm2_weights_fp4,
gemm2_weights_scale=gemm2_weights_scale,
gemm2_bias=gemm2_bias,
output1_scale_scalar=output1_scale_scalar,
output1_scale_gate_scalar=output1_scale_gate_scalar,
output2_scale_scalar=output2_scale_scalar,
num_experts=num_experts,
top_k=top_k,
n_group=n_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=local_expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling_factor,
tile_tokens_dim=tile_tokens_dim,
routing_method_type=routing_method_type,
gated_act_type=gated_act_type,
do_finalize=True,
)
backend = "trtllm"
# Optional autotune warmup (supported for FP4 TRTLlm fused MoE)
if getattr(args, "autotune", False):
warmup_iters = (
args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10
)
backend = "trtllm_autotune"
if args.verbose >= 1:
print(
f"[INFO] Autotune warmup for FP4 block scale MoE: {warmup_iters} iters"
)
with autotune(True):
for _ in range(warmup_iters):
run_fp4_moe()
# Benchmark timing
if is_cuda_graph_compatible:
times = bench_gpu_time_with_cudagraph(
fn=run_fp4_moe,
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
num_iters_within_graph=20,
l2_flush=True,
l2_flush_size_mb=256,
l2_flush_device=device,
sleep_after_run=False,
)
else:
times = bench_gpu_time(
fn=run_fp4_moe,
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
l2_flush=True,
l2_flush_size_mb=256,
l2_flush_device=device,
sleep_after_run=False,
)
# Compute performance metrics
median_time = np.median(times)
std_time = np.std(times)
tflops = calculate_moe_tflops(
num_tokens, hidden_size, intermediate_size, num_experts, top_k, median_time
)
tb_per_sec = calculate_moe_bandwidth(
num_tokens,
hidden_size,
intermediate_size,
num_experts,
top_k,
median_time,
input_dtype,
weight_dtype,
input_format="fp4",
weight_format="fp4",
routing_logits_dtype=routing_logits.dtype,
)
print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec)
res = []
if args.output_path is not None:
cur_res = defaultdict(str)
cur_res["routine"] = args.routine
cur_res["median_time"] = median_time
cur_res["std_time"] = std_time
cur_res["tflops"] = tflops
cur_res["tb_per_sec"] = tb_per_sec
cur_res["backend"] = backend
cur_res["num_tokens"] = num_tokens
cur_res["hidden_size"] = hidden_size
cur_res["intermediate_size"] = intermediate_size
cur_res["num_experts"] = num_experts
cur_res["top_k"] = top_k
cur_res["n_group"] = n_group
cur_res["topk_group"] = topk_group
cur_res["routed_scaling_factor"] = routed_scaling_factor
cur_res["local_expert_offset"] = local_expert_offset
cur_res["local_num_experts"] = local_num_experts
cur_res["tile_tokens_dim"] = tile_tokens_dim
cur_res["routing_method"] = args.routing_method
cur_res["use_shuffled_weight"] = use_shuffled_weight
cur_res["weight_layout"] = weight_layout
cur_res["use_routing_bias"] = args.use_routing_bias
cur_res["use_routing_scales_on_input"] = args.use_routing_scales_on_input
cur_res["input_dtype"] = input_dtype
cur_res["weight_dtype"] = weight_dtype
cur_res["gated_act"] = args.gated_act
res.append(cur_res)
return res
def testCutlassFusedMoe(args):
"""
Benchmark cutlass_fused_moe (CUTLASS MoE) with variants mirroring tests in tests/test_trtllm_cutlass_fused_moe.py
Variants:
- base: no quantization
- fp8: per-tensor fp8 for weights and activation scale
- nvfp4: FP4 block-scale weights, optional quantized input
Supports TP/EP via tp_size/tp_rank and ep_size/ep_rank.
"""
if args.verbose >= 1:
print("[INFO] Running testCutlassFusedMoe")
print(f"[INFO] FlashInfer version: {flashinfer.__version__}")
device = get_device(args)
if args.generate_repro_command:
print(
f"[INFO] To reproduce this test case, run the following command: {args.repro_command}"
)
input_dtype = dtype_str_to_torch_dtype(args.input_dtype)
# Shapes
num_tokens = args.num_tokens
hidden_size = args.hidden_size
intermediate_size = args.intermediate_size
num_experts = args.num_experts
top_k = args.top_k
tp_size = getattr(args, "tp_size", 1)
tp_rank = getattr(args, "tp_rank", 0)
ep_size = getattr(args, "ep_size", 1)
ep_rank = getattr(args, "ep_rank", 0)
is_cuda_graph_compatible = not args.no_cuda_graph
# Create base tensors
torch.manual_seed(args.random_seed)
x = torch.randn(num_tokens, hidden_size, dtype=input_dtype, device=device)
w31_weight = (
torch.randn(
num_experts,
2 * intermediate_size,
hidden_size,
dtype=input_dtype,
device=device,
)
/ 10
)
w2_weight = (
torch.randn(
num_experts,
hidden_size,
intermediate_size,
dtype=input_dtype,
device=device,
)
/ 10
)
# Routing
router_logits = torch.randn(
num_tokens, num_experts, dtype=input_dtype, device=device
)
routing_weights, selected_experts = _compute_routing(router_logits, top_k)
if args.verbose >= 2:
print(f"[VVERBOSE] x.shape = {x.shape}")
print(f"[VVERBOSE] w31_weight.shape = {w31_weight.shape}")
print(f"[VVERBOSE] w2_weight.shape = {w2_weight.shape}")
# Build local weights per EP/TP like tests do
experts_per_rank = num_experts // max(ep_size, 1)
expert_start = ep_rank * experts_per_rank
expert_end = expert_start + experts_per_rank
w31_ep = w31_weight[expert_start:expert_end, :]
w2_ep = w2_weight[expert_start:expert_end, :]
def build_tp_shards(w31_ep_tensor: torch.Tensor, w2_ep_tensor: torch.Tensor):
if tp_size <= 1:
return w31_ep_tensor, w2_ep_tensor
# Split w31 into w3 and w1 along intermediate dim
w3_weight, w1_weight = torch.chunk(w31_ep_tensor, 2, dim=1)
shard = intermediate_size // tp_size
start = tp_rank * shard
end = start + shard
w3_local = w3_weight[:, start:end, :]
w1_local = w1_weight[:, start:end, :]
w31_local = torch.cat([w3_local, w1_local], dim=1)
w2_local = w2_ep_tensor[:, :, start:end]
return w31_local.contiguous(), w2_local.contiguous()
w31_local, w2_local = build_tp_shards(w31_ep, w2_ep)
# Prepare variant-specific inputs (outside of the timed/captured region)
variant = getattr(args, "cutlass_variant", "base")
out = torch.empty_like(x)
if variant == "base":
def run_cutlass():
return cutlass_fused_moe(
x,
selected_experts.to(torch.int),
routing_weights,
w31_local,
w2_local,
input_dtype,
tp_size=tp_size,
tp_rank=tp_rank,
ep_size=ep_size,
ep_rank=ep_rank,
quant_scales=None,
output=out,
)
elif variant == "fp8":
# Per-tensor FP8 for weights and activation scale
w31_weight_fp8 = torch.empty_like(w31_local, dtype=torch.float8_e4m3fn)
w2_weight_fp8 = torch.empty_like(w2_local, dtype=torch.float8_e4m3fn)
local_num_experts = w31_local.shape[0]
w31_scales = torch.empty(local_num_experts, 2, dtype=input_dtype, device=device)
w2_scales = torch.empty(local_num_experts, 1, dtype=input_dtype, device=device)
# Quantize weights per expert
for expert_id in range(local_num_experts):
w31_expert = w31_local[expert_id]
w2_expert = w2_local[expert_id]
w31_q, s31 = _dynamic_per_tensor_fp8_quant(w31_expert)
w2_q, s2 = _dynamic_per_tensor_fp8_quant(w2_expert)
w31_weight_fp8[expert_id].copy_(w31_q)
w2_weight_fp8[expert_id].copy_(w2_q)
# Store the same scalar twice to mimic test layout (avoid torch.tensor())
w31_scales[expert_id, 0] = s31.to(dtype=input_dtype, device=device)
w31_scales[expert_id, 1] = s31.to(dtype=input_dtype, device=device)
w2_scales[expert_id, 0] = s2.to(dtype=input_dtype, device=device)
x_quant, hidden_states_scale = _dynamic_per_tensor_fp8_quant(x)
hidden_states_scale_scalar = hidden_states_scale[0].to(device)
# Note: follow tests quant_scales format
# [w1_scales * hidden_states_scale, 1.0, 1.0 * w2_scales, hidden_states_scale]
w1_scales = w31_scales[:, 1]
one_const = torch.ones((), device=device)
quant_scales = [
(w1_scales * hidden_states_scale_scalar).float().squeeze(),
one_const,
w2_scales.squeeze().float(),
hidden_states_scale_scalar,
]
def run_cutlass():
return cutlass_fused_moe(
x_quant,
selected_experts.to(torch.int),
routing_weights,
w31_weight_fp8,
w2_weight_fp8,
input_dtype,
tp_size=tp_size,
tp_rank=tp_rank,
ep_size=ep_size,
ep_rank=ep_rank,
quant_scales=quant_scales,
output=out,
)
elif variant == "nvfp4":
# NVFP4: FP4 block-scale weights, optional quantized input
FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
def round_up(x_val, y):
return (x_val + y - 1) // y * y
e = w31_local.shape[0]
n = w2_local.shape[2] # local intermediate size after TP
k = hidden_size
quant_blocksize = 16
# Weight quantization buffers
w1_q = torch.empty((e, 2 * n, k // 2), device=device, dtype=torch.uint8)
w2_q = torch.empty((e, k, n // 2), device=device, dtype=torch.uint8)
w1_blockscale = torch.empty(
(e, round_up(2 * n, 128), round_up(k // quant_blocksize, 4)),
device=device,
dtype=torch.float8_e4m3fn,
)
w2_blockscale = torch.empty(
(e, round_up(k, 128), round_up(n // quant_blocksize, 4)),
device=device,
dtype=torch.float8_e4m3fn,
)
w1_gs = torch.empty((e,), device=device, dtype=torch.float32)
w2_gs = torch.empty((e,), device=device, dtype=torch.float32)
# Quantize from local shards
for expert in range(e):
w1_src = w31_local[expert]
# w31 layout is [2n, k]; w2 layout is [k, n]
w2_src = w2_local[expert].contiguous() # [hidden_size, n]
w1_amax = torch.abs(w1_src).max().to(torch.float32)
w2_amax = torch.abs(w2_src).max().to(torch.float32)
w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
w1_q[expert], w1_blockscale[expert] = fp4_quantize(w1_src, w1_gs[expert])
w2_q[expert], w2_blockscale[expert] = fp4_quantize(w2_src, w2_gs[expert])
a1_gs = torch.ones((), device=device, dtype=torch.float32)
a2_gs = torch.ones((), device=device, dtype=torch.float32)
hidden_states = x
input_sf = None
if getattr(args, "quantized_input", False):
hidden_states, input_sf = fp4_quantize(x, a1_gs)
quant_scales = [
a1_gs,
w1_blockscale.view(torch.int32),
1.0 / (a1_gs * w1_gs),
a2_gs,
w2_blockscale.view(torch.int32),
1.0 / (a2_gs * w2_gs),
]
def run_cutlass():
return cutlass_fused_moe(
hidden_states,
selected_experts.to(torch.int),
routing_weights,
w1_q.contiguous().view(torch.long),
w2_q.contiguous().view(torch.long),
input_dtype,
tp_size=tp_size,
tp_rank=tp_rank,
ep_size=ep_size,
ep_rank=ep_rank,
quant_scales=quant_scales,
input_sf=input_sf,
output=out,
)
else:
raise ValueError(f"Unknown cutlass_variant: {variant}")
backend = "cutlass"
# Optional autotune warmup (supported for CUTLASS fused MoE)
if getattr(args, "autotune", False):
warmup_iters = (
args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10
)
backend = "cutlass_autotune"
if args.verbose >= 1:
print(f"[INFO] Autotune warmup for CUTLASS fused MoE: {warmup_iters} iters")
with autotune(True):
for _ in range(warmup_iters):
run_cutlass()
# Measure
if is_cuda_graph_compatible:
times = bench_gpu_time_with_cudagraph(
fn=run_cutlass,
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
num_iters_within_graph=20,
l2_flush=True,
l2_flush_size_mb=256,
l2_flush_device=device,
sleep_after_run=False,
)
else:
times = bench_gpu_time(
fn=run_cutlass,
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
l2_flush=True,
l2_flush_size_mb=256,
l2_flush_device=device,
sleep_after_run=False,
)
median_time = np.median(times)
std_time = np.std(times)
tflops = calculate_moe_tflops(
num_tokens, hidden_size, intermediate_size, num_experts, top_k, median_time
)
tb_per_sec = calculate_moe_bandwidth(
num_tokens,
hidden_size,
intermediate_size,
num_experts,
top_k,
median_time,
input_dtype,
input_dtype,
input_format=(
"fp8"
if variant == "fp8"
else (
"fp4"
if (variant == "nvfp4" and getattr(args, "quantized_input", False))
else None
)
),
weight_format=(
"fp8" if variant == "fp8" else ("fp4" if variant == "nvfp4" else None)
),
routing_logits_dtype=router_logits.dtype,
active_experts=int(selected_experts.unique().numel()),
)
print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec)
res = []
if args.output_path is not None:
cur_res = defaultdict(str)
cur_res["routine"] = args.routine
cur_res["median_time"] = median_time
cur_res["std_time"] = std_time
cur_res["tflops"] = tflops
cur_res["tb_per_sec"] = tb_per_sec
cur_res["backend"] = backend
cur_res["num_tokens"] = num_tokens
cur_res["hidden_size"] = hidden_size
cur_res["intermediate_size"] = intermediate_size
cur_res["num_experts"] = num_experts
cur_res["top_k"] = top_k
# Routing method/weight layout not applicable; leave defaults
cur_res["use_shuffled_weight"] = False
cur_res["weight_layout"] = 0
cur_res["use_routing_scales_on_input"] = False
cur_res["input_dtype"] = input_dtype
cur_res["weight_dtype"] = input_dtype
# CUTLASS fused MoE specific
cur_res["cutlass_variant"] = variant
cur_res["quantized_input"] = args.quantized_input
cur_res["tp_size"] = tp_size
cur_res["tp_rank"] = tp_rank
cur_res["ep_size"] = ep_size
cur_res["ep_rank"] = ep_rank
res.append(cur_res)
return res
def testTrtllmFp8BlockScaleMoe(args):
"""
Test trtllm_fp8_block_scale_moe API (TensorRT-LLM fused MoE).
This test:
1. Creates quantized FP8 weights and block scales
2. Runs FP8 block scale MOE
3. Measures performance metrics (TFLOPS, TB/sec)
Args:
args: Parsed command line arguments containing test configuration
Returns:
dict: List of dictionaries containing performance results
"""
if args.verbose >= 1:
print("[INFO] Running testTrtllmFp8BlockScaleMoe")
print(f"[INFO] FlashInfer version: {flashinfer.__version__}")
device = get_device(args)
if args.generate_repro_command:
print(
f"[INFO] To reproduce this test case, run the following command: {args.repro_command}"
)
input_dtype = dtype_str_to_torch_dtype(args.input_dtype)
weight_dtype = dtype_str_to_torch_dtype(args.weight_dtype)
# Parse configuration
num_tokens = args.num_tokens
hidden_size = args.hidden_size
intermediate_size = args.intermediate_size
num_experts = args.num_experts
top_k = args.top_k
n_group = (
args.n_group
if hasattr(args, "n_group") and args.n_group is not None and args.n_group > 0
else None
)
topk_group = (
args.topk_group
if hasattr(args, "topk_group")
and args.topk_group is not None
and args.topk_group > 0
else None
)
routed_scaling_factor = (
args.routed_scaling_factor
if hasattr(args, "routed_scaling_factor")
and args.routed_scaling_factor is not None
else None
)
local_expert_offset = args.local_expert_offset
local_num_experts = args.local_num_experts or num_experts
tile_tokens_dim = args.tile_tokens_dim
routing_method_type = args.routing_method_type
use_shuffled_weight = args.use_shuffled_weight
weight_layout = args.weight_layout
is_cuda_graph_compatible = not args.no_cuda_graph
if args.verbose >= 1:
print(
f"[INFO] Configuration: tokens={num_tokens}, hidden={hidden_size}, "
f"intermediate={intermediate_size}, experts={num_experts}, top_k={top_k}"
)
# Create test data
routing_logits, routing_bias, hidden_states, gemm1_weights, gemm2_weights = (
create_trtllm_moe_test_data(
num_tokens,
hidden_size,
intermediate_size,
num_experts,
routing_method_type,
args.use_routing_bias,
input_dtype,
weight_dtype,
device,
moe_kernel_type="fp8_block_scale",
)
)
# For FP8 block scale, create quantized weights and block scales
# Quantize to FP8
gemm1_weights_fp8 = gemm1_weights.to(torch.float8_e4m3fn)
gemm2_weights_fp8 = gemm2_weights.to(torch.float8_e4m3fn)
# Optionally shuffle weights and convert to BlockMajorK layout to match kernel expectation
if use_shuffled_weight:
# This tile size follows test implementations
epilogue_tile_m = 64
gemm1_weights_fp8_shuffled = []
gemm2_weights_fp8_shuffled = []
for i in range(num_experts):
tmp_w1 = shuffle_matrix_a(
gemm1_weights_fp8[i].view(torch.uint8), epilogue_tile_m
)
tmp_w2 = shuffle_matrix_a(
gemm2_weights_fp8[i].view(torch.uint8), epilogue_tile_m
)
if weight_layout == WeightLayout.BlockMajorK:
block_k = 128
tmp_w1 = convert_to_block_layout(tmp_w1, block_k)
tmp_w2 = convert_to_block_layout(tmp_w2, block_k)
gemm1_weights_fp8_shuffled.append(tmp_w1)
gemm2_weights_fp8_shuffled.append(tmp_w2)
kernel_gemm1_weights = torch.stack(gemm1_weights_fp8_shuffled).view(
torch.float8_e4m3fn
)
kernel_gemm2_weights = torch.stack(gemm2_weights_fp8_shuffled).view(
torch.float8_e4m3fn
)
else:
kernel_gemm1_weights = gemm1_weights_fp8
kernel_gemm2_weights = gemm2_weights_fp8
# Create block scale tensors for hidden states and weights (use float32 for scales)
# TensorRT-LLM FP8 block-scale expects hidden_states_scale shape [hidden_size // 128, num_tokens]
hidden_states_scale = 2.0 * torch.ones(
(hidden_size // 128, num_tokens), device=device, dtype=torch.float32
)
gemm1_weights_scale = 2.0 * torch.ones(
(num_experts, 2 * intermediate_size // 128, hidden_size // 128),
device=device,
dtype=torch.float32,
)
gemm2_weights_scale = 2.0 * torch.ones(
(num_experts, hidden_size // 128, intermediate_size // 128),
device=device,
dtype=torch.float32,
)
if args.verbose >= 2:
print(f"[VVERBOSE] routing_logits.shape = {routing_logits.shape}")
print(f"[VVERBOSE] hidden_states.shape = {hidden_states.shape}")
print(f"[VVERBOSE] gemm1_weights_fp8.shape = {gemm1_weights_fp8.shape}")
print(f"[VVERBOSE] gemm2_weights_fp8.shape = {gemm2_weights_fp8.shape}")
# Match test heuristic for tile_tokens_dim when using BlockMajorK
if use_shuffled_weight and weight_layout == WeightLayout.BlockMajorK:
def _next_pow2(x: int) -> int:
x = max(1, x)
x -= 1
x |= x >> 1
x |= x >> 2
x |= x >> 4
x |= x >> 8
x |= x >> 16
return x + 1
tokens_per_expert = max(1, (num_tokens * top_k) // max(local_num_experts, 1))
suggested_tile = min(max(_next_pow2(tokens_per_expert), 8), 64)
if suggested_tile != tile_tokens_dim and args.verbose >= 1:
print(
f"[INFO] Overriding tile_tokens_dim {tile_tokens_dim} -> {suggested_tile} for BlockMajorK"
)
tile_tokens_dim = suggested_tile
def run_fp8_block_moe():
# Quantize hidden states to FP8 for block scale MOE
hidden_states_fp8 = hidden_states.to(torch.float8_e4m3fn)
# Note: FP8 block scale MOE expects int64_t for n_group/topk_group, not Optional[int64_t]
# So we convert None to 0 to indicate "no groups" mode
return trtllm_fp8_block_scale_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=hidden_states_fp8,
hidden_states_scale=hidden_states_scale,
gemm1_weights=kernel_gemm1_weights,
gemm1_weights_scale=gemm1_weights_scale,
gemm2_weights=kernel_gemm2_weights,
gemm2_weights_scale=gemm2_weights_scale,
num_experts=num_experts,
top_k=top_k,
n_group=n_group if n_group is not None else 0,
topk_group=topk_group if topk_group is not None else 0,
intermediate_size=intermediate_size,
local_expert_offset=local_expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling_factor,
tile_tokens_dim=tile_tokens_dim,
routing_method_type=routing_method_type,
use_shuffled_weight=use_shuffled_weight,
weight_layout=weight_layout,
enable_pdl=True,
)
# Benchmark timing
if is_cuda_graph_compatible:
times = bench_gpu_time_with_cudagraph(
fn=run_fp8_block_moe,
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
num_iters_within_graph=20,
l2_flush=True,
l2_flush_size_mb=256,
l2_flush_device=device,
sleep_after_run=False,
)
else:
times = bench_gpu_time(
fn=run_fp8_block_moe,
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
l2_flush=True,
l2_flush_size_mb=256,
l2_flush_device=device,
sleep_after_run=False,
)
# Compute performance metrics
median_time = np.median(times)
std_time = np.std(times)
tflops = calculate_moe_tflops(
num_tokens, hidden_size, intermediate_size, num_experts, top_k, median_time
)
tb_per_sec = calculate_moe_bandwidth(
num_tokens,
hidden_size,
intermediate_size,
num_experts,
top_k,
median_time,
input_dtype,
weight_dtype,
input_format="fp8",
weight_format="fp8",
routing_logits_dtype=routing_logits.dtype,
)
backend = "trtllm"
print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec)
res = []
if args.output_path is not None:
cur_res = defaultdict(str)
cur_res["routine"] = args.routine
cur_res["median_time"] = median_time
cur_res["std_time"] = std_time
cur_res["tflops"] = tflops
cur_res["tb_per_sec"] = tb_per_sec
cur_res["backend"] = backend
cur_res["num_tokens"] = num_tokens
cur_res["hidden_size"] = hidden_size
cur_res["intermediate_size"] = intermediate_size
cur_res["num_experts"] = num_experts
cur_res["top_k"] = top_k
cur_res["n_group"] = n_group
cur_res["topk_group"] = topk_group
cur_res["routed_scaling_factor"] = routed_scaling_factor
cur_res["local_expert_offset"] = local_expert_offset
cur_res["local_num_experts"] = local_num_experts
cur_res["tile_tokens_dim"] = tile_tokens_dim
cur_res["routing_method"] = args.routing_method
cur_res["use_shuffled_weight"] = use_shuffled_weight
cur_res["weight_layout"] = weight_layout
cur_res["use_routing_bias"] = args.use_routing_bias
cur_res["use_routing_scales_on_input"] = args.use_routing_scales_on_input
cur_res["input_dtype"] = input_dtype
cur_res["weight_dtype"] = weight_dtype
res.append(cur_res)
return res
def testTrtllmFp8PerTensorScaleMoe(args):
"""
Test trtllm_fp8_per_tensor_scale_moe API (TensorRT-LLM fused MoE).
This test:
1. Creates quantized FP8 weights and per-tensor scales
2. Runs FP8 per-tensor scale MOE
3. Measures performance metrics (TFLOPS, TB/sec)
Args:
args: Parsed command line arguments containing test configuration
Returns:
dict: List of dictionaries containing performance results
"""
if args.verbose >= 1:
print("[INFO] Running testTrtllmFp8PerTensorScaleMoe")
print(f"[INFO] FlashInfer version: {flashinfer.__version__}")
device = get_device(args)
if args.generate_repro_command:
print(
f"[INFO] To reproduce this test case, run the following command: {args.repro_command}"
)
input_dtype = dtype_str_to_torch_dtype(args.input_dtype)
weight_dtype = dtype_str_to_torch_dtype(args.weight_dtype)
# Parse configuration
num_tokens = args.num_tokens
hidden_size = args.hidden_size
intermediate_size = args.intermediate_size
num_experts = args.num_experts
top_k = args.top_k
n_group = (
args.n_group
if hasattr(args, "n_group") and args.n_group is not None and args.n_group > 0
else None
)
topk_group = (
args.topk_group
if hasattr(args, "topk_group")
and args.topk_group is not None
and args.topk_group > 0
else None
)
routed_scaling_factor = (
args.routed_scaling_factor
if hasattr(args, "routed_scaling_factor")
and args.routed_scaling_factor is not None
else None
)
local_expert_offset = args.local_expert_offset
local_num_experts = args.local_num_experts or num_experts
tile_tokens_dim = args.tile_tokens_dim
routing_method_type = args.routing_method_type
use_routing_scales_on_input = args.use_routing_scales_on_input
is_cuda_graph_compatible = not args.no_cuda_graph
if args.verbose >= 1:
print(
f"[INFO] Configuration: tokens={num_tokens}, hidden={hidden_size}, "
f"intermediate={intermediate_size}, experts={num_experts}, top_k={top_k}"
)
# Create test data
routing_logits, routing_bias, hidden_states, gemm1_weights, gemm2_weights = (
create_trtllm_moe_test_data(
num_tokens,
hidden_size,
intermediate_size,
num_experts,
routing_method_type,
args.use_routing_bias,
input_dtype,
weight_dtype,
device,
moe_kernel_type="fp8_per_tensor",
)
)
# For FP8 per-tensor scale, create quantized weights and per-tensor scales
# Quantize to FP8
gemm1_weights_fp8 = gemm1_weights.to(torch.float8_e4m3fn)
gemm2_weights_fp8 = gemm2_weights.to(torch.float8_e4m3fn)
# Quantize hidden states to FP8 for per-tensor scale
hidden_states_fp8 = hidden_states.to(torch.float8_e4m3fn)
# Create per-tensor scale scalars
output1_scales_scalar = torch.ones(
local_num_experts, device=device, dtype=torch.float32
)
output1_scales_gate_scalar = torch.ones(
local_num_experts, device=device, dtype=torch.float32
)
output2_scales_scalar = torch.ones(
local_num_experts, device=device, dtype=torch.float32
)
if args.verbose >= 2:
print(f"[VVERBOSE] routing_logits.shape = {routing_logits.shape}")
print(f"[VVERBOSE] hidden_states.shape = {hidden_states.shape}")
print(f"[VVERBOSE] gemm1_weights_fp8.shape = {gemm1_weights_fp8.shape}")
print(f"[VVERBOSE] gemm2_weights_fp8.shape = {gemm2_weights_fp8.shape}")
def run_fp8_per_tensor_moe():
# Note: FP8 per-tensor MOE expects int64_t for n_group/topk_group, not Optional[int64_t]
# So we convert None to 0 to indicate "no groups" mode
return trtllm_fp8_per_tensor_scale_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=hidden_states_fp8,
gemm1_weights=gemm1_weights_fp8,
output1_scales_scalar=output1_scales_scalar,
output1_scales_gate_scalar=output1_scales_gate_scalar,
gemm2_weights=gemm2_weights_fp8,
output2_scales_scalar=output2_scales_scalar,
num_experts=num_experts,
top_k=top_k,
n_group=n_group if n_group is not None else 0,
topk_group=topk_group if topk_group is not None else 0,
intermediate_size=intermediate_size,
local_expert_offset=local_expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling_factor,
use_routing_scales_on_input=use_routing_scales_on_input,
tile_tokens_dim=tile_tokens_dim,
routing_method_type=routing_method_type,
)
# Benchmark timing
if is_cuda_graph_compatible:
times = bench_gpu_time_with_cudagraph(
fn=run_fp8_per_tensor_moe,
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
num_iters_within_graph=20,
l2_flush=True,
l2_flush_size_mb=256,
l2_flush_device=device,
sleep_after_run=False,
)
else:
times = bench_gpu_time(
fn=run_fp8_per_tensor_moe,
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
l2_flush=True,
l2_flush_size_mb=256,
l2_flush_device=device,
sleep_after_run=False,
)
# Compute performance metrics
median_time = np.median(times)
std_time = np.std(times)
tflops = calculate_moe_tflops(
num_tokens, hidden_size, intermediate_size, num_experts, top_k, median_time
)
tb_per_sec = calculate_moe_bandwidth(
num_tokens,
hidden_size,
intermediate_size,
num_experts,
top_k,
median_time,
input_dtype,
weight_dtype,
input_format="fp8",
weight_format="fp8",
routing_logits_dtype=routing_logits.dtype,
)
backend = "trtllm"
print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec)
res = []
if args.output_path is not None:
cur_res = defaultdict(str)
cur_res["routine"] = args.routine
cur_res["median_time"] = median_time
cur_res["std_time"] = std_time
cur_res["tflops"] = tflops
cur_res["tb_per_sec"] = tb_per_sec
cur_res["backend"] = backend
cur_res["num_tokens"] = num_tokens
cur_res["hidden_size"] = hidden_size
cur_res["intermediate_size"] = intermediate_size
cur_res["num_experts"] = num_experts
cur_res["top_k"] = top_k
cur_res["n_group"] = n_group
cur_res["topk_group"] = topk_group
cur_res["routed_scaling_factor"] = routed_scaling_factor
cur_res["local_expert_offset"] = local_expert_offset
cur_res["local_num_experts"] = local_num_experts
cur_res["tile_tokens_dim"] = tile_tokens_dim
cur_res["routing_method"] = args.routing_method
cur_res["use_routing_bias"] = args.use_routing_bias
cur_res["use_routing_scales_on_input"] = use_routing_scales_on_input
cur_res["input_dtype"] = input_dtype
cur_res["weight_dtype"] = weight_dtype
res.append(cur_res)
return res