1047 lines
36 KiB
Python
1047 lines
36 KiB
Python
import logging
|
|
from typing import Callable, List, Optional, Tuple
|
|
|
|
import torch
|
|
|
|
# TODO: use deep_gemm masked kernel after low latency dispatch
|
|
# import deep_gemm
|
|
# from deep_gemm import (
|
|
# get_col_major_tma_aligned_tensor,
|
|
# m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
|
# )
|
|
from torch.nn import Module
|
|
|
|
from sglang.srt.custom_op import CustomOp
|
|
from sglang.srt.distributed import (
|
|
get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
)
|
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
gelu_and_mul_triton_kernel,
|
|
grouped_gemm_triton,
|
|
post_reorder_triton_kernel,
|
|
pre_reorder_triton_kernel,
|
|
run_moe_ep_preproess,
|
|
silu_and_mul_triton_kernel,
|
|
)
|
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
|
|
from sglang.srt.layers.moe.topk import select_experts
|
|
from sglang.srt.layers.quantization.base_config import (
|
|
QuantizationConfig,
|
|
QuantizeMethodBase,
|
|
)
|
|
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
|
from sglang.srt.utils import is_cuda, is_hip, set_weight_attrs
|
|
|
|
_is_cuda = is_cuda()
|
|
|
|
if _is_cuda:
|
|
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
|
else:
|
|
from vllm import _custom_ops as vllm_ops
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_is_hip = is_hip()
|
|
|
|
_buffer = None
|
|
|
|
|
|
class GroupedGemmRunner(torch.nn.Module):
|
|
flashinfer_gemm_warpper = None
|
|
|
|
def __init__(self, device, use_flashinfer: bool = False):
|
|
super().__init__()
|
|
self.device = device
|
|
self.use_flashinfer = use_flashinfer
|
|
if self.use_flashinfer and GroupedGemmRunner.flashinfer_gemm_warpper is None:
|
|
GroupedGemmRunner._init_flashinfer_wrapper(device)
|
|
|
|
@classmethod
|
|
def _init_flashinfer_wrapper(cls, device):
|
|
from flashinfer import SegmentGEMMWrapper
|
|
|
|
workspace_buffer = torch.empty(
|
|
128 * 1024 * 1024, dtype=torch.int8, device=device
|
|
)
|
|
cls.flashinfer_gemm_warpper = SegmentGEMMWrapper(workspace_buffer)
|
|
|
|
# c = a * b
|
|
def forward(
|
|
self,
|
|
a: torch.Tensor,
|
|
b: torch.Tensor,
|
|
c: torch.Tensor,
|
|
batch_size: int,
|
|
weight_column_major: bool,
|
|
seg_indptr: Optional[torch.Tensor] = None,
|
|
weight_indices: Optional[torch.Tensor] = None,
|
|
use_fp8_w8a8: bool = False,
|
|
scale_a: torch.Tensor = None,
|
|
scale_b: torch.Tensor = None,
|
|
block_shape: Optional[List[int]] = None,
|
|
):
|
|
if self.use_flashinfer:
|
|
# TODO: flashinfer
|
|
assert False
|
|
assert GroupedGemmRunner.flashinfer_gemm_warpper is not None
|
|
c = GroupedGemmRunner.flashinfer_gemm_warpper.run(
|
|
x=a,
|
|
weights=b,
|
|
batch_size=batch_size,
|
|
weight_column_major=weight_column_major,
|
|
seg_indptr=seg_indptr,
|
|
weight_indices=weight_indices,
|
|
)
|
|
else:
|
|
assert weight_column_major == True
|
|
c = grouped_gemm_triton(
|
|
a,
|
|
b,
|
|
c,
|
|
batch_size,
|
|
weight_column_major,
|
|
seg_indptr,
|
|
weight_indices,
|
|
use_fp8_w8a8,
|
|
scale_a,
|
|
scale_b,
|
|
block_shape=block_shape,
|
|
)
|
|
return c
|
|
|
|
|
|
class EPMoE(torch.nn.Module):
|
|
"""
|
|
MoE Expert Parallel Impl
|
|
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_experts: int,
|
|
top_k: int,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
renormalize: bool = True,
|
|
use_grouped_topk: bool = False,
|
|
num_expert_group: Optional[int] = None,
|
|
topk_group: Optional[int] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
tp_size: Optional[int] = None,
|
|
prefix: str = "",
|
|
correction_bias: Optional[torch.Tensor] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
activation: str = "silu",
|
|
):
|
|
super().__init__()
|
|
|
|
if params_dtype is None:
|
|
params_dtype = torch.get_default_dtype()
|
|
|
|
self.tp_size = (
|
|
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
|
)
|
|
self.tp_rank = get_tensor_model_parallel_rank()
|
|
|
|
self.num_experts = num_experts
|
|
assert self.num_experts % self.tp_size == 0
|
|
self.num_experts_per_partition = self.num_experts // self.tp_size
|
|
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
|
|
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
|
|
|
|
self.top_k = top_k
|
|
self.intermediate_size = intermediate_size
|
|
self.renormalize = renormalize
|
|
self.use_grouped_topk = use_grouped_topk
|
|
if self.use_grouped_topk:
|
|
assert num_expert_group is not None and topk_group is not None
|
|
self.num_expert_group = num_expert_group
|
|
self.topk_group = topk_group
|
|
self.correction_bias = correction_bias
|
|
self.custom_routing_function = custom_routing_function
|
|
self.activation = activation
|
|
|
|
if quant_config is None:
|
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
|
|
self.use_fp8_w8a8 = False
|
|
self.use_block_quant = False
|
|
self.block_shape = None
|
|
self.activation_scheme = None
|
|
else:
|
|
self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
|
|
quant_config
|
|
)
|
|
self.use_fp8_w8a8 = True
|
|
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
|
self.block_shape = (
|
|
self.quant_method.quant_config.weight_block_size
|
|
if self.use_block_quant
|
|
else None
|
|
)
|
|
self.fp8_dtype = torch.float8_e4m3fn
|
|
self.activation_scheme = quant_config.activation_scheme
|
|
|
|
self.quant_method.create_weights(
|
|
layer=self,
|
|
num_experts_per_partition=self.num_experts_per_partition,
|
|
hidden_size=hidden_size,
|
|
intermediate_size=self.intermediate_size,
|
|
params_dtype=params_dtype,
|
|
weight_loader=self.weight_loader,
|
|
)
|
|
|
|
self.grouped_gemm_runner = None
|
|
|
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
|
assert self.quant_method is not None
|
|
|
|
if self.grouped_gemm_runner is None:
|
|
self.grouped_gemm_runner = GroupedGemmRunner(
|
|
hidden_states.device,
|
|
use_flashinfer=False, # TODO: use flashinfer
|
|
)
|
|
|
|
topk_weights, topk_ids = select_experts(
|
|
hidden_states=hidden_states,
|
|
router_logits=router_logits,
|
|
top_k=self.top_k,
|
|
use_grouped_topk=self.use_grouped_topk,
|
|
renormalize=self.renormalize,
|
|
topk_group=self.topk_group,
|
|
num_expert_group=self.num_expert_group,
|
|
correction_bias=self.correction_bias,
|
|
custom_routing_function=self.custom_routing_function,
|
|
)
|
|
|
|
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
|
topk_ids, self.num_experts
|
|
)
|
|
|
|
gateup_input = torch.empty(
|
|
(int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]),
|
|
device=hidden_states.device,
|
|
dtype=(
|
|
self.fp8_dtype
|
|
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
|
else hidden_states.dtype
|
|
),
|
|
)
|
|
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
|
max_value = (
|
|
torch.max(hidden_states)
|
|
.repeat(self.num_experts_per_partition)
|
|
.to(torch.float32)
|
|
)
|
|
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
|
|
|
# PreReorder
|
|
pre_reorder_triton_kernel[(hidden_states.shape[0],)](
|
|
hidden_states,
|
|
gateup_input,
|
|
src2dst,
|
|
topk_ids,
|
|
self.w13_input_scale,
|
|
self.start_expert_id,
|
|
self.end_expert_id,
|
|
self.top_k,
|
|
hidden_states.shape[1],
|
|
BLOCK_SIZE=512,
|
|
)
|
|
|
|
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
|
|
weight_indices_cur_rank = torch.arange(
|
|
0,
|
|
self.num_experts_per_partition,
|
|
device=hidden_states.device,
|
|
dtype=torch.int64,
|
|
)
|
|
# GroupGemm-0
|
|
gateup_output = torch.empty(
|
|
gateup_input.shape[0],
|
|
self.w13_weight.shape[1],
|
|
device=hidden_states.device,
|
|
dtype=hidden_states.dtype,
|
|
)
|
|
gateup_output = self.grouped_gemm_runner(
|
|
a=gateup_input,
|
|
b=self.w13_weight,
|
|
c=gateup_output,
|
|
batch_size=self.num_experts_per_partition,
|
|
weight_column_major=True,
|
|
seg_indptr=seg_indptr_cur_rank,
|
|
weight_indices=weight_indices_cur_rank,
|
|
use_fp8_w8a8=self.use_fp8_w8a8,
|
|
scale_a=self.w13_input_scale,
|
|
scale_b=(
|
|
self.w13_weight_scale_inv
|
|
if self.use_block_quant
|
|
else self.w13_weight_scale
|
|
),
|
|
block_shape=self.block_shape,
|
|
)
|
|
|
|
# Act
|
|
down_input = torch.empty(
|
|
gateup_output.shape[0],
|
|
gateup_output.shape[1] // 2,
|
|
device=gateup_output.device,
|
|
dtype=(
|
|
self.fp8_dtype
|
|
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
|
else hidden_states.dtype
|
|
),
|
|
)
|
|
if self.w2_input_scale is None and not self.use_block_quant:
|
|
self.w2_input_scale = torch.ones(
|
|
self.num_experts_per_partition,
|
|
dtype=torch.float32,
|
|
device=hidden_states.device,
|
|
)
|
|
|
|
if self.activation == "silu":
|
|
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
|
gateup_output,
|
|
down_input,
|
|
gateup_output.shape[1],
|
|
reorder_topk_ids,
|
|
self.w2_input_scale,
|
|
self.start_expert_id,
|
|
self.end_expert_id,
|
|
BLOCK_SIZE=512,
|
|
)
|
|
elif self.activation == "gelu":
|
|
gelu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
|
gateup_output,
|
|
down_input,
|
|
gateup_output.shape[1],
|
|
reorder_topk_ids,
|
|
self.w2_input_scale,
|
|
self.start_expert_id,
|
|
self.end_expert_id,
|
|
BLOCK_SIZE=512,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported activation: {self.activation=}")
|
|
|
|
# GroupGemm-1
|
|
down_output = torch.empty(
|
|
down_input.shape[0],
|
|
self.w2_weight.shape[1],
|
|
device=hidden_states.device,
|
|
dtype=hidden_states.dtype,
|
|
)
|
|
down_output = self.grouped_gemm_runner(
|
|
a=down_input,
|
|
b=self.w2_weight,
|
|
c=down_output,
|
|
batch_size=self.num_experts_per_partition,
|
|
weight_column_major=True,
|
|
seg_indptr=seg_indptr_cur_rank,
|
|
weight_indices=weight_indices_cur_rank,
|
|
use_fp8_w8a8=self.use_fp8_w8a8,
|
|
scale_a=self.w2_input_scale,
|
|
scale_b=(
|
|
self.w2_weight_scale_inv
|
|
if self.use_block_quant
|
|
else self.w2_weight_scale
|
|
),
|
|
block_shape=self.block_shape,
|
|
)
|
|
|
|
# PostReorder
|
|
output = torch.empty_like(hidden_states)
|
|
post_reorder_triton_kernel[(hidden_states.size(0),)](
|
|
down_output,
|
|
output,
|
|
src2dst,
|
|
topk_ids,
|
|
topk_weights,
|
|
self.start_expert_id,
|
|
self.end_expert_id,
|
|
self.top_k,
|
|
hidden_states.size(1),
|
|
BLOCK_SIZE=512,
|
|
)
|
|
return output
|
|
|
|
@classmethod
|
|
def make_expert_params_mapping(
|
|
cls,
|
|
ckpt_gate_proj_name: str,
|
|
ckpt_down_proj_name: str,
|
|
ckpt_up_proj_name: str,
|
|
num_experts: int,
|
|
) -> List[Tuple[str, str, int, str]]:
|
|
return [
|
|
# (param_name, weight_name, expert_id, shard_id)
|
|
(
|
|
(
|
|
"experts.w13_"
|
|
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
|
|
else "experts.w2_"
|
|
),
|
|
f"experts.{expert_id}.{weight_name}.",
|
|
expert_id,
|
|
shard_id,
|
|
)
|
|
for expert_id in range(num_experts)
|
|
for shard_id, weight_name in [
|
|
("w1", ckpt_gate_proj_name),
|
|
("w2", ckpt_down_proj_name),
|
|
("w3", ckpt_up_proj_name),
|
|
]
|
|
]
|
|
|
|
def weight_loader(
|
|
self,
|
|
param: torch.nn.Parameter,
|
|
loaded_weight: torch.Tensor,
|
|
weight_name: str,
|
|
shard_id: str,
|
|
expert_id: int,
|
|
) -> None:
|
|
if expert_id < self.start_expert_id or expert_id > self.end_expert_id:
|
|
return
|
|
expert_id = expert_id - self.start_expert_id
|
|
|
|
if shard_id not in ("w1", "w2", "w3"):
|
|
raise ValueError(
|
|
f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."
|
|
)
|
|
|
|
# Special case for fp8 scales.
|
|
if "scale" in weight_name:
|
|
self._load_fp8_scale(
|
|
param.data,
|
|
loaded_weight,
|
|
weight_name,
|
|
shard_id,
|
|
expert_id,
|
|
)
|
|
return
|
|
|
|
if shard_id == "w2":
|
|
param.data[expert_id] = loaded_weight
|
|
elif shard_id == "w1":
|
|
param.data[expert_id][: self.intermediate_size, :] = loaded_weight
|
|
elif shard_id == "w3":
|
|
param.data[expert_id][self.intermediate_size :, :] = loaded_weight
|
|
else:
|
|
raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}")
|
|
|
|
def _load_fp8_scale(
|
|
self,
|
|
param: torch.nn.Parameter,
|
|
loaded_weight: torch.Tensor,
|
|
weight_name: str,
|
|
shard_id: str,
|
|
expert_id: int,
|
|
) -> None:
|
|
param_data = param.data
|
|
|
|
# Input scales can be loaded directly and should be equal.
|
|
if "input_scale" in weight_name:
|
|
if (
|
|
param_data[expert_id] != 1
|
|
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
|
|
):
|
|
raise ValueError(
|
|
"input_scales of w1 and w3 of a layer "
|
|
f"must be equal. But got {param_data[expert_id]} "
|
|
f"vs. {loaded_weight}"
|
|
)
|
|
param_data[expert_id] = loaded_weight
|
|
# Weight scales
|
|
elif "weight_scale" in weight_name:
|
|
if self.use_block_quant:
|
|
block_n, block_k = self.block_shape[0], self.block_shape[1]
|
|
if shard_id == "w1":
|
|
param_data[expert_id][
|
|
: (self.intermediate_size + block_n - 1) // block_n, :
|
|
] = loaded_weight
|
|
elif shard_id == "w3":
|
|
param_data[expert_id][
|
|
(self.intermediate_size + block_n - 1) // block_n :, :
|
|
] = loaded_weight
|
|
else: # w2
|
|
param_data[expert_id] = loaded_weight
|
|
# If we are in merged column case (gate_up_proj)
|
|
else:
|
|
if shard_id in ("w1", "w3"):
|
|
# We have to keep the weight scales of w1 and w3 because
|
|
# we need to re-quantize w1/w3 weights after weight loading.
|
|
idx = 0 if shard_id == "w1" else 1
|
|
param_data[expert_id][idx] = loaded_weight
|
|
|
|
# If we are in the row parallel case (down_proj)
|
|
else:
|
|
param_data[expert_id] = loaded_weight
|
|
|
|
|
|
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
|
|
def create_weights(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
num_experts_per_partition: int,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
params_dtype: torch.dtype,
|
|
**extra_weight_attrs,
|
|
):
|
|
# Fused gate_up_proj (column parallel)
|
|
w13_weight = torch.nn.Parameter(
|
|
torch.empty(
|
|
num_experts_per_partition,
|
|
2 * intermediate_size,
|
|
hidden_size,
|
|
dtype=params_dtype,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
layer.register_parameter("w13_weight", w13_weight)
|
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
|
|
|
# down_proj (row parallel)
|
|
w2_weight = torch.nn.Parameter(
|
|
torch.empty(
|
|
num_experts_per_partition,
|
|
hidden_size,
|
|
intermediate_size,
|
|
dtype=params_dtype,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
layer.register_parameter("w2_weight", w2_weight)
|
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
|
|
|
# scale
|
|
ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
|
|
w13_input_scale = torch.nn.Parameter(
|
|
ones_tensor,
|
|
requires_grad=False,
|
|
)
|
|
layer.register_parameter("w13_input_scale", w13_input_scale)
|
|
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
|
|
|
w2_input_scale = torch.nn.Parameter(
|
|
ones_tensor,
|
|
requires_grad=False,
|
|
)
|
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
|
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
|
|
|
w13_weight_scale = torch.nn.Parameter(
|
|
ones_tensor,
|
|
requires_grad=False,
|
|
)
|
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
|
|
|
w2_weight_scale = torch.nn.Parameter(
|
|
ones_tensor,
|
|
requires_grad=False,
|
|
)
|
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
renormalize: bool,
|
|
use_grouped_topk: bool,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
) -> torch.Tensor:
|
|
raise NotImplementedError
|
|
|
|
|
|
class Fp8EPMoEMethod(Fp8MoEMethod):
|
|
"""MoE method for FP8.
|
|
Supports loading FP8 checkpoints with static weight scale and
|
|
dynamic/static activation scale.
|
|
|
|
Args:
|
|
quant_config: The quantization config.
|
|
"""
|
|
|
|
def __init__(self, quant_config: Fp8Config):
|
|
self.quant_config = quant_config
|
|
self.block_quant = self.quant_config.weight_block_size is not None
|
|
|
|
def create_weights(
|
|
self,
|
|
layer: Module,
|
|
num_experts_per_partition: int,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
params_dtype: torch.dtype,
|
|
**extra_weight_attrs,
|
|
):
|
|
|
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
|
params_dtype = torch.float8_e4m3fn
|
|
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
if self.block_quant:
|
|
block_n, block_k = (
|
|
self.quant_config.weight_block_size[0],
|
|
self.quant_config.weight_block_size[1],
|
|
)
|
|
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
|
# Required by collum parallel or enabling merged weights
|
|
if intermediate_size % block_n != 0:
|
|
raise ValueError(
|
|
f"The output_size of gate's and up's weight = "
|
|
f"{intermediate_size} is not divisible by "
|
|
f"weight quantization block_n = {block_n}."
|
|
)
|
|
if tp_size > 1:
|
|
# Required by row parallel
|
|
if intermediate_size % block_k != 0:
|
|
raise ValueError(
|
|
f"The input_size of down's weight = "
|
|
f"{intermediate_size} is not divisible by "
|
|
f"weight quantization block_k = {block_k}."
|
|
)
|
|
|
|
# WEIGHTS
|
|
w13_weight = torch.nn.Parameter(
|
|
torch.empty(
|
|
num_experts_per_partition,
|
|
2 * intermediate_size,
|
|
hidden_size,
|
|
dtype=params_dtype,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
layer.register_parameter("w13_weight", w13_weight)
|
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
|
|
|
w2_weight = torch.nn.Parameter(
|
|
torch.empty(
|
|
num_experts_per_partition,
|
|
hidden_size,
|
|
intermediate_size,
|
|
dtype=params_dtype,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
layer.register_parameter("w2_weight", w2_weight)
|
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
|
|
|
# WEIGHT_SCALES
|
|
if self.block_quant:
|
|
w13_weight_scale = torch.nn.Parameter(
|
|
torch.ones(
|
|
num_experts_per_partition,
|
|
2 * ((intermediate_size + block_n - 1) // block_n),
|
|
(hidden_size + block_k - 1) // block_k,
|
|
dtype=torch.float32,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
w2_weight_scale = torch.nn.Parameter(
|
|
torch.ones(
|
|
num_experts_per_partition,
|
|
(hidden_size + block_n - 1) // block_n,
|
|
(intermediate_size + block_k - 1) // block_k,
|
|
dtype=torch.float32,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
|
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
|
assert self.quant_config.activation_scheme == "dynamic"
|
|
else:
|
|
# WEIGHT_SCALES
|
|
# Allocate 2 scales for w1 and w3 respectively.
|
|
w13_weight_scale = torch.nn.Parameter(
|
|
torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
|
|
requires_grad=False,
|
|
)
|
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
|
|
|
w2_weight_scale = torch.nn.Parameter(
|
|
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
|
requires_grad=False,
|
|
)
|
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
|
# Add the quantization method used (per tensor/grouped/channel)
|
|
# to ensure the weight scales are loaded in properly
|
|
extra_weight_attrs.update(
|
|
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
|
|
if self.block_quant
|
|
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
|
)
|
|
# If loading fp8 checkpoint, pass the weight loaders.
|
|
# If loading an fp16 checkpoint, do not (we will quantize in
|
|
# process_weights_after_loading()
|
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
|
|
|
# INPUT_SCALES
|
|
if self.quant_config.activation_scheme == "static":
|
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
|
raise ValueError(
|
|
"Found static activation scheme for checkpoint that "
|
|
"was not serialized fp8."
|
|
)
|
|
|
|
w13_input_scale = torch.nn.Parameter(
|
|
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
|
requires_grad=False,
|
|
)
|
|
layer.register_parameter("w13_input_scale", w13_input_scale)
|
|
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
|
|
|
w2_input_scale = torch.nn.Parameter(
|
|
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
|
requires_grad=False,
|
|
)
|
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
|
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
|
|
|
else:
|
|
layer.w13_input_scale = None
|
|
layer.w2_input_scale = None
|
|
|
|
def process_weights_after_loading(self, layer: Module) -> None:
|
|
|
|
# If checkpoint is fp16, quantize in place.
|
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
|
# If rocm, use float8_e4m3fnuz as dtype
|
|
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
|
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
|
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
|
|
|
layer.w13_weight_scale = torch.nn.Parameter(
|
|
torch.ones(
|
|
layer.num_experts_per_partition,
|
|
dtype=torch.float32,
|
|
device=w13_weight.device,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
|
|
for expert in range(layer.num_experts_per_partition):
|
|
if _is_cuda:
|
|
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
|
sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
|
)
|
|
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
|
sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
|
)
|
|
else:
|
|
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
|
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
|
)
|
|
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
|
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
|
)
|
|
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
|
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
|
return
|
|
|
|
# If checkpoint is fp8, we need to handle that the
|
|
# MoE kernels require single activation scale and single weight
|
|
# scale for w13 per expert.
|
|
else:
|
|
if self.quant_config.activation_scheme == "static":
|
|
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
|
raise ValueError(
|
|
"QuantConfig has static quantization, but found "
|
|
"activation scales are None."
|
|
)
|
|
layer.w13_weight_scale = torch.nn.Parameter(
|
|
torch.max(layer.w13_weight_scale, dim=1).values,
|
|
requires_grad=False,
|
|
)
|
|
return
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
renormalize: bool,
|
|
use_grouped_topk: bool,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
) -> torch.Tensor:
|
|
raise NotImplementedError
|
|
|
|
|
|
class DeepEPMoE(EPMoE):
|
|
"""
|
|
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
|
|
"""
|
|
|
|
_has_printed = False
|
|
|
|
def __init__(
|
|
self,
|
|
num_experts: int,
|
|
top_k: int,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
renormalize: bool = True,
|
|
use_grouped_topk: bool = False,
|
|
num_expert_group: Optional[int] = None,
|
|
topk_group: Optional[int] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
tp_size: Optional[int] = None,
|
|
prefix: str = "",
|
|
correction_bias: Optional[torch.Tensor] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
activation: str = "silu",
|
|
):
|
|
super().__init__(
|
|
num_experts,
|
|
top_k,
|
|
hidden_size,
|
|
intermediate_size,
|
|
params_dtype,
|
|
renormalize,
|
|
use_grouped_topk,
|
|
num_expert_group,
|
|
topk_group,
|
|
quant_config,
|
|
tp_size,
|
|
prefix,
|
|
correction_bias,
|
|
custom_routing_function,
|
|
activation,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
reorder_topk_ids: torch.Tensor,
|
|
seg_indptr: torch.Tensor,
|
|
forward_mode: ForwardMode,
|
|
):
|
|
# Todo: use m_grouped_gemm_fp8_fp8_bf16_nt_masked after low_latency dispatch (decode)
|
|
if True: # not forward_mode.is_decode():
|
|
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
|
|
else:
|
|
return self.forward_deepgemm_masked(
|
|
hidden_states, reorder_topk_ids, seg_indptr
|
|
)
|
|
|
|
def forward_normal(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
reorder_topk_ids: torch.Tensor,
|
|
seg_indptr: torch.Tensor,
|
|
):
|
|
assert self.quant_method is not None
|
|
assert self.activation == "silu"
|
|
if self.grouped_gemm_runner is None:
|
|
self.grouped_gemm_runner = GroupedGemmRunner(
|
|
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
|
|
)
|
|
|
|
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
|
max_value = (
|
|
torch.max(hidden_states)
|
|
.repeat(self.num_experts_per_partition)
|
|
.to(torch.float32)
|
|
)
|
|
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
|
weight_indices_cur_rank = torch.arange(
|
|
0,
|
|
self.num_experts_per_partition,
|
|
device=hidden_states.device,
|
|
dtype=torch.int64,
|
|
)
|
|
|
|
# GroupGemm-0
|
|
gateup_output = torch.empty(
|
|
hidden_states.shape[0],
|
|
self.w13_weight.shape[1],
|
|
device=hidden_states.device,
|
|
dtype=hidden_states.dtype,
|
|
)
|
|
|
|
if hidden_states.shape[0] > 0:
|
|
gateup_output = self.grouped_gemm_runner(
|
|
a=hidden_states,
|
|
b=self.w13_weight,
|
|
c=gateup_output,
|
|
batch_size=self.num_experts_per_partition,
|
|
weight_column_major=True,
|
|
seg_indptr=seg_indptr,
|
|
weight_indices=weight_indices_cur_rank,
|
|
use_fp8_w8a8=self.use_fp8_w8a8,
|
|
scale_a=self.w13_input_scale,
|
|
scale_b=(
|
|
self.w13_weight_scale_inv
|
|
if self.use_block_quant
|
|
else self.w13_weight_scale
|
|
),
|
|
block_shape=self.block_shape,
|
|
)
|
|
|
|
# Act
|
|
down_input = torch.empty(
|
|
gateup_output.shape[0],
|
|
gateup_output.shape[1] // 2,
|
|
device=gateup_output.device,
|
|
dtype=(
|
|
self.fp8_dtype
|
|
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
|
else hidden_states.dtype
|
|
),
|
|
)
|
|
if self.w2_input_scale is None and not self.use_block_quant:
|
|
self.w2_input_scale = torch.ones(
|
|
self.num_experts_per_partition,
|
|
dtype=torch.float32,
|
|
device=hidden_states.device,
|
|
)
|
|
|
|
if self.activation == "silu":
|
|
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
|
gateup_output,
|
|
down_input,
|
|
gateup_output.shape[1],
|
|
reorder_topk_ids,
|
|
self.w2_input_scale,
|
|
0,
|
|
self.num_experts_per_partition - 1,
|
|
BLOCK_SIZE=512,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported activation: {self.activation=}")
|
|
|
|
# GroupGemm-1
|
|
down_output = torch.empty(
|
|
down_input.shape[0],
|
|
self.w2_weight.shape[1],
|
|
device=hidden_states.device,
|
|
dtype=hidden_states.dtype,
|
|
)
|
|
if down_input.shape[0] > 0:
|
|
down_output = self.grouped_gemm_runner(
|
|
a=down_input,
|
|
b=self.w2_weight,
|
|
c=down_output,
|
|
batch_size=self.num_experts_per_partition,
|
|
weight_column_major=True,
|
|
seg_indptr=seg_indptr,
|
|
weight_indices=weight_indices_cur_rank,
|
|
use_fp8_w8a8=self.use_fp8_w8a8,
|
|
scale_a=self.w2_input_scale,
|
|
scale_b=(
|
|
self.w2_weight_scale_inv
|
|
if self.use_block_quant
|
|
else self.w2_weight_scale
|
|
),
|
|
block_shape=self.block_shape,
|
|
)
|
|
return down_output
|
|
|
|
def forward_deepgemm_masked(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
reorder_topk_ids: torch.Tensor,
|
|
seg_indptr: torch.Tensor,
|
|
):
|
|
assert self.quant_method is not None
|
|
assert self.activation == "silu"
|
|
|
|
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
|
max_value = (
|
|
torch.max(hidden_states)
|
|
.repeat(self.num_experts_per_partition)
|
|
.to(torch.float32)
|
|
)
|
|
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
|
|
|
# GroupGemm-0
|
|
gateup_output = torch.empty(
|
|
hidden_states.shape[0],
|
|
self.w13_weight.shape[1],
|
|
device=hidden_states.device,
|
|
dtype=hidden_states.dtype,
|
|
)
|
|
if hidden_states.shape[0] > 0:
|
|
# Transpose earlier so that the testing will not trigger transposing kernels
|
|
hidden_states = (
|
|
hidden_states[0],
|
|
get_col_major_tma_aligned_tensor(hidden_states[1]),
|
|
)
|
|
"""
|
|
gateup_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
|
hidden_states, self.w13_weight, out, masked_m, expected_m
|
|
)
|
|
"""
|
|
|
|
# Act
|
|
down_input = torch.empty(
|
|
gateup_output.shape[0],
|
|
gateup_output.shape[1] // 2,
|
|
device=gateup_output.device,
|
|
dtype=(
|
|
self.fp8_dtype
|
|
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
|
else hidden_states.dtype
|
|
),
|
|
)
|
|
if self.w2_input_scale is None and not self.use_block_quant:
|
|
self.w2_input_scale = torch.ones(
|
|
self.num_experts_per_partition,
|
|
dtype=torch.float32,
|
|
device=hidden_states.device,
|
|
)
|
|
|
|
if self.activation == "silu":
|
|
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
|
gateup_output,
|
|
down_input,
|
|
gateup_output.shape[1],
|
|
reorder_topk_ids,
|
|
self.w2_input_scale,
|
|
0,
|
|
self.num_experts_per_partition - 1,
|
|
BLOCK_SIZE=512,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported activation: {self.activation=}")
|
|
|
|
# GroupGemm-1
|
|
down_output = torch.empty(
|
|
down_input.shape[0],
|
|
self.w2_weight.shape[1],
|
|
device=hidden_states.device,
|
|
dtype=hidden_states.dtype,
|
|
)
|
|
if down_input.shape[0] > 0:
|
|
# Transpose earlier so that the testing will not trigger transposing kernels
|
|
down_input = (
|
|
down_input[0],
|
|
get_col_major_tma_aligned_tensor(down_input[1]),
|
|
)
|
|
"""
|
|
down_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
|
down_input, self.w2_weight, out, masked_m, expected_m
|
|
)
|
|
"""
|
|
|
|
return down_output
|