sglang_v0.5.2/flashinfer_0.3.1/tests/test_trtllm_cutlass_fused_m...

1354 lines
44 KiB
Python

"""
Copyright (c) 2025 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import pytest
import torch
from torch.nn import functional as F
import flashinfer.fused_moe as fused_moe
from flashinfer import (
fp4_quantize,
mxfp4_dequantize,
mxfp4_quantize,
mxfp8_dequantize_host,
mxfp8_quantize,
mxfp4_dequantize_host,
)
FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
FP8_DTYPE = torch.float8_e4m3fn
def dynamic_per_tensor_fp8_quant(x: torch.tensor) -> tuple[torch.tensor, torch.tensor]:
fp8_traits_max = FLOAT8_E4M3_MAX
fp8_traits_min = -FLOAT8_E4M3_MAX
fp8_max = torch.tensor(fp8_traits_max).float()
one = torch.tensor(1.0).float()
x_max = x.abs().max().float()
scale = x_max / fp8_max
iscale = one / scale
out = (x.float() * iscale).clamp(fp8_traits_min, fp8_traits_max).to(FP8_DTYPE)
return out, scale.view((1,))
def gen_tensor(shape, dtype, stype=None, scale=1.0):
x = torch.randn(*shape, dtype=dtype).cuda() * scale
return x.to(stype) if stype else x
def cast_to_representable(x):
x_q, x_scale = dynamic_per_tensor_fp8_quant(x)
x = x_q.to(x.dtype) * x_scale.to(x.dtype)
return x
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
m_tiles = (m + 128 - 1) // 128
f = block_size * 4
k_tiles = (k + f - 1) // f
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
return out[0:m, 0:k]
def dequantize_nvfp4_to_dtype(
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert tensor_fp4.dtype == torch.uint8
m, packed_k = tensor_fp4.shape
k = packed_k * 2
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
# scale the tensor
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
return out.to(dtype=dtype)
def break_fp4_bytes(a, dtype):
assert a.dtype == torch.uint8
m, n = a.shape
# Vectorized nibble processing
a_flat = a.flatten()
high = (a_flat & 0xF0) >> 4 # Upper nibbles
low = a_flat & 0x0F # Lower nibbles
# Combine nibbles for batch processing
combined = torch.stack((low, high), dim=1).flatten()
# Vectorized sign and magnitude extraction
signs = (combined & 0x08).to(torch.bool) # Sign bits
abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices
# Device-aware lookup and sign application
kE2M1ToFloat = torch.tensor(
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
)
kE2M1 = kE2M1ToFloat.to(device=a.device)
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
# Reshape to final form
return values.reshape(m, n * 2).to(dtype=dtype)
def compute_routing(
router_logits: torch.Tensor, top_k: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute routing weights and selected experts from router logits.
Args:
router_logits (torch.Tensor): Router logits of shape [batch_size, num_experts]
top_k (int): Number of experts to route to per token
Returns:
tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- routing_weights: Expert weights of shape [batch_size, top_k]
- selected_experts: Expert indices of shape [batch_size, top_k]
"""
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.float()
return routing_weights, selected_experts
def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
# score = torch.softmax(score, dim=-1, dtype=torch.float32)
# topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
# w1 needs to be swapped in terms of gate and up_proj
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
m = w1[i].shape[0]
assert m % 2 == 0
w1_expert, w3_expert = w1[i][m // 2 :, :], w1[i][: m // 2, :]
inter = F.silu(a[mask] @ w1_expert.t()) * (a[mask] @ w3_expert.t())
inter_gs = torch.tensor(1.0).cuda()
inter_q, inter_blockscale = fp4_quantize(inter, inter_gs)
inter = dequantize_nvfp4_to_dtype(
inter_q,
inter_blockscale,
inter_gs,
dtype=inter.dtype,
device=inter.device,
block_size=16,
).cuda()
out[mask] = inter @ w2[i].transpose(0, 1)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
def compute_with_experts(
num_experts,
x,
w31_weight,
w2_weight,
selected_experts,
routing_weights,
alpha=None,
beta=None,
limit=None,
):
results = torch.zeros_like(x)
for expert_id in range(num_experts):
mask = selected_experts == expert_id
if not mask.sum():
continue
batch_idx, nth_expert = torch.where(mask)
w31_expert = w31_weight[expert_id] # [2 * intermediate_size, hidden_size]
w2_expert = w2_weight[expert_id] # [hidden_size, intermediate_size]
# Split w13 into w1 and w3
w3_expert, w1_expert = torch.chunk(w31_expert, 2, dim=0)
expert_inputs = x[batch_idx]
if alpha is not None and limit is not None and beta is not None:
# SwiGLUBias
x1 = expert_inputs @ w1_expert.t()
x1 = x1.clamp_(min=None, max=limit)
x1_scaled = x1 * torch.sigmoid(alpha * x1)
x2 = expert_inputs @ w3_expert.t()
x2 = x2.clamp_(min=-limit, max=limit) + beta
inter = x1_scaled * x2
else:
inter = F.silu(expert_inputs @ w1_expert.t()) * (
expert_inputs @ w3_expert.t()
)
output = inter @ w2_expert.t()
results[batch_idx] += routing_weights[batch_idx, nth_expert, None] * output
return results.view_as(x)
# Test configurations
BATCH_SIZES = [
1,
]
HIDDEN_SIZES = [
128,
]
NUM_EXPERTS = [2]
TOP_K_VALUES = [2]
INTERMEDIATE_SIZES = [
128,
]
EP_NUM_EXPERTS = [8]
EP_TOP_K = [2]
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
def test_moe(batch_size, hidden_size, num_experts, top_k, intermediate_size):
# Skip invalid configurations
if top_k > num_experts:
pytest.skip(
f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})"
)
torch.manual_seed(42)
x = torch.randn(batch_size, hidden_size, dtype=torch.float16).cuda() / 5
router_logits = torch.randn(batch_size, num_experts, dtype=torch.float32).cuda()
w31_weight = (
torch.randn(
num_experts, 2 * intermediate_size, hidden_size, dtype=torch.float16
).cuda()
/ 5
)
w2_weight = (
torch.randn(
num_experts, hidden_size, intermediate_size, dtype=torch.float16
).cuda()
/ 5
)
routing_weights, selected_experts = compute_routing(router_logits, top_k)
ref_output = compute_with_experts(
num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights
)
flash_output = torch.empty_like(ref_output)
flash_output = fused_moe.cutlass_fused_moe(
x,
selected_experts.to(torch.int),
routing_weights,
w31_weight,
w2_weight,
flash_output.dtype,
output=flash_output,
quant_scales=None,
)
torch.testing.assert_close(ref_output, flash_output[0], rtol=1e-2, atol=1e-2)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
@pytest.mark.parametrize("otype, wtype", [(torch.float16, torch.float8_e4m3fn)])
def test_moe_fp8(
batch_size, hidden_size, num_experts, top_k, intermediate_size, otype, wtype
):
# Skip invalid configurations
if top_k > num_experts:
pytest.skip(
f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})"
)
torch.manual_seed(42)
input_shape = (batch_size, hidden_size)
w31_shape = (num_experts, 2 * intermediate_size, hidden_size)
w2_shape = (num_experts, hidden_size, intermediate_size)
x = cast_to_representable(gen_tensor(input_shape, otype))
router_logits = gen_tensor((batch_size, num_experts), otype)
# Create weight tensors
w31_weight = gen_tensor(w31_shape, otype, wtype)
w2_weight = gen_tensor(w2_shape, otype, wtype)
w31_scales = torch.empty(num_experts, 2, dtype=otype).cuda()
w2_scales = torch.empty(num_experts, 1, dtype=otype).cuda()
w31_dequantized = gen_tensor(w31_shape, otype)
w2_dequantized = gen_tensor(w2_shape, otype)
for expert_id in range(num_experts):
w31 = cast_to_representable(gen_tensor(w31_shape[1:], otype, scale=0.1))
w2 = cast_to_representable(gen_tensor(w2_shape[1:], otype, scale=0.09))
w31_quant, s31 = dynamic_per_tensor_fp8_quant(w31)
w2_quant, s2 = dynamic_per_tensor_fp8_quant(w2)
w31_weight.data[expert_id].copy_(w31_quant)
w2_weight.data[expert_id].copy_(w2_quant)
w31_scales.data[expert_id].copy_(s31)
w2_scales.data[expert_id].copy_(s2)
w31_dequantized.data[expert_id].copy_(torch.mul(w31_quant.to(dtype=otype), s31))
w2_dequantized.data[expert_id].copy_(torch.mul(w2_quant.to(dtype=otype), s2))
routing_weights, selected_experts = compute_routing(router_logits, top_k)
ref_output = compute_with_experts(
num_experts,
x,
w31_dequantized,
w2_dequantized,
selected_experts,
routing_weights,
)
flash_output = torch.empty_like(ref_output)
# For fp8, the hidden_state expects quantized.
_, w1_scales = torch.chunk(w31_scales, 2, dim=-1)
x_quant, hidden_states_scale = dynamic_per_tensor_fp8_quant(x)
hidden_states_scale = torch.tensor(hidden_states_scale[0]).cuda()
quant_scales = [
torch.squeeze(w1_scales * hidden_states_scale).float(),
torch.tensor(1.0).cuda(),
torch.squeeze(1.0 * w2_scales).float(),
hidden_states_scale,
]
_ = fused_moe.cutlass_fused_moe(
x_quant,
selected_experts.to(torch.int),
routing_weights,
w31_weight,
w2_weight,
otype,
quant_scales=quant_scales,
output=flash_output,
)
torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
@pytest.mark.parametrize(
"otype, wtype",
[(torch.float16, torch.float8_e4m3fn), (torch.bfloat16, torch.float8_e4m3fn)],
)
@pytest.mark.parametrize("quantized_input", [False, True])
@pytest.mark.skipif(
torch.cuda.get_device_capability()[0] not in [10, 11, 12],
reason="NVFP4 is only supported on SM100, SM110 and SM120",
)
def test_moe_nvfp4(
batch_size,
hidden_size,
num_experts,
top_k,
intermediate_size,
otype,
wtype,
quantized_input,
):
# Skip invalid configurations
if top_k > num_experts:
pytest.skip(
f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})"
)
torch.manual_seed(42)
quant_blocksize = 16
round_up = lambda x, y: (x + y - 1) // y * y
e = num_experts
m = batch_size
n = intermediate_size
k = hidden_size
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=otype) / 10
w1_cutlass = torch.cat((w1[:, n:, :], w1[:, :n, :]), dim=1).contiguous()
sf_w1_2n = round_up(2 * n, 128)
sf_w1_k = round_up(k // quant_blocksize, 4)
w1_blockscale = torch.empty(
(e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn
)
w1_blockscale_cutlass = torch.empty(
(e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn
)
w2 = torch.randn((e, k, n), device="cuda", dtype=otype) / 10
sf_w2_k = round_up(k, 128)
sf_w2_n = round_up(n // quant_blocksize, 4)
w2_blockscale = torch.empty(
(e, sf_w2_k, sf_w2_n), device="cuda", dtype=torch.float8_e4m3fn
)
w1_q = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8)
w1_q_cutlass = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8)
w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8)
w1_gs = torch.empty((e,), device="cuda", dtype=torch.float32)
w2_gs = torch.empty((e,), device="cuda", dtype=torch.float32)
for expert in range(e):
w1_amax = torch.abs(w1).max().to(torch.float32)
w2_amax = torch.abs(w2).max().to(torch.float32)
w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
w1_q[expert], w1_blockscale[expert] = fp4_quantize(w1[expert], w1_gs[expert])
w1_q_cutlass[expert], w1_blockscale_cutlass[expert] = fp4_quantize(
w1_cutlass[expert], w1_gs[expert]
)
w2_q[expert], w2_blockscale[expert] = fp4_quantize(w2[expert], w2_gs[expert])
x = torch.randn(m, k, dtype=otype).cuda()
a1_gs = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(x).max().to(
torch.float32
).cuda()
a1_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32)
a2_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32)
router_logits = torch.randn(m, e, dtype=otype).cuda()
routing_weights, selected_experts = compute_routing(router_logits, top_k)
# quant_scales format
# auto const fc1_act_global = quant_scales.value()[0];
# auto const fc1_weight_block = quant_scales.value()[1];
# auto const fc1_global = quant_scales.value()[2];
# auto const fc2_act_global = quant_scales.value()[3];
# auto const fc2_weight_block = quant_scales.value()[4];
# auto const fc2_global = quant_scales.value()[5];
flash_output = torch.zeros_like(x)
quant_scales = [
a1_gs,
w1_blockscale.view(torch.int32),
1.0 / (a1_gs * w1_gs),
a2_gs,
w2_blockscale.view(torch.int32),
1.0 / (a2_gs * w2_gs),
]
hidden_states = x
input_sf = None
if quantized_input:
hidden_states, input_sf = fp4_quantize(x, a1_gs)
_ = fused_moe.cutlass_fused_moe(
hidden_states,
selected_experts.to(torch.int),
routing_weights,
w1_q.contiguous().view(torch.long),
w2_q.contiguous().view(torch.long),
otype,
quant_scales=quant_scales,
input_sf=input_sf,
output=flash_output,
)
# Ref check
a_fp4, a_scale_interleaved = fp4_quantize(x, a1_gs)
_, m_k = a_fp4.shape
a_in_dtype = dequantize_nvfp4_to_dtype(
a_fp4,
a_scale_interleaved,
a1_gs,
dtype=otype,
device=x.device,
block_size=quant_blocksize,
)
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=otype)
w2_d = torch.empty((e, k, n), device="cuda", dtype=otype)
for idx in range(0, e):
w1_d[idx] = dequantize_nvfp4_to_dtype(
w1_q[idx],
w1_blockscale[idx],
w1_gs[idx],
dtype=w1.dtype,
device=w1.device,
block_size=quant_blocksize,
)
w2_d[idx] = dequantize_nvfp4_to_dtype(
w2_q[idx],
w2_blockscale[idx],
w2_gs[idx],
dtype=w2.dtype,
device=w2.device,
block_size=quant_blocksize,
)
w1_q_cutlass = torch.cat((w1_q[:, n:, :], w1_q[:, :n, :]), dim=1).contiguous()
w1_blockscale_cutlass = torch.cat(
(w1_blockscale[:, n:, :], w1_blockscale[:, :n, :]), dim=1
).contiguous()
ref_output = torch_moe_nvfp4(
a_in_dtype, w1_d, w2_d, top_k, routing_weights, selected_experts
)
torch.testing.assert_close(ref_output, flash_output, rtol=2e-1, atol=2e-1)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("num_experts", EP_NUM_EXPERTS)
@pytest.mark.parametrize("top_k", EP_TOP_K)
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
def test_moe_expert_parallel(
batch_size, hidden_size, num_experts, top_k, intermediate_size
):
"""
Test expert parallelism with X GPUs and Y experts.
Each GPU handles one expert and results are reduced.
Args:
batch_size: Batch size for the input
hidden_size: Hidden dimension size
num_experts: Number of experts (must be 2 for this test)
top_k: Number of experts to route to per token
intermediate_size: Intermediate dimension size
activation: Activation function type
"""
# This test is specifically for 2 GPUs and 2 experts
# GPU 0 (ep_rank=0) handles expert 0
# GPU 1 (ep_rank=1) handles expert 1
ep_size = num_experts // 2
torch.manual_seed(42)
# Create input tensors
x = torch.randn(batch_size, hidden_size, dtype=torch.float16).cuda()
# Create weight tensors - each GPU will have one expert
w31_weight = (
torch.randn(
num_experts, 2 * intermediate_size, hidden_size, dtype=torch.float16
).cuda()
/ 10
)
w2_weight = (
torch.randn(
num_experts, hidden_size, intermediate_size, dtype=torch.float16
).cuda()
/ 10
)
selected_experts = torch.stack(
[torch.randperm(num_experts)[:top_k] for _ in range(batch_size)]
).cuda()
routing_weights = torch.randn((batch_size, top_k)).cuda()
routing_weights = F.softmax(routing_weights, dim=1)
ref_output = compute_with_experts(
num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights
)
outputs = []
flash_output = torch.zeros_like(ref_output)
for ep_rank in range(ep_size):
# Create output tensor for this GPU
out_hidden_states_local = torch.zeros_like(x)
# Compute expert start and end positions for this rank
experts_per_rank = (
num_experts // ep_size
) # 2 GPUs, so each gets half the experts
expert_start = ep_rank * experts_per_rank
expert_end = expert_start + experts_per_rank # if ep_rank < 1 else num_experts
w31_weight_local = w31_weight[
expert_start:expert_end, :
] # Get only the experts for this rank
w2_weight_local = w2_weight[
expert_start:expert_end, :
] # Get only the experts for this rank
_ = fused_moe.cutlass_fused_moe(
x.contiguous(),
selected_experts.to(torch.int),
routing_weights,
w31_weight_local.contiguous(),
w2_weight_local.contiguous(),
x.dtype,
ep_size=ep_size,
ep_rank=ep_rank,
quant_scales=None,
output=out_hidden_states_local,
)
outputs.append(out_hidden_states_local)
# Reduce results from all GPUs
for ep_rank in range(ep_size):
flash_output += outputs[ep_rank] # [batch_size, num_experts]
torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1)
TP_SIZES = [2, 4]
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("tp_size", TP_SIZES)
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
def test_moe_tensor_parallel(
batch_size, hidden_size, num_experts, tp_size, intermediate_size
):
"""
Test tensor parallelism with:
- w31 sharded along second dimension (non-contracting)
- w2 sharded along third dimension (contracting)
- All-reduce to sum partial results
Args:
batch_size: Batch size for the input
hidden_size: Hidden dimension size
num_experts: Number of experts
top_k: Number of experts to route to per token
intermediate_size: Intermediate dimension size
activation: Activation function type
"""
# Set random seed for reproducibility
torch.manual_seed(42)
top_k = 2
# Create input tensors
x = torch.randn(batch_size, hidden_size, dtype=torch.float16).cuda()
# Create weight tensors
w31_weight = (
torch.randn(
num_experts, 2 * intermediate_size, hidden_size, dtype=torch.float16
).cuda()
/ 10
)
w2_weight = (
torch.randn(
num_experts, hidden_size, intermediate_size, dtype=torch.float16
).cuda()
/ 10
)
# Generate unique random expert indices for each token
selected_experts = torch.stack(
[torch.randperm(num_experts)[:top_k] for _ in range(batch_size)]
).cuda()
routing_weights = torch.randn((batch_size, top_k)).cuda()
routing_weights = F.softmax(routing_weights, dim=1)
# Run reference implementation (no parallelism)
ref_output = compute_with_experts(
num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights
)
# Simulate tensor parallelism on # TP GPUs
outputs = []
for tp_rank in range(tp_size):
# Create output tensor for this GPU
out_hidden_states_local = torch.zeros_like(x)
# Shard w31 along second dimension (intermediate_size)
# First split w31 into w3 and w1
w3_weight, w1_weight = torch.chunk(
w31_weight, 2, dim=1
) # [num_experts, intermediate_size, hidden_size] each
# Shard w3 and w1 separately
w3_shard_size = intermediate_size // tp_size
w3_start = tp_rank * w3_shard_size
w3_end = w3_start + w3_shard_size
w3_weight_local = w3_weight[:, w3_start:w3_end, :]
w1_shard_size = intermediate_size // tp_size
w1_start = tp_rank * w1_shard_size
w1_end = w1_start + w1_shard_size
w1_weight_local = w1_weight[:, w1_start:w1_end, :]
# Stack the sharded weights back together
w31_weight_local = torch.cat([w3_weight_local, w1_weight_local], dim=1)
# Shard w2 along third dimension (intermediate_size)
w2_shard_size = intermediate_size // tp_size
w2_start = tp_rank * w2_shard_size
w2_end = w2_start + w2_shard_size
w2_weight_local = w2_weight[:, :, w2_start:w2_end]
_ = fused_moe.cutlass_fused_moe(
x.contiguous(),
selected_experts.to(torch.int),
routing_weights,
w31_weight_local.contiguous(),
w2_weight_local.contiguous(),
x.dtype,
tp_size=tp_size,
tp_rank=tp_rank,
quant_scales=None,
output=out_hidden_states_local,
)
outputs.append(out_hidden_states_local)
# All-reduce to sum partial results from all GPUs
flash_output = sum(outputs)
torch.testing.assert_close(ref_output, flash_output, rtol=1e-2, atol=1e-2)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("num_experts", EP_NUM_EXPERTS)
@pytest.mark.parametrize("top_k", EP_TOP_K)
@pytest.mark.parametrize("tp_size", TP_SIZES)
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
def test_moe_tensor_expert_parallel(
batch_size, hidden_size, num_experts, top_k, tp_size, intermediate_size
):
"""
Test combined tensor parallelism and expert parallelism:
- Expert parallelism: Distribute experts across GPUs
- Tensor parallelism: For each expert's weights:
- w31 sharded along second dimension (non-contracting)
- w2 sharded along third dimension (contracting)
- All-reduce to sum partial results
Args:
batch_size: Batch size for the input
hidden_size: Hidden dimension size
num_experts: Number of experts
tp_size: Number of GPUs for tensor parallelism
intermediate_size: Intermediate dimension size
"""
torch.manual_seed(42)
x = torch.randn(batch_size, hidden_size, dtype=torch.float16).cuda()
w31_weight = (
torch.randn(
num_experts, 2 * intermediate_size, hidden_size, dtype=torch.float16
).cuda()
/ 10
)
w2_weight = (
torch.randn(
num_experts, hidden_size, intermediate_size, dtype=torch.float16
).cuda()
/ 10
)
# Generate unique random expert indices for each token
selected_experts = torch.stack(
[torch.randperm(num_experts)[:top_k] for _ in range(batch_size)]
).cuda()
routing_weights = torch.randn((batch_size, top_k)).cuda()
routing_weights = F.softmax(routing_weights, dim=1)
# Run reference implementation (no parallelism)
ref_output = compute_with_experts(
num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights
)
# Simulate combined parallelism
ep_size = num_experts // 2 # Number of GPUs for expert parallelism
outputs = []
# For each expert parallel rank
for ep_rank in range(ep_size):
# Get experts for this rank
experts_per_rank = num_experts // ep_size
expert_start = ep_rank * experts_per_rank
expert_end = expert_start + experts_per_rank
# Get expert weights for this rank
w31_weight_ep = w31_weight[
expert_start:expert_end, :
] # [experts_per_rank, 2*intermediate_size, hidden_size]
w2_weight_ep = w2_weight[
expert_start:expert_end, :
] # [experts_per_rank, hidden_size, intermediate_size]
# For each tensor parallel rank
for tp_rank in range(tp_size):
# Create output tensor for this GPU
out_hidden_states_local = torch.zeros_like(x)
# Split w31 into w3 and w1
w3_weight, w1_weight = torch.chunk(w31_weight_ep, 2, dim=1)
# Shard w3 and w1 separately
w3_shard_size = intermediate_size // tp_size
w3_start = tp_rank * w3_shard_size
w3_end = w3_start + w3_shard_size
w3_weight_local = w3_weight[:, w3_start:w3_end, :]
w1_shard_size = intermediate_size // tp_size
w1_start = tp_rank * w1_shard_size
w1_end = w1_start + w1_shard_size
w1_weight_local = w1_weight[:, w1_start:w1_end, :]
# Stack the sharded weights back together
w31_weight_local = torch.cat([w3_weight_local, w1_weight_local], dim=1)
# Shard w2 along third dimension
w2_shard_size = intermediate_size // tp_size
w2_start = tp_rank * w2_shard_size
w2_end = w2_start + w2_shard_size
w2_weight_local = w2_weight_ep[:, :, w2_start:w2_end]
# Call flashinfer implementation with both parallelisms
out_hidden_states_local = fused_moe.cutlass_fused_moe(
x.contiguous(),
selected_experts.to(torch.int),
routing_weights,
w31_weight_local.contiguous(),
w2_weight_local.contiguous(),
x.dtype,
tp_size=tp_size,
tp_rank=tp_rank,
ep_size=ep_size,
ep_rank=ep_rank,
quant_scales=None,
)
outputs.append(out_hidden_states_local[0])
# All-reduce to sum partial results from all GPUs
flash_output = sum(outputs)
torch.testing.assert_close(ref_output, flash_output, rtol=1e-2, atol=1e-2)
def ceil_div(a: int, b: int) -> int:
return -(a // -b)
def per_block_cast_to_fp8(
x: torch.Tensor, block_size_n: int = 128
) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(ceil_div(m, 128) * 128, ceil_div(n, block_size_n) * block_size_n),
dtype=x.dtype,
device=x.device,
)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
def per_token_group_quant_fp8(x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn):
"""Function to perform per-token-group quantization on an input tensor
`x` using native torch."""
assert x.shape[-1] % group_size == 0, (
"the last dimension of `x` cannot be divisible by `group_size`"
)
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
x_ = x.reshape(x.numel() // group_size, group_size)
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / fp8_max
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))
return x_q, x_s
def dequantize_block(
x_quant: torch.Tensor,
scales: torch.Tensor,
dtype: torch.dtype,
original_shape: tuple,
) -> torch.Tensor:
"""
Dequantize a block-quantized tensor.
Args:
x_quant: Quantized tensor
scales: Block scaling factors
dtype: Target dtype for dequantization
original_shape: Original shape of the tensor before padding
Returns:
torch.Tensor: Dequantized tensor
"""
# Reshape scales to match block structure
def transform_dim(a: torch.Tensor, dim: int = -1) -> torch.Tensor:
# Move target dim to last position if not already last
if dim != -1:
a = a.transpose(dim, -1)
# Broadcast and reshape
a_broadcasted = a.unsqueeze(-1).expand(*a.shape, 128)
a_reshaped = a_broadcasted.reshape(*a.shape[:-1], a.shape[-1] * 128)
# Move back if needed
if dim != -1:
a_reshaped = a_reshaped.transpose(dim, -1)
return a_reshaped
if x_quant.dim() == 2: # For activation tensors [batch_size, hidden_size]
batch_size, hidden_size = x_quant.shape
num_blocks = (hidden_size + 127) // 128
scales = scales.view(batch_size, num_blocks, 1).expand(-1, -1, 128)
scales = scales[:, :, : hidden_size % 128] if hidden_size % 128 != 0 else scales
else: # For weight tensors [..., in_dim, out_dim]
*_dims, in_dim, out_dim = x_quant.shape
# Transform both dimensions
scales = transform_dim(scales, -1) # Last dim
scales = transform_dim(scales, -2) # Second-to-last dim
# Handle padding
if in_dim % 128 != 0:
scales = scales[..., : in_dim % 128, :]
if out_dim % 128 != 0:
scales = scales[..., :, : out_dim % 128]
x_dequant = x_quant.to(dtype) * scales.to(dtype)
return x_dequant.view(original_shape)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
@pytest.mark.skipif(
torch.cuda.get_device_capability()[0] not in [10, 11, 12],
reason="FP8 block scaling is only supported on SM100, SM110 and SM120",
)
def test_moe_fp8_block_scaling(
batch_size, hidden_size, num_experts, top_k, intermediate_size
):
"""
Test MoE with FP8 block scaling (Deepseek style):
- Activation: 128x1 blocks
- Weights: 128x128 blocks
- Each block has its own scaling factor
Args:
batch_size: Batch size for the input
hidden_size: Hidden dimension size
num_experts: Number of experts
top_k: Number of experts to route to per token
intermediate_size: Intermediate dimension size
Only support bf16 for hidden_states
"""
torch.manual_seed(42)
otype = torch.bfloat16
x = torch.randn(batch_size, hidden_size, dtype=otype).cuda()
w31_weight = (
torch.randn(num_experts, 2 * intermediate_size, hidden_size, dtype=otype).cuda()
/ 10
)
w2_weight = (
torch.randn(num_experts, hidden_size, intermediate_size, dtype=otype).cuda()
/ 10
)
# Generate unique random expert indices for each token
selected_experts = torch.stack(
[torch.randperm(num_experts)[:top_k] for _ in range(batch_size)]
).cuda()
routing_weights = torch.randn((batch_size, top_k)).cuda()
routing_weights = F.softmax(routing_weights, dim=1)
# Run reference implementation (no quantization)
_ref_output = compute_with_experts(
num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights
)
# Quantize input and weights
x_quant, x_scales = per_token_group_quant_fp8(x, group_size=128)
w31_dequant = torch.empty_like(w31_weight)
w2_dequant = torch.empty_like(w2_weight)
w31_quant = torch.empty_like(w31_weight).to(torch.float8_e4m3fn)
w2_quant = torch.empty_like(w2_weight).to(torch.float8_e4m3fn)
w31_scales = torch.randn(
num_experts,
ceil_div(2 * intermediate_size, 128),
ceil_div(hidden_size, 128),
dtype=torch.float32,
).cuda()
w2_scales = torch.randn(
num_experts,
ceil_div(hidden_size, 128),
ceil_div(intermediate_size, 128),
dtype=torch.float32,
).cuda()
for expert_id in range(num_experts):
w31, w31_s = per_block_cast_to_fp8(w31_weight[expert_id, :])
w2, w2_s = per_block_cast_to_fp8(w2_weight[expert_id, :])
w31_quant.data[expert_id].copy_(w31)
w31_scales.data[expert_id].copy_(w31_s)
w2_quant.data[expert_id].copy_(w2)
w2_scales.data[expert_id].copy_(w2_s)
# Dequantize for verificationa
x_dequant = dequantize_block(x_quant, x_scales, x.dtype, x.shape)
w31_dequant = dequantize_block(
w31_quant, w31_scales, w31_weight.dtype, w31_weight.shape
)
w2_dequant = dequantize_block(w2_quant, w2_scales, w2_weight.dtype, w2_weight.shape)
# Run reference implementation with dequantized tensors
_ref_output = compute_with_experts(
num_experts,
x_dequant,
w31_dequant,
w2_dequant,
selected_experts,
routing_weights,
)
quant_scales = [
w31_scales, # .view(-1), # W31 scales
w2_scales, # .view(-1), # W2 scales
]
# Call flashinfer implementation with block scaling and expect NotImplementedError
with pytest.raises(
NotImplementedError,
match="DeepSeek FP8 Block Scaling is not yet implemented in CUTLASS for Blackwell",
):
_ = fused_moe.cutlass_fused_moe(
x.contiguous(),
selected_experts.to(torch.int),
routing_weights,
w31_quant.contiguous(),
w2_quant.contiguous(),
otype,
tp_size=1,
tp_rank=0,
use_deepseek_fp8_block_scale=True,
quant_scales=quant_scales,
)
def quant_mxfp4_batches(a, num_experts):
quant_a = []
sfs = []
for i in range(num_experts):
a_fp4, a_sf = mxfp4_quantize(a[i].cuda())
quant_a.append(a_fp4)
sfs.append(a_sf)
result_quant_a = torch.stack(quant_a)
result_sfs = torch.stack(sfs)
return result_quant_a, result_sfs
def dequant_mxfp4_batches(
mat_fp4: torch.Tensor,
scale_tensor: torch.Tensor,
):
num_batches = mat_fp4.size(0)
scale_tensor = scale_tensor.view(num_batches, -1)
return torch.stack(
[
mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :])
for b in range(num_batches)
]
)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
@pytest.mark.parametrize("otype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
("alpha", "beta", "limit"), [(None, None, None), (0.5, 0.0, 7.0), (1.702, 1.0, 7.0)]
)
@pytest.mark.skipif(
torch.cuda.get_device_capability()[0] not in [10, 11, 12],
reason="MXFP8xMXFP4 is only supported on SM100, SM110 and SM120",
)
def test_moe_mxfp8_mxfp4(
batch_size,
hidden_size,
num_experts,
top_k,
intermediate_size,
otype,
alpha,
beta,
limit,
):
"""
Test MoE with MXFP8 activations and MXFP4 weights.
Uses mxfp8_quantize for activations and fp4_quantize for weights.
"""
# Skip invalid configurations
if top_k > num_experts:
pytest.skip(
f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})"
)
torch.manual_seed(42)
e = num_experts
m = batch_size
n = intermediate_size
k = hidden_size
x = torch.randn(m, k, dtype=otype).cuda()
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=otype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=otype) / 10
mxfp8_x, mxfp8_x_sf = mxfp8_quantize(x, True, 32)
mxfp4_w1, mxfp4_w1_scale = quant_mxfp4_batches(w1, e)
mxfp4_w2, mxfp4_w2_scale = quant_mxfp4_batches(w2, e)
router_logits = torch.randn(m, e, dtype=otype).cuda()
routing_weights, selected_experts = compute_routing(router_logits, top_k)
fake_input_scale = torch.ones(e, device=x.device)
quant_scales = [
mxfp4_w1_scale.view(torch.int32),
fake_input_scale,
mxfp4_w2_scale.view(torch.int32),
fake_input_scale,
]
flash_output = torch.zeros_like(x)
if alpha is not None and limit is not None and beta is not None:
alpha_t = torch.ones(e, device=x.device) * alpha
limit_t = torch.ones(e, device=x.device) * limit
beta_t = torch.ones(e, device=x.device) * beta
else:
alpha_t = None
limit_t = None
beta_t = None
# Call cutlass_fused_moe with MXFP8 activations and MXFP4 weights
_ = fused_moe.cutlass_fused_moe(
mxfp8_x,
selected_experts.to(torch.int),
routing_weights,
mxfp4_w1.contiguous().view(torch.long),
mxfp4_w2.contiguous().view(torch.long),
otype,
swiglu_alpha=alpha_t,
swiglu_limit=limit_t,
swiglu_beta=beta_t,
quant_scales=quant_scales,
input_sf=mxfp8_x_sf,
use_mxfp8_act_scaling=True,
output=flash_output,
)
dq_mxfp8_x = (
mxfp8_dequantize_host(
mxfp8_x.cpu().view(torch.uint8),
mxfp8_x_sf.cpu().view(torch.uint8).reshape(-1),
True,
)
.cuda()
.to(otype)
)
dq_mfxp4_w1 = (
dequant_mxfp4_batches(
mxfp4_w1.cpu().view(torch.uint8),
mxfp4_w1_scale.cpu().view(torch.uint8).reshape(-1),
)
.cuda()
.to(otype)
)
dq_mfxp4_w2 = (
dequant_mxfp4_batches(
mxfp4_w2.cpu().view(torch.uint8),
mxfp4_w2_scale.cpu().view(torch.uint8).reshape(-1),
)
.cuda()
.to(otype)
)
# Use original weights for reference computation
ref_output = compute_with_experts(
e,
dq_mxfp8_x,
dq_mfxp4_w1,
dq_mfxp4_w2,
selected_experts,
routing_weights,
alpha,
beta,
limit,
)
torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1)
def dequant_mxfp4_batches_host(
mat_fp4: torch.Tensor,
scale_tensor: torch.Tensor,
):
return torch.stack(
[
mxfp4_dequantize_host(mat_fp4[b, :, :], scale_tensor[b, :, :])
for b in range(mat_fp4.size(0))
]
)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
@pytest.mark.parametrize(
("alpha", "beta", "limit"), [(None, None, None), (0.5, 0.0, 7.0), (1.702, 1.0, 7.0)]
)
@pytest.mark.skipif(
torch.cuda.get_device_capability()[0] != 9,
reason="BF16xMXFP4 is only supported on SM90",
)
def test_moe_bf16_mxfp4(
batch_size,
hidden_size,
num_experts,
top_k,
intermediate_size,
alpha,
beta,
limit,
):
"""
Test MoE with bf16 activations and MXFP4 weights.
Uses bf16 for activations and fp4_quantize for weights.
"""
# Skip invalid configurations
if top_k > num_experts:
pytest.skip(
f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})"
)
torch.manual_seed(42)
e = num_experts
m = batch_size
n = intermediate_size
k = hidden_size
x = torch.randn(m, k, dtype=torch.bfloat16).cuda()
w1 = torch.randint(0, 256, (e, 2 * n, k // 2), device="cuda", dtype=torch.uint8)
w2 = torch.randint(0, 256, (e, k, n // 2), device="cuda", dtype=torch.uint8)
w1_scale = torch.randint(
118, 123, (e, 2 * n, k // 32), device="cuda", dtype=torch.uint8
)
w2_scale = torch.randint(
118, 123, (e, k, n // 32), device="cuda", dtype=torch.uint8
)
router_logits = torch.randn(m, e, dtype=torch.bfloat16).cuda()
routing_weights, selected_experts = compute_routing(router_logits, top_k)
flash_output = torch.zeros_like(x)
if alpha is not None and limit is not None and beta is not None:
alpha_t = torch.ones(e, device=x.device) * alpha
limit_t = torch.ones(e, device=x.device) * limit
beta_t = torch.ones(e, device=x.device) * beta
else:
alpha_t = None
limit_t = None
beta_t = None
pad_size = hidden_size - x.shape[1]
x_pad = torch.nn.functional.pad(x, (0, pad_size))
quant_scales = [
w1_scale.view(torch.int32),
w2_scale.view(torch.int32),
]
# Call cutlass_fused_moe with BF16 activations and MXFP4 weights
_ = fused_moe.cutlass_fused_moe(
x_pad,
selected_experts.to(torch.int),
routing_weights,
w1.contiguous().view(torch.uint8),
w2.contiguous().view(torch.uint8),
torch.bfloat16,
swiglu_alpha=alpha_t,
swiglu_limit=limit_t,
swiglu_beta=beta_t,
quant_scales=quant_scales,
use_w4_group_scaling=True,
output=flash_output,
)
dq_mfxp4_w1 = (
dequant_mxfp4_batches_host(
w1.cpu(),
w1_scale.cpu(),
)
.cuda()
.to(torch.bfloat16)
)
dq_mfxp4_w2 = (
dequant_mxfp4_batches_host(
w2.cpu(),
w2_scale.cpu(),
)
.cuda()
.to(torch.bfloat16)
)
# Use original weights for reference computation
ref_output = compute_with_experts(
e,
x,
dq_mfxp4_w1,
dq_mfxp4_w2,
selected_experts,
routing_weights,
alpha,
beta,
limit,
)
torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1)
if __name__ == "__main__":
pytest.main([__file__, "-v"])