""" Copyright (c) 2025 by FlashInfer team. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import pytest import torch from torch.nn import functional as F import flashinfer.fused_moe as fused_moe from flashinfer import ( fp4_quantize, mxfp4_dequantize, mxfp4_quantize, mxfp8_dequantize_host, mxfp8_quantize, mxfp4_dequantize_host, ) FLOAT4_E2M1_MAX = 6.0 FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max FP8_DTYPE = torch.float8_e4m3fn def dynamic_per_tensor_fp8_quant(x: torch.tensor) -> tuple[torch.tensor, torch.tensor]: fp8_traits_max = FLOAT8_E4M3_MAX fp8_traits_min = -FLOAT8_E4M3_MAX fp8_max = torch.tensor(fp8_traits_max).float() one = torch.tensor(1.0).float() x_max = x.abs().max().float() scale = x_max / fp8_max iscale = one / scale out = (x.float() * iscale).clamp(fp8_traits_min, fp8_traits_max).to(FP8_DTYPE) return out, scale.view((1,)) def gen_tensor(shape, dtype, stype=None, scale=1.0): x = torch.randn(*shape, dtype=dtype).cuda() * scale return x.to(stype) if stype else x def cast_to_representable(x): x_q, x_scale = dynamic_per_tensor_fp8_quant(x) x = x_q.to(x.dtype) * x_scale.to(x.dtype) return x def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): m_tiles = (m + 128 - 1) // 128 f = block_size * 4 k_tiles = (k + f - 1) // f tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) return out[0:m, 0:k] def dequantize_nvfp4_to_dtype( tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16 ): """Dequantize the fp4 tensor back to high precision.""" # Two fp4 values are packed into one uint8. assert tensor_fp4.dtype == torch.uint8 m, packed_k = tensor_fp4.shape k = packed_k * 2 tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) tensor_sf = tensor_sf.view(torch.float8_e4m3fn) tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale # scale the tensor out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) return out.to(dtype=dtype) def break_fp4_bytes(a, dtype): assert a.dtype == torch.uint8 m, n = a.shape # Vectorized nibble processing a_flat = a.flatten() high = (a_flat & 0xF0) >> 4 # Upper nibbles low = a_flat & 0x0F # Lower nibbles # Combine nibbles for batch processing combined = torch.stack((low, high), dim=1).flatten() # Vectorized sign and magnitude extraction signs = (combined & 0x08).to(torch.bool) # Sign bits abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices # Device-aware lookup and sign application kE2M1ToFloat = torch.tensor( [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 ) kE2M1 = kE2M1ToFloat.to(device=a.device) values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) # Reshape to final form return values.reshape(m, n * 2).to(dtype=dtype) def compute_routing( router_logits: torch.Tensor, top_k: int ) -> tuple[torch.Tensor, torch.Tensor]: """ Compute routing weights and selected experts from router logits. Args: router_logits (torch.Tensor): Router logits of shape [batch_size, num_experts] top_k (int): Number of experts to route to per token Returns: tuple[torch.Tensor, torch.Tensor]: A tuple containing: - routing_weights: Expert weights of shape [batch_size, top_k] - selected_experts: Expert indices of shape [batch_size, top_k] """ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) routing_weights = routing_weights.float() return routing_weights, selected_experts def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids): B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) # score = torch.softmax(score, dim=-1, dtype=torch.float32) # topk_weight, topk_ids = torch.topk(score, topk) topk_weight = topk_weight.view(-1) topk_ids = topk_ids.view(-1) # w1 needs to be swapped in terms of gate and up_proj for i in range(w1.shape[0]): mask = topk_ids == i if mask.sum(): m = w1[i].shape[0] assert m % 2 == 0 w1_expert, w3_expert = w1[i][m // 2 :, :], w1[i][: m // 2, :] inter = F.silu(a[mask] @ w1_expert.t()) * (a[mask] @ w3_expert.t()) inter_gs = torch.tensor(1.0).cuda() inter_q, inter_blockscale = fp4_quantize(inter, inter_gs) inter = dequantize_nvfp4_to_dtype( inter_q, inter_blockscale, inter_gs, dtype=inter.dtype, device=inter.device, block_size=16, ).cuda() out[mask] = inter @ w2[i].transpose(0, 1) return ( out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) ).sum(dim=1) def compute_with_experts( num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights, alpha=None, beta=None, limit=None, ): results = torch.zeros_like(x) for expert_id in range(num_experts): mask = selected_experts == expert_id if not mask.sum(): continue batch_idx, nth_expert = torch.where(mask) w31_expert = w31_weight[expert_id] # [2 * intermediate_size, hidden_size] w2_expert = w2_weight[expert_id] # [hidden_size, intermediate_size] # Split w13 into w1 and w3 w3_expert, w1_expert = torch.chunk(w31_expert, 2, dim=0) expert_inputs = x[batch_idx] if alpha is not None and limit is not None and beta is not None: # SwiGLUBias x1 = expert_inputs @ w1_expert.t() x1 = x1.clamp_(min=None, max=limit) x1_scaled = x1 * torch.sigmoid(alpha * x1) x2 = expert_inputs @ w3_expert.t() x2 = x2.clamp_(min=-limit, max=limit) + beta inter = x1_scaled * x2 else: inter = F.silu(expert_inputs @ w1_expert.t()) * ( expert_inputs @ w3_expert.t() ) output = inter @ w2_expert.t() results[batch_idx] += routing_weights[batch_idx, nth_expert, None] * output return results.view_as(x) # Test configurations BATCH_SIZES = [ 1, ] HIDDEN_SIZES = [ 128, ] NUM_EXPERTS = [2] TOP_K_VALUES = [2] INTERMEDIATE_SIZES = [ 128, ] EP_NUM_EXPERTS = [8] EP_TOP_K = [2] @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @pytest.mark.parametrize("top_k", TOP_K_VALUES) @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) def test_moe(batch_size, hidden_size, num_experts, top_k, intermediate_size): # Skip invalid configurations if top_k > num_experts: pytest.skip( f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})" ) torch.manual_seed(42) x = torch.randn(batch_size, hidden_size, dtype=torch.float16).cuda() / 5 router_logits = torch.randn(batch_size, num_experts, dtype=torch.float32).cuda() w31_weight = ( torch.randn( num_experts, 2 * intermediate_size, hidden_size, dtype=torch.float16 ).cuda() / 5 ) w2_weight = ( torch.randn( num_experts, hidden_size, intermediate_size, dtype=torch.float16 ).cuda() / 5 ) routing_weights, selected_experts = compute_routing(router_logits, top_k) ref_output = compute_with_experts( num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights ) flash_output = torch.empty_like(ref_output) flash_output = fused_moe.cutlass_fused_moe( x, selected_experts.to(torch.int), routing_weights, w31_weight, w2_weight, flash_output.dtype, output=flash_output, quant_scales=None, ) torch.testing.assert_close(ref_output, flash_output[0], rtol=1e-2, atol=1e-2) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @pytest.mark.parametrize("top_k", TOP_K_VALUES) @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) @pytest.mark.parametrize("otype, wtype", [(torch.float16, torch.float8_e4m3fn)]) def test_moe_fp8( batch_size, hidden_size, num_experts, top_k, intermediate_size, otype, wtype ): # Skip invalid configurations if top_k > num_experts: pytest.skip( f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})" ) torch.manual_seed(42) input_shape = (batch_size, hidden_size) w31_shape = (num_experts, 2 * intermediate_size, hidden_size) w2_shape = (num_experts, hidden_size, intermediate_size) x = cast_to_representable(gen_tensor(input_shape, otype)) router_logits = gen_tensor((batch_size, num_experts), otype) # Create weight tensors w31_weight = gen_tensor(w31_shape, otype, wtype) w2_weight = gen_tensor(w2_shape, otype, wtype) w31_scales = torch.empty(num_experts, 2, dtype=otype).cuda() w2_scales = torch.empty(num_experts, 1, dtype=otype).cuda() w31_dequantized = gen_tensor(w31_shape, otype) w2_dequantized = gen_tensor(w2_shape, otype) for expert_id in range(num_experts): w31 = cast_to_representable(gen_tensor(w31_shape[1:], otype, scale=0.1)) w2 = cast_to_representable(gen_tensor(w2_shape[1:], otype, scale=0.09)) w31_quant, s31 = dynamic_per_tensor_fp8_quant(w31) w2_quant, s2 = dynamic_per_tensor_fp8_quant(w2) w31_weight.data[expert_id].copy_(w31_quant) w2_weight.data[expert_id].copy_(w2_quant) w31_scales.data[expert_id].copy_(s31) w2_scales.data[expert_id].copy_(s2) w31_dequantized.data[expert_id].copy_(torch.mul(w31_quant.to(dtype=otype), s31)) w2_dequantized.data[expert_id].copy_(torch.mul(w2_quant.to(dtype=otype), s2)) routing_weights, selected_experts = compute_routing(router_logits, top_k) ref_output = compute_with_experts( num_experts, x, w31_dequantized, w2_dequantized, selected_experts, routing_weights, ) flash_output = torch.empty_like(ref_output) # For fp8, the hidden_state expects quantized. _, w1_scales = torch.chunk(w31_scales, 2, dim=-1) x_quant, hidden_states_scale = dynamic_per_tensor_fp8_quant(x) hidden_states_scale = torch.tensor(hidden_states_scale[0]).cuda() quant_scales = [ torch.squeeze(w1_scales * hidden_states_scale).float(), torch.tensor(1.0).cuda(), torch.squeeze(1.0 * w2_scales).float(), hidden_states_scale, ] _ = fused_moe.cutlass_fused_moe( x_quant, selected_experts.to(torch.int), routing_weights, w31_weight, w2_weight, otype, quant_scales=quant_scales, output=flash_output, ) torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @pytest.mark.parametrize("top_k", TOP_K_VALUES) @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) @pytest.mark.parametrize( "otype, wtype", [(torch.float16, torch.float8_e4m3fn), (torch.bfloat16, torch.float8_e4m3fn)], ) @pytest.mark.parametrize("quantized_input", [False, True]) @pytest.mark.skipif( torch.cuda.get_device_capability()[0] not in [10, 11, 12], reason="NVFP4 is only supported on SM100, SM110 and SM120", ) def test_moe_nvfp4( batch_size, hidden_size, num_experts, top_k, intermediate_size, otype, wtype, quantized_input, ): # Skip invalid configurations if top_k > num_experts: pytest.skip( f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})" ) torch.manual_seed(42) quant_blocksize = 16 round_up = lambda x, y: (x + y - 1) // y * y e = num_experts m = batch_size n = intermediate_size k = hidden_size w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=otype) / 10 w1_cutlass = torch.cat((w1[:, n:, :], w1[:, :n, :]), dim=1).contiguous() sf_w1_2n = round_up(2 * n, 128) sf_w1_k = round_up(k // quant_blocksize, 4) w1_blockscale = torch.empty( (e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn ) w1_blockscale_cutlass = torch.empty( (e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn ) w2 = torch.randn((e, k, n), device="cuda", dtype=otype) / 10 sf_w2_k = round_up(k, 128) sf_w2_n = round_up(n // quant_blocksize, 4) w2_blockscale = torch.empty( (e, sf_w2_k, sf_w2_n), device="cuda", dtype=torch.float8_e4m3fn ) w1_q = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8) w1_q_cutlass = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8) w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8) w1_gs = torch.empty((e,), device="cuda", dtype=torch.float32) w2_gs = torch.empty((e,), device="cuda", dtype=torch.float32) for expert in range(e): w1_amax = torch.abs(w1).max().to(torch.float32) w2_amax = torch.abs(w2).max().to(torch.float32) w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax w1_q[expert], w1_blockscale[expert] = fp4_quantize(w1[expert], w1_gs[expert]) w1_q_cutlass[expert], w1_blockscale_cutlass[expert] = fp4_quantize( w1_cutlass[expert], w1_gs[expert] ) w2_q[expert], w2_blockscale[expert] = fp4_quantize(w2[expert], w2_gs[expert]) x = torch.randn(m, k, dtype=otype).cuda() a1_gs = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(x).max().to( torch.float32 ).cuda() a1_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32) a2_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32) router_logits = torch.randn(m, e, dtype=otype).cuda() routing_weights, selected_experts = compute_routing(router_logits, top_k) # quant_scales format # auto const fc1_act_global = quant_scales.value()[0]; # auto const fc1_weight_block = quant_scales.value()[1]; # auto const fc1_global = quant_scales.value()[2]; # auto const fc2_act_global = quant_scales.value()[3]; # auto const fc2_weight_block = quant_scales.value()[4]; # auto const fc2_global = quant_scales.value()[5]; flash_output = torch.zeros_like(x) quant_scales = [ a1_gs, w1_blockscale.view(torch.int32), 1.0 / (a1_gs * w1_gs), a2_gs, w2_blockscale.view(torch.int32), 1.0 / (a2_gs * w2_gs), ] hidden_states = x input_sf = None if quantized_input: hidden_states, input_sf = fp4_quantize(x, a1_gs) _ = fused_moe.cutlass_fused_moe( hidden_states, selected_experts.to(torch.int), routing_weights, w1_q.contiguous().view(torch.long), w2_q.contiguous().view(torch.long), otype, quant_scales=quant_scales, input_sf=input_sf, output=flash_output, ) # Ref check a_fp4, a_scale_interleaved = fp4_quantize(x, a1_gs) _, m_k = a_fp4.shape a_in_dtype = dequantize_nvfp4_to_dtype( a_fp4, a_scale_interleaved, a1_gs, dtype=otype, device=x.device, block_size=quant_blocksize, ) w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=otype) w2_d = torch.empty((e, k, n), device="cuda", dtype=otype) for idx in range(0, e): w1_d[idx] = dequantize_nvfp4_to_dtype( w1_q[idx], w1_blockscale[idx], w1_gs[idx], dtype=w1.dtype, device=w1.device, block_size=quant_blocksize, ) w2_d[idx] = dequantize_nvfp4_to_dtype( w2_q[idx], w2_blockscale[idx], w2_gs[idx], dtype=w2.dtype, device=w2.device, block_size=quant_blocksize, ) w1_q_cutlass = torch.cat((w1_q[:, n:, :], w1_q[:, :n, :]), dim=1).contiguous() w1_blockscale_cutlass = torch.cat( (w1_blockscale[:, n:, :], w1_blockscale[:, :n, :]), dim=1 ).contiguous() ref_output = torch_moe_nvfp4( a_in_dtype, w1_d, w2_d, top_k, routing_weights, selected_experts ) torch.testing.assert_close(ref_output, flash_output, rtol=2e-1, atol=2e-1) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("num_experts", EP_NUM_EXPERTS) @pytest.mark.parametrize("top_k", EP_TOP_K) @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) def test_moe_expert_parallel( batch_size, hidden_size, num_experts, top_k, intermediate_size ): """ Test expert parallelism with X GPUs and Y experts. Each GPU handles one expert and results are reduced. Args: batch_size: Batch size for the input hidden_size: Hidden dimension size num_experts: Number of experts (must be 2 for this test) top_k: Number of experts to route to per token intermediate_size: Intermediate dimension size activation: Activation function type """ # This test is specifically for 2 GPUs and 2 experts # GPU 0 (ep_rank=0) handles expert 0 # GPU 1 (ep_rank=1) handles expert 1 ep_size = num_experts // 2 torch.manual_seed(42) # Create input tensors x = torch.randn(batch_size, hidden_size, dtype=torch.float16).cuda() # Create weight tensors - each GPU will have one expert w31_weight = ( torch.randn( num_experts, 2 * intermediate_size, hidden_size, dtype=torch.float16 ).cuda() / 10 ) w2_weight = ( torch.randn( num_experts, hidden_size, intermediate_size, dtype=torch.float16 ).cuda() / 10 ) selected_experts = torch.stack( [torch.randperm(num_experts)[:top_k] for _ in range(batch_size)] ).cuda() routing_weights = torch.randn((batch_size, top_k)).cuda() routing_weights = F.softmax(routing_weights, dim=1) ref_output = compute_with_experts( num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights ) outputs = [] flash_output = torch.zeros_like(ref_output) for ep_rank in range(ep_size): # Create output tensor for this GPU out_hidden_states_local = torch.zeros_like(x) # Compute expert start and end positions for this rank experts_per_rank = ( num_experts // ep_size ) # 2 GPUs, so each gets half the experts expert_start = ep_rank * experts_per_rank expert_end = expert_start + experts_per_rank # if ep_rank < 1 else num_experts w31_weight_local = w31_weight[ expert_start:expert_end, : ] # Get only the experts for this rank w2_weight_local = w2_weight[ expert_start:expert_end, : ] # Get only the experts for this rank _ = fused_moe.cutlass_fused_moe( x.contiguous(), selected_experts.to(torch.int), routing_weights, w31_weight_local.contiguous(), w2_weight_local.contiguous(), x.dtype, ep_size=ep_size, ep_rank=ep_rank, quant_scales=None, output=out_hidden_states_local, ) outputs.append(out_hidden_states_local) # Reduce results from all GPUs for ep_rank in range(ep_size): flash_output += outputs[ep_rank] # [batch_size, num_experts] torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1) TP_SIZES = [2, 4] @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @pytest.mark.parametrize("tp_size", TP_SIZES) @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) def test_moe_tensor_parallel( batch_size, hidden_size, num_experts, tp_size, intermediate_size ): """ Test tensor parallelism with: - w31 sharded along second dimension (non-contracting) - w2 sharded along third dimension (contracting) - All-reduce to sum partial results Args: batch_size: Batch size for the input hidden_size: Hidden dimension size num_experts: Number of experts top_k: Number of experts to route to per token intermediate_size: Intermediate dimension size activation: Activation function type """ # Set random seed for reproducibility torch.manual_seed(42) top_k = 2 # Create input tensors x = torch.randn(batch_size, hidden_size, dtype=torch.float16).cuda() # Create weight tensors w31_weight = ( torch.randn( num_experts, 2 * intermediate_size, hidden_size, dtype=torch.float16 ).cuda() / 10 ) w2_weight = ( torch.randn( num_experts, hidden_size, intermediate_size, dtype=torch.float16 ).cuda() / 10 ) # Generate unique random expert indices for each token selected_experts = torch.stack( [torch.randperm(num_experts)[:top_k] for _ in range(batch_size)] ).cuda() routing_weights = torch.randn((batch_size, top_k)).cuda() routing_weights = F.softmax(routing_weights, dim=1) # Run reference implementation (no parallelism) ref_output = compute_with_experts( num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights ) # Simulate tensor parallelism on # TP GPUs outputs = [] for tp_rank in range(tp_size): # Create output tensor for this GPU out_hidden_states_local = torch.zeros_like(x) # Shard w31 along second dimension (intermediate_size) # First split w31 into w3 and w1 w3_weight, w1_weight = torch.chunk( w31_weight, 2, dim=1 ) # [num_experts, intermediate_size, hidden_size] each # Shard w3 and w1 separately w3_shard_size = intermediate_size // tp_size w3_start = tp_rank * w3_shard_size w3_end = w3_start + w3_shard_size w3_weight_local = w3_weight[:, w3_start:w3_end, :] w1_shard_size = intermediate_size // tp_size w1_start = tp_rank * w1_shard_size w1_end = w1_start + w1_shard_size w1_weight_local = w1_weight[:, w1_start:w1_end, :] # Stack the sharded weights back together w31_weight_local = torch.cat([w3_weight_local, w1_weight_local], dim=1) # Shard w2 along third dimension (intermediate_size) w2_shard_size = intermediate_size // tp_size w2_start = tp_rank * w2_shard_size w2_end = w2_start + w2_shard_size w2_weight_local = w2_weight[:, :, w2_start:w2_end] _ = fused_moe.cutlass_fused_moe( x.contiguous(), selected_experts.to(torch.int), routing_weights, w31_weight_local.contiguous(), w2_weight_local.contiguous(), x.dtype, tp_size=tp_size, tp_rank=tp_rank, quant_scales=None, output=out_hidden_states_local, ) outputs.append(out_hidden_states_local) # All-reduce to sum partial results from all GPUs flash_output = sum(outputs) torch.testing.assert_close(ref_output, flash_output, rtol=1e-2, atol=1e-2) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("num_experts", EP_NUM_EXPERTS) @pytest.mark.parametrize("top_k", EP_TOP_K) @pytest.mark.parametrize("tp_size", TP_SIZES) @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) def test_moe_tensor_expert_parallel( batch_size, hidden_size, num_experts, top_k, tp_size, intermediate_size ): """ Test combined tensor parallelism and expert parallelism: - Expert parallelism: Distribute experts across GPUs - Tensor parallelism: For each expert's weights: - w31 sharded along second dimension (non-contracting) - w2 sharded along third dimension (contracting) - All-reduce to sum partial results Args: batch_size: Batch size for the input hidden_size: Hidden dimension size num_experts: Number of experts tp_size: Number of GPUs for tensor parallelism intermediate_size: Intermediate dimension size """ torch.manual_seed(42) x = torch.randn(batch_size, hidden_size, dtype=torch.float16).cuda() w31_weight = ( torch.randn( num_experts, 2 * intermediate_size, hidden_size, dtype=torch.float16 ).cuda() / 10 ) w2_weight = ( torch.randn( num_experts, hidden_size, intermediate_size, dtype=torch.float16 ).cuda() / 10 ) # Generate unique random expert indices for each token selected_experts = torch.stack( [torch.randperm(num_experts)[:top_k] for _ in range(batch_size)] ).cuda() routing_weights = torch.randn((batch_size, top_k)).cuda() routing_weights = F.softmax(routing_weights, dim=1) # Run reference implementation (no parallelism) ref_output = compute_with_experts( num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights ) # Simulate combined parallelism ep_size = num_experts // 2 # Number of GPUs for expert parallelism outputs = [] # For each expert parallel rank for ep_rank in range(ep_size): # Get experts for this rank experts_per_rank = num_experts // ep_size expert_start = ep_rank * experts_per_rank expert_end = expert_start + experts_per_rank # Get expert weights for this rank w31_weight_ep = w31_weight[ expert_start:expert_end, : ] # [experts_per_rank, 2*intermediate_size, hidden_size] w2_weight_ep = w2_weight[ expert_start:expert_end, : ] # [experts_per_rank, hidden_size, intermediate_size] # For each tensor parallel rank for tp_rank in range(tp_size): # Create output tensor for this GPU out_hidden_states_local = torch.zeros_like(x) # Split w31 into w3 and w1 w3_weight, w1_weight = torch.chunk(w31_weight_ep, 2, dim=1) # Shard w3 and w1 separately w3_shard_size = intermediate_size // tp_size w3_start = tp_rank * w3_shard_size w3_end = w3_start + w3_shard_size w3_weight_local = w3_weight[:, w3_start:w3_end, :] w1_shard_size = intermediate_size // tp_size w1_start = tp_rank * w1_shard_size w1_end = w1_start + w1_shard_size w1_weight_local = w1_weight[:, w1_start:w1_end, :] # Stack the sharded weights back together w31_weight_local = torch.cat([w3_weight_local, w1_weight_local], dim=1) # Shard w2 along third dimension w2_shard_size = intermediate_size // tp_size w2_start = tp_rank * w2_shard_size w2_end = w2_start + w2_shard_size w2_weight_local = w2_weight_ep[:, :, w2_start:w2_end] # Call flashinfer implementation with both parallelisms out_hidden_states_local = fused_moe.cutlass_fused_moe( x.contiguous(), selected_experts.to(torch.int), routing_weights, w31_weight_local.contiguous(), w2_weight_local.contiguous(), x.dtype, tp_size=tp_size, tp_rank=tp_rank, ep_size=ep_size, ep_rank=ep_rank, quant_scales=None, ) outputs.append(out_hidden_states_local[0]) # All-reduce to sum partial results from all GPUs flash_output = sum(outputs) torch.testing.assert_close(ref_output, flash_output, rtol=1e-2, atol=1e-2) def ceil_div(a: int, b: int) -> int: return -(a // -b) def per_block_cast_to_fp8( x: torch.Tensor, block_size_n: int = 128 ) -> tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape x_padded = torch.zeros( (ceil_div(m, 128) * 128, ceil_div(n, block_size_n) * block_size_n), dtype=x.dtype, device=x.device, ) x_padded[:m, :n] = x x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) return x_scaled_sub, scales def per_token_group_quant_fp8(x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn): """Function to perform per-token-group quantization on an input tensor `x` using native torch.""" assert x.shape[-1] % group_size == 0, ( "the last dimension of `x` cannot be divisible by `group_size`" ) assert x.is_contiguous(), "`x` is not contiguous" finfo = torch.finfo(dtype) fp8_min = finfo.min fp8_max = finfo.max x_ = x.reshape(x.numel() // group_size, group_size) amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32) x_s = amax / fp8_max x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype) x_q = x_q.reshape(x.shape) x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,)) return x_q, x_s def dequantize_block( x_quant: torch.Tensor, scales: torch.Tensor, dtype: torch.dtype, original_shape: tuple, ) -> torch.Tensor: """ Dequantize a block-quantized tensor. Args: x_quant: Quantized tensor scales: Block scaling factors dtype: Target dtype for dequantization original_shape: Original shape of the tensor before padding Returns: torch.Tensor: Dequantized tensor """ # Reshape scales to match block structure def transform_dim(a: torch.Tensor, dim: int = -1) -> torch.Tensor: # Move target dim to last position if not already last if dim != -1: a = a.transpose(dim, -1) # Broadcast and reshape a_broadcasted = a.unsqueeze(-1).expand(*a.shape, 128) a_reshaped = a_broadcasted.reshape(*a.shape[:-1], a.shape[-1] * 128) # Move back if needed if dim != -1: a_reshaped = a_reshaped.transpose(dim, -1) return a_reshaped if x_quant.dim() == 2: # For activation tensors [batch_size, hidden_size] batch_size, hidden_size = x_quant.shape num_blocks = (hidden_size + 127) // 128 scales = scales.view(batch_size, num_blocks, 1).expand(-1, -1, 128) scales = scales[:, :, : hidden_size % 128] if hidden_size % 128 != 0 else scales else: # For weight tensors [..., in_dim, out_dim] *_dims, in_dim, out_dim = x_quant.shape # Transform both dimensions scales = transform_dim(scales, -1) # Last dim scales = transform_dim(scales, -2) # Second-to-last dim # Handle padding if in_dim % 128 != 0: scales = scales[..., : in_dim % 128, :] if out_dim % 128 != 0: scales = scales[..., :, : out_dim % 128] x_dequant = x_quant.to(dtype) * scales.to(dtype) return x_dequant.view(original_shape) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @pytest.mark.parametrize("top_k", TOP_K_VALUES) @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) @pytest.mark.skipif( torch.cuda.get_device_capability()[0] not in [10, 11, 12], reason="FP8 block scaling is only supported on SM100, SM110 and SM120", ) def test_moe_fp8_block_scaling( batch_size, hidden_size, num_experts, top_k, intermediate_size ): """ Test MoE with FP8 block scaling (Deepseek style): - Activation: 128x1 blocks - Weights: 128x128 blocks - Each block has its own scaling factor Args: batch_size: Batch size for the input hidden_size: Hidden dimension size num_experts: Number of experts top_k: Number of experts to route to per token intermediate_size: Intermediate dimension size Only support bf16 for hidden_states """ torch.manual_seed(42) otype = torch.bfloat16 x = torch.randn(batch_size, hidden_size, dtype=otype).cuda() w31_weight = ( torch.randn(num_experts, 2 * intermediate_size, hidden_size, dtype=otype).cuda() / 10 ) w2_weight = ( torch.randn(num_experts, hidden_size, intermediate_size, dtype=otype).cuda() / 10 ) # Generate unique random expert indices for each token selected_experts = torch.stack( [torch.randperm(num_experts)[:top_k] for _ in range(batch_size)] ).cuda() routing_weights = torch.randn((batch_size, top_k)).cuda() routing_weights = F.softmax(routing_weights, dim=1) # Run reference implementation (no quantization) _ref_output = compute_with_experts( num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights ) # Quantize input and weights x_quant, x_scales = per_token_group_quant_fp8(x, group_size=128) w31_dequant = torch.empty_like(w31_weight) w2_dequant = torch.empty_like(w2_weight) w31_quant = torch.empty_like(w31_weight).to(torch.float8_e4m3fn) w2_quant = torch.empty_like(w2_weight).to(torch.float8_e4m3fn) w31_scales = torch.randn( num_experts, ceil_div(2 * intermediate_size, 128), ceil_div(hidden_size, 128), dtype=torch.float32, ).cuda() w2_scales = torch.randn( num_experts, ceil_div(hidden_size, 128), ceil_div(intermediate_size, 128), dtype=torch.float32, ).cuda() for expert_id in range(num_experts): w31, w31_s = per_block_cast_to_fp8(w31_weight[expert_id, :]) w2, w2_s = per_block_cast_to_fp8(w2_weight[expert_id, :]) w31_quant.data[expert_id].copy_(w31) w31_scales.data[expert_id].copy_(w31_s) w2_quant.data[expert_id].copy_(w2) w2_scales.data[expert_id].copy_(w2_s) # Dequantize for verificationa x_dequant = dequantize_block(x_quant, x_scales, x.dtype, x.shape) w31_dequant = dequantize_block( w31_quant, w31_scales, w31_weight.dtype, w31_weight.shape ) w2_dequant = dequantize_block(w2_quant, w2_scales, w2_weight.dtype, w2_weight.shape) # Run reference implementation with dequantized tensors _ref_output = compute_with_experts( num_experts, x_dequant, w31_dequant, w2_dequant, selected_experts, routing_weights, ) quant_scales = [ w31_scales, # .view(-1), # W31 scales w2_scales, # .view(-1), # W2 scales ] # Call flashinfer implementation with block scaling and expect NotImplementedError with pytest.raises( NotImplementedError, match="DeepSeek FP8 Block Scaling is not yet implemented in CUTLASS for Blackwell", ): _ = fused_moe.cutlass_fused_moe( x.contiguous(), selected_experts.to(torch.int), routing_weights, w31_quant.contiguous(), w2_quant.contiguous(), otype, tp_size=1, tp_rank=0, use_deepseek_fp8_block_scale=True, quant_scales=quant_scales, ) def quant_mxfp4_batches(a, num_experts): quant_a = [] sfs = [] for i in range(num_experts): a_fp4, a_sf = mxfp4_quantize(a[i].cuda()) quant_a.append(a_fp4) sfs.append(a_sf) result_quant_a = torch.stack(quant_a) result_sfs = torch.stack(sfs) return result_quant_a, result_sfs def dequant_mxfp4_batches( mat_fp4: torch.Tensor, scale_tensor: torch.Tensor, ): num_batches = mat_fp4.size(0) scale_tensor = scale_tensor.view(num_batches, -1) return torch.stack( [ mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :]) for b in range(num_batches) ] ) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @pytest.mark.parametrize("top_k", TOP_K_VALUES) @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) @pytest.mark.parametrize("otype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize( ("alpha", "beta", "limit"), [(None, None, None), (0.5, 0.0, 7.0), (1.702, 1.0, 7.0)] ) @pytest.mark.skipif( torch.cuda.get_device_capability()[0] not in [10, 11, 12], reason="MXFP8xMXFP4 is only supported on SM100, SM110 and SM120", ) def test_moe_mxfp8_mxfp4( batch_size, hidden_size, num_experts, top_k, intermediate_size, otype, alpha, beta, limit, ): """ Test MoE with MXFP8 activations and MXFP4 weights. Uses mxfp8_quantize for activations and fp4_quantize for weights. """ # Skip invalid configurations if top_k > num_experts: pytest.skip( f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})" ) torch.manual_seed(42) e = num_experts m = batch_size n = intermediate_size k = hidden_size x = torch.randn(m, k, dtype=otype).cuda() w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=otype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=otype) / 10 mxfp8_x, mxfp8_x_sf = mxfp8_quantize(x, True, 32) mxfp4_w1, mxfp4_w1_scale = quant_mxfp4_batches(w1, e) mxfp4_w2, mxfp4_w2_scale = quant_mxfp4_batches(w2, e) router_logits = torch.randn(m, e, dtype=otype).cuda() routing_weights, selected_experts = compute_routing(router_logits, top_k) fake_input_scale = torch.ones(e, device=x.device) quant_scales = [ mxfp4_w1_scale.view(torch.int32), fake_input_scale, mxfp4_w2_scale.view(torch.int32), fake_input_scale, ] flash_output = torch.zeros_like(x) if alpha is not None and limit is not None and beta is not None: alpha_t = torch.ones(e, device=x.device) * alpha limit_t = torch.ones(e, device=x.device) * limit beta_t = torch.ones(e, device=x.device) * beta else: alpha_t = None limit_t = None beta_t = None # Call cutlass_fused_moe with MXFP8 activations and MXFP4 weights _ = fused_moe.cutlass_fused_moe( mxfp8_x, selected_experts.to(torch.int), routing_weights, mxfp4_w1.contiguous().view(torch.long), mxfp4_w2.contiguous().view(torch.long), otype, swiglu_alpha=alpha_t, swiglu_limit=limit_t, swiglu_beta=beta_t, quant_scales=quant_scales, input_sf=mxfp8_x_sf, use_mxfp8_act_scaling=True, output=flash_output, ) dq_mxfp8_x = ( mxfp8_dequantize_host( mxfp8_x.cpu().view(torch.uint8), mxfp8_x_sf.cpu().view(torch.uint8).reshape(-1), True, ) .cuda() .to(otype) ) dq_mfxp4_w1 = ( dequant_mxfp4_batches( mxfp4_w1.cpu().view(torch.uint8), mxfp4_w1_scale.cpu().view(torch.uint8).reshape(-1), ) .cuda() .to(otype) ) dq_mfxp4_w2 = ( dequant_mxfp4_batches( mxfp4_w2.cpu().view(torch.uint8), mxfp4_w2_scale.cpu().view(torch.uint8).reshape(-1), ) .cuda() .to(otype) ) # Use original weights for reference computation ref_output = compute_with_experts( e, dq_mxfp8_x, dq_mfxp4_w1, dq_mfxp4_w2, selected_experts, routing_weights, alpha, beta, limit, ) torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1) def dequant_mxfp4_batches_host( mat_fp4: torch.Tensor, scale_tensor: torch.Tensor, ): return torch.stack( [ mxfp4_dequantize_host(mat_fp4[b, :, :], scale_tensor[b, :, :]) for b in range(mat_fp4.size(0)) ] ) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @pytest.mark.parametrize("top_k", TOP_K_VALUES) @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) @pytest.mark.parametrize( ("alpha", "beta", "limit"), [(None, None, None), (0.5, 0.0, 7.0), (1.702, 1.0, 7.0)] ) @pytest.mark.skipif( torch.cuda.get_device_capability()[0] != 9, reason="BF16xMXFP4 is only supported on SM90", ) def test_moe_bf16_mxfp4( batch_size, hidden_size, num_experts, top_k, intermediate_size, alpha, beta, limit, ): """ Test MoE with bf16 activations and MXFP4 weights. Uses bf16 for activations and fp4_quantize for weights. """ # Skip invalid configurations if top_k > num_experts: pytest.skip( f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})" ) torch.manual_seed(42) e = num_experts m = batch_size n = intermediate_size k = hidden_size x = torch.randn(m, k, dtype=torch.bfloat16).cuda() w1 = torch.randint(0, 256, (e, 2 * n, k // 2), device="cuda", dtype=torch.uint8) w2 = torch.randint(0, 256, (e, k, n // 2), device="cuda", dtype=torch.uint8) w1_scale = torch.randint( 118, 123, (e, 2 * n, k // 32), device="cuda", dtype=torch.uint8 ) w2_scale = torch.randint( 118, 123, (e, k, n // 32), device="cuda", dtype=torch.uint8 ) router_logits = torch.randn(m, e, dtype=torch.bfloat16).cuda() routing_weights, selected_experts = compute_routing(router_logits, top_k) flash_output = torch.zeros_like(x) if alpha is not None and limit is not None and beta is not None: alpha_t = torch.ones(e, device=x.device) * alpha limit_t = torch.ones(e, device=x.device) * limit beta_t = torch.ones(e, device=x.device) * beta else: alpha_t = None limit_t = None beta_t = None pad_size = hidden_size - x.shape[1] x_pad = torch.nn.functional.pad(x, (0, pad_size)) quant_scales = [ w1_scale.view(torch.int32), w2_scale.view(torch.int32), ] # Call cutlass_fused_moe with BF16 activations and MXFP4 weights _ = fused_moe.cutlass_fused_moe( x_pad, selected_experts.to(torch.int), routing_weights, w1.contiguous().view(torch.uint8), w2.contiguous().view(torch.uint8), torch.bfloat16, swiglu_alpha=alpha_t, swiglu_limit=limit_t, swiglu_beta=beta_t, quant_scales=quant_scales, use_w4_group_scaling=True, output=flash_output, ) dq_mfxp4_w1 = ( dequant_mxfp4_batches_host( w1.cpu(), w1_scale.cpu(), ) .cuda() .to(torch.bfloat16) ) dq_mfxp4_w2 = ( dequant_mxfp4_batches_host( w2.cpu(), w2_scale.cpu(), ) .cuda() .to(torch.bfloat16) ) # Use original weights for reference computation ref_output = compute_with_experts( e, x, dq_mfxp4_w1, dq_mfxp4_w2, selected_experts, routing_weights, alpha, beta, limit, ) torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1) if __name__ == "__main__": pytest.main([__file__, "-v"])