sglang_v0.5.2/sglang/sgl-kernel/python/sgl_kernel/cutlass_moe.py

113 lines
3.9 KiB
Python

import torch
def get_cutlass_w4a8_moe_mm_data(
topk_ids: torch.Tensor,
expert_offsets: torch.Tensor,
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
input_permutation: torch.Tensor,
output_permutation: torch.Tensor,
num_experts: int,
n: int,
k: int,
):
"""
Prepare data necessary to perform CUTLASS grouped matrix multiplications
used in CUTLASS-based fused MoE.
The function takes in topk_ids (token-expert mapping) and uses it to
compute:
- expert_offsets: Indices that mark at which token index each expert begins
its computation after the input is sorted with
input_permutation. The number of tokens computed with
expert E is expert_offsets[E + 1] - expert_offsets[E]
- problem_sizes1, problem_sizes2: MxNxK sizes of each expert's
multiplication in two grouped MMs used in
the fused MoE operation.
- input_permutation: Permutation that must be used to shuffle the input
before executing the MMs.
- output_permutation: Permutation that must be used to shuffle the output
after executing the MMs.
"""
torch.ops.sgl_kernel.get_cutlass_w4a8_moe_mm_data.default(
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
input_permutation,
output_permutation,
num_experts,
n,
k,
)
def cutlass_w4a8_moe_mm(
d: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
a_scales: torch.Tensor,
b_scales: torch.Tensor,
experts_offsets: torch.tensor,
problem_sizes: torch.tensor,
a_strides: torch.tensor,
b_strides: torch.tensor,
d_strides: torch.tensor,
s_strides: torch.tensor,
chunk_size: int = 128,
topk: int = 8,
):
"""
Perform grouped matrix multiplication between int4 weights and fp8 activations.
This function executes multiple GEMM operations in parallel, which is useful for
scenarios like Mixture of Experts (MoE) where different inputs go through different
experts. The implementation leverages NVIDIA Hopper architecture features for
optimal performance with quantized weights.
Args:
d: Output matrices of shape [total_m, total_n]
a: Activation matrices in FP8 (float_e4m3_t) format
Each tensor should be of shape [total_m, K] in row-major layout
b: Weight matrices in packed int4 format
Each tensor should be of shape [E, N, K//2] in column-major layout
where each byte contains two 4-bit integers
a_scales: Scale factors for the inputs
b_scales: Scale factors for the quantized weights
Each tensor should be of shape [E, K//512, N*8]
experts_offsets: Tensor containing expert offsets for determining group boundaries
problem_sizes: with shape [num_experts, 3] (M, N, K for each group) (int32)
a_strides: Strides information for A matrices
b_strides: Strides information for B matrices
d_strides: Strides information for D matrices
s_strides: Strides information for b_scales matrices
chunk_size: Number of elements each scale value applies to (K//512), default to 128
Requirements:
- All tensors must be on a CUDA device
- Requires an NVIDIA Hopper GPU (H100)
- A tensors must be in float8_e4m3fn format
- B tensors must contain packed int4 values (stored as int8)
Note:
The function computes: D = (A * (B * scales))
for each group of tensors in parallel
"""
torch.ops.sgl_kernel.cutlass_w4a8_moe_mm.default(
d,
a,
b,
a_scales,
b_scales,
experts_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size,
topk,
)