sglang.0.4.8.post1/sglang/sgl-kernel/python/sgl_kernel/moe.py

257 lines
6.1 KiB
Python
Executable File

from typing import Any, Dict, Optional
import torch
def moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_token_ids,
experts_ids,
num_tokens_post_pad,
token_cnts_buffer,
cumsum_buffer,
pad_sorted_token_ids=False,
):
torch.ops.sgl_kernel.moe_align_block_size.default(
topk_ids,
num_experts,
block_size,
sorted_token_ids,
experts_ids,
num_tokens_post_pad,
token_cnts_buffer,
cumsum_buffer,
pad_sorted_token_ids,
)
def topk_softmax(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: float,
) -> None:
torch.ops.sgl_kernel.topk_softmax.default(
topk_weights, topk_ids, token_expert_indices, gating_output
)
def moe_fused_gate(
input_tensor,
bias,
num_expert_group,
topk_group,
topk,
num_fused_shared_experts=0,
routed_scaling_factor=0,
):
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
# as the group weight to select expert groups and then select topk experts within the selected groups
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now.
# for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
# num_fused_shared_experts: if > 0, the last several experts will be replaced with shared experts
# routed_scaling_factor: if > 0, the shared experts will be scaled by this factor
return torch.ops.sgl_kernel.moe_fused_gate.default(
input_tensor,
bias,
num_expert_group,
topk_group,
topk,
num_fused_shared_experts,
routed_scaling_factor,
)
def ep_moe_pre_reorder(
input_tensor,
gateup_input,
src2dst,
topk_ids,
a1_scales,
start_expert_id,
end_expert_id,
topk,
use_per_token_if_dynamic,
):
return torch.ops.sgl_kernel.ep_moe_pre_reorder.default(
input_tensor,
gateup_input,
src2dst,
topk_ids,
a1_scales,
start_expert_id,
end_expert_id,
topk,
use_per_token_if_dynamic,
)
def ep_moe_silu_and_mul(
gateup_output,
down_input,
reorder_topk_ids,
scales,
start_expert_id,
end_expert_id,
):
return torch.ops.sgl_kernel.ep_moe_silu_and_mul.default(
gateup_output,
down_input,
reorder_topk_ids,
scales,
start_expert_id,
end_expert_id,
)
def ep_moe_post_reorder(
down_output,
output,
src2dst,
topk_ids,
topk_weights,
start_expert_id,
end_expert_id,
topk,
):
return torch.ops.sgl_kernel.ep_moe_post_reorder.default(
down_output,
output,
src2dst,
topk_ids,
topk_weights,
start_expert_id,
end_expert_id,
topk,
)
def fp8_blockwise_scaled_grouped_mm(
output,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
scales_a,
scales_b,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets,
workspace,
):
torch.ops.sgl_kernel.fp8_blockwise_scaled_grouped_mm.default(
output,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
scales_a,
scales_b,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets,
workspace,
)
def prepare_moe_input(
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
input_permutation,
output_permutation,
num_experts,
n,
k,
blockscale_offsets: Optional[torch.Tensor] = None,
):
torch.ops.sgl_kernel.prepare_moe_input.default(
topk_ids,
expert_offsets,
blockscale_offsets,
problem_sizes1,
problem_sizes2,
input_permutation,
output_permutation,
num_experts,
n,
k,
)
def apply_shuffle_mul_sum(
input,
output,
permutation,
factors,
):
torch.ops.sgl_kernel.apply_shuffle_mul_sum.default(
input, output, permutation, factors
)
def cutlass_fp4_group_mm(
a_fp4,
b_fp4,
a_blockscale,
b_blockscale,
alphas,
out_dtype,
device,
params: Dict[str, Any],
):
"""
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
the gemms for each combination based on the specified problem sizes.
This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward.
- a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized
input and expert weights.
- a_/b_scales: The blockscales in FP8-E4M3 precision
- ab_strides/c_strides: Strides for the a/b tensors between rows.
- expert_offsets/sf_offsets: Indices that mark at which token index
each expert begins its computation. The number of tokens
computed with expert E is expert_offsets[E + 1] -
expert_offsets[E] And the sf_size per expert is
sf_offset[E+1] - sf_offset[E]
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
MMs used in the fused MoE operation.
"""
m_topk = a_fp4.shape[0]
n = b_fp4.shape[1]
c_shape = (m_topk, n)
c = torch.empty(c_shape, device=device, dtype=out_dtype)
torch.ops.sgl_kernel.cutlass_fp4_group_mm.default(
c,
a_fp4,
b_fp4,
a_blockscale,
b_blockscale,
alphas,
params["ab_strides"],
params["c_strides"],
params["problem_sizes"],
params["expert_offsets"],
params["blockscale_offsets"],
)
return c.to(dtype=out_dtype)