import torch def get_cutlass_w4a8_moe_mm_data( topk_ids: torch.Tensor, expert_offsets: torch.Tensor, problem_sizes1: torch.Tensor, problem_sizes2: torch.Tensor, input_permutation: torch.Tensor, output_permutation: torch.Tensor, num_experts: int, n: int, k: int, ): """ Prepare data necessary to perform CUTLASS grouped matrix multiplications used in CUTLASS-based fused MoE. The function takes in topk_ids (token-expert mapping) and uses it to compute: - expert_offsets: Indices that mark at which token index each expert begins its computation after the input is sorted with input_permutation. The number of tokens computed with expert E is expert_offsets[E + 1] - expert_offsets[E] - problem_sizes1, problem_sizes2: MxNxK sizes of each expert's multiplication in two grouped MMs used in the fused MoE operation. - input_permutation: Permutation that must be used to shuffle the input before executing the MMs. - output_permutation: Permutation that must be used to shuffle the output after executing the MMs. """ torch.ops.sgl_kernel.get_cutlass_w4a8_moe_mm_data.default( topk_ids, expert_offsets, problem_sizes1, problem_sizes2, input_permutation, output_permutation, num_experts, n, k, ) def cutlass_w4a8_moe_mm( d: torch.Tensor, a: torch.Tensor, b: torch.Tensor, a_scales: torch.Tensor, b_scales: torch.Tensor, experts_offsets: torch.tensor, problem_sizes: torch.tensor, a_strides: torch.tensor, b_strides: torch.tensor, d_strides: torch.tensor, s_strides: torch.tensor, chunk_size: int = 128, topk: int = 8, ): """ Perform grouped matrix multiplication between int4 weights and fp8 activations. This function executes multiple GEMM operations in parallel, which is useful for scenarios like Mixture of Experts (MoE) where different inputs go through different experts. The implementation leverages NVIDIA Hopper architecture features for optimal performance with quantized weights. Args: d: Output matrices of shape [total_m, total_n] a: Activation matrices in FP8 (float_e4m3_t) format Each tensor should be of shape [total_m, K] in row-major layout b: Weight matrices in packed int4 format Each tensor should be of shape [E, N, K//2] in column-major layout where each byte contains two 4-bit integers a_scales: Scale factors for the inputs b_scales: Scale factors for the quantized weights Each tensor should be of shape [E, K//512, N*8] experts_offsets: Tensor containing expert offsets for determining group boundaries problem_sizes: with shape [num_experts, 3] (M, N, K for each group) (int32) a_strides: Strides information for A matrices b_strides: Strides information for B matrices d_strides: Strides information for D matrices s_strides: Strides information for b_scales matrices chunk_size: Number of elements each scale value applies to (K//512), default to 128 Requirements: - All tensors must be on a CUDA device - Requires an NVIDIA Hopper GPU (H100) - A tensors must be in float8_e4m3fn format - B tensors must contain packed int4 values (stored as int8) Note: The function computes: D = (A * (B * scales)) for each group of tensors in parallel """ torch.ops.sgl_kernel.cutlass_w4a8_moe_mm.default( d, a, b, a_scales, b_scales, experts_offsets, problem_sizes, a_strides, b_strides, d_strides, s_strides, chunk_size, topk, )