113 lines
3.9 KiB
Python
113 lines
3.9 KiB
Python
import torch
|
|
|
|
|
|
def get_cutlass_w4a8_moe_mm_data(
|
|
topk_ids: torch.Tensor,
|
|
expert_offsets: torch.Tensor,
|
|
problem_sizes1: torch.Tensor,
|
|
problem_sizes2: torch.Tensor,
|
|
input_permutation: torch.Tensor,
|
|
output_permutation: torch.Tensor,
|
|
num_experts: int,
|
|
n: int,
|
|
k: int,
|
|
):
|
|
"""
|
|
Prepare data necessary to perform CUTLASS grouped matrix multiplications
|
|
used in CUTLASS-based fused MoE.
|
|
|
|
The function takes in topk_ids (token-expert mapping) and uses it to
|
|
compute:
|
|
- expert_offsets: Indices that mark at which token index each expert begins
|
|
its computation after the input is sorted with
|
|
input_permutation. The number of tokens computed with
|
|
expert E is expert_offsets[E + 1] - expert_offsets[E]
|
|
- problem_sizes1, problem_sizes2: MxNxK sizes of each expert's
|
|
multiplication in two grouped MMs used in
|
|
the fused MoE operation.
|
|
- input_permutation: Permutation that must be used to shuffle the input
|
|
before executing the MMs.
|
|
- output_permutation: Permutation that must be used to shuffle the output
|
|
after executing the MMs.
|
|
"""
|
|
torch.ops.sgl_kernel.get_cutlass_w4a8_moe_mm_data.default(
|
|
topk_ids,
|
|
expert_offsets,
|
|
problem_sizes1,
|
|
problem_sizes2,
|
|
input_permutation,
|
|
output_permutation,
|
|
num_experts,
|
|
n,
|
|
k,
|
|
)
|
|
|
|
|
|
def cutlass_w4a8_moe_mm(
|
|
d: torch.Tensor,
|
|
a: torch.Tensor,
|
|
b: torch.Tensor,
|
|
a_scales: torch.Tensor,
|
|
b_scales: torch.Tensor,
|
|
experts_offsets: torch.tensor,
|
|
problem_sizes: torch.tensor,
|
|
a_strides: torch.tensor,
|
|
b_strides: torch.tensor,
|
|
d_strides: torch.tensor,
|
|
s_strides: torch.tensor,
|
|
chunk_size: int = 128,
|
|
topk: int = 8,
|
|
):
|
|
"""
|
|
Perform grouped matrix multiplication between int4 weights and fp8 activations.
|
|
|
|
This function executes multiple GEMM operations in parallel, which is useful for
|
|
scenarios like Mixture of Experts (MoE) where different inputs go through different
|
|
experts. The implementation leverages NVIDIA Hopper architecture features for
|
|
optimal performance with quantized weights.
|
|
|
|
Args:
|
|
d: Output matrices of shape [total_m, total_n]
|
|
a: Activation matrices in FP8 (float_e4m3_t) format
|
|
Each tensor should be of shape [total_m, K] in row-major layout
|
|
b: Weight matrices in packed int4 format
|
|
Each tensor should be of shape [E, N, K//2] in column-major layout
|
|
where each byte contains two 4-bit integers
|
|
a_scales: Scale factors for the inputs
|
|
b_scales: Scale factors for the quantized weights
|
|
Each tensor should be of shape [E, K//512, N*8]
|
|
experts_offsets: Tensor containing expert offsets for determining group boundaries
|
|
problem_sizes: with shape [num_experts, 3] (M, N, K for each group) (int32)
|
|
a_strides: Strides information for A matrices
|
|
b_strides: Strides information for B matrices
|
|
d_strides: Strides information for D matrices
|
|
s_strides: Strides information for b_scales matrices
|
|
chunk_size: Number of elements each scale value applies to (K//512), default to 128
|
|
|
|
Requirements:
|
|
- All tensors must be on a CUDA device
|
|
- Requires an NVIDIA Hopper GPU (H100)
|
|
- A tensors must be in float8_e4m3fn format
|
|
- B tensors must contain packed int4 values (stored as int8)
|
|
|
|
Note:
|
|
The function computes: D = (A * (B * scales))
|
|
for each group of tensors in parallel
|
|
"""
|
|
|
|
torch.ops.sgl_kernel.cutlass_w4a8_moe_mm.default(
|
|
d,
|
|
a,
|
|
b,
|
|
a_scales,
|
|
b_scales,
|
|
experts_offsets,
|
|
problem_sizes,
|
|
a_strides,
|
|
b_strides,
|
|
d_strides,
|
|
s_strides,
|
|
chunk_size,
|
|
topk,
|
|
)
|