702 lines
24 KiB
Python
702 lines
24 KiB
Python
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
|
|
|
from abc import abstractmethod
|
|
from enum import Enum
|
|
from typing import Callable, List, Optional, Tuple
|
|
|
|
import torch
|
|
|
|
from sglang.srt.custom_op import CustomOp
|
|
from sglang.srt.distributed import (
|
|
get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
tensor_model_parallel_all_reduce,
|
|
)
|
|
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
|
from sglang.srt.layers.moe.topk import select_experts
|
|
from sglang.srt.layers.quantization.base_config import (
|
|
QuantizationConfig,
|
|
QuantizeMethodBase,
|
|
)
|
|
from sglang.srt.utils import get_bool_env_var, is_hip, permute_weight, set_weight_attrs
|
|
|
|
if torch.cuda.is_available():
|
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
|
else:
|
|
fused_experts = None # type: ignore
|
|
|
|
import logging
|
|
|
|
_is_hip = is_hip()
|
|
|
|
if _is_hip:
|
|
from aiter import ck_moe
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FusedMoeWeightScaleSupported(Enum):
|
|
TENSOR = "tensor"
|
|
CHANNEL = "channel"
|
|
GROUP = "group"
|
|
BLOCK = "block"
|
|
|
|
|
|
class FusedMoEMethodBase(QuantizeMethodBase):
|
|
|
|
@abstractmethod
|
|
def create_weights(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
num_experts: int,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
params_dtype: torch.dtype,
|
|
**extra_weight_attrs,
|
|
):
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
renormalize: bool,
|
|
use_grouped_topk: bool,
|
|
) -> torch.Tensor:
|
|
raise NotImplementedError
|
|
|
|
|
|
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
"""MoE method without quantization."""
|
|
|
|
def create_weights(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
num_experts: int,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
params_dtype: torch.dtype,
|
|
**extra_weight_attrs,
|
|
):
|
|
# Fused gate_up_proj (column parallel)
|
|
w13_weight = torch.nn.Parameter(
|
|
torch.empty(
|
|
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
layer.register_parameter("w13_weight", w13_weight)
|
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
|
|
|
# down_proj (row parallel)
|
|
w2_weight = torch.nn.Parameter(
|
|
torch.empty(
|
|
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
layer.register_parameter("w2_weight", w2_weight)
|
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
if _is_hip and get_bool_env_var("CK_MOE"):
|
|
layer.w13_weight = torch.nn.Parameter(
|
|
permute_weight(layer.w13_weight.data),
|
|
requires_grad=False,
|
|
)
|
|
torch.cuda.empty_cache()
|
|
layer.w2_weight = torch.nn.Parameter(
|
|
permute_weight(layer.w2_weight.data),
|
|
requires_grad=False,
|
|
)
|
|
torch.cuda.empty_cache()
|
|
return
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
renormalize: bool,
|
|
use_grouped_topk: bool,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
correction_bias: Optional[torch.Tensor] = None,
|
|
activation: str = "silu",
|
|
inplace: bool = True,
|
|
no_combine: bool = False,
|
|
) -> torch.Tensor:
|
|
return self.forward(
|
|
x=x,
|
|
layer=layer,
|
|
router_logits=router_logits,
|
|
top_k=top_k,
|
|
renormalize=renormalize,
|
|
use_grouped_topk=use_grouped_topk,
|
|
topk_group=topk_group,
|
|
num_expert_group=num_expert_group,
|
|
custom_routing_function=custom_routing_function,
|
|
correction_bias=correction_bias,
|
|
activation=activation,
|
|
inplace=inplace,
|
|
no_combine=no_combine,
|
|
)
|
|
|
|
def forward_cuda(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
use_grouped_topk: bool,
|
|
top_k: int,
|
|
router_logits: torch.Tensor,
|
|
renormalize: bool,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
correction_bias: Optional[torch.Tensor] = None,
|
|
activation: str = "silu",
|
|
inplace: bool = True,
|
|
no_combine: bool = False,
|
|
) -> torch.Tensor:
|
|
topk_weights, topk_ids = select_experts(
|
|
hidden_states=x,
|
|
router_logits=router_logits,
|
|
use_grouped_topk=use_grouped_topk,
|
|
top_k=top_k,
|
|
renormalize=renormalize,
|
|
topk_group=topk_group,
|
|
num_expert_group=num_expert_group,
|
|
custom_routing_function=custom_routing_function,
|
|
correction_bias=correction_bias,
|
|
)
|
|
|
|
if _is_hip and get_bool_env_var("CK_MOE"):
|
|
assert not no_combine, "unsupported"
|
|
return ck_moe(
|
|
x,
|
|
layer.w13_weight,
|
|
layer.w2_weight,
|
|
topk_weights,
|
|
topk_ids,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
32,
|
|
None,
|
|
activation,
|
|
)
|
|
else:
|
|
return fused_experts(
|
|
hidden_states=x,
|
|
w1=layer.w13_weight,
|
|
w2=layer.w2_weight,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
inplace=inplace and not no_combine,
|
|
activation=activation,
|
|
no_combine=no_combine,
|
|
)
|
|
|
|
def forward_cpu(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
use_grouped_topk: bool,
|
|
top_k: int,
|
|
router_logits: torch.Tensor,
|
|
renormalize: bool,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
correction_bias: Optional[torch.Tensor] = None,
|
|
inplace: bool = True,
|
|
) -> torch.Tensor:
|
|
return moe_forward_native(
|
|
layer,
|
|
x,
|
|
use_grouped_topk,
|
|
top_k,
|
|
router_logits,
|
|
renormalize,
|
|
topk_group,
|
|
num_expert_group,
|
|
custom_routing_function,
|
|
correction_bias,
|
|
)
|
|
|
|
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
|
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
|
|
|
forward_native = forward_cuda
|
|
|
|
|
|
class FusedMoE(torch.nn.Module):
|
|
"""FusedMoE layer for MoE models.
|
|
|
|
This layer contains both MergedColumnParallel weights (gate_up_proj /
|
|
w13) and RowParallelLinear weights (down_proj/ w2).
|
|
|
|
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
|
|
copy that naming convention here and handle any remapping in the
|
|
load_weights function in each model implementation.
|
|
|
|
Args:
|
|
num_experts: Number of experts in the model
|
|
top_k: Number of experts selected for each token
|
|
hidden_size: Input hidden state size of the transformer
|
|
intermediate_size: Intermediate size of the experts
|
|
params_dtype: Data type for the parameters.
|
|
reduce_results: Whether to all all_reduce on the output of the layer
|
|
renomalize: Whether to renormalize the logits in the fused_moe kernel
|
|
quant_config: Quantization configure.
|
|
inplace: suggestion to compute inplace (modify input activation).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_experts: int,
|
|
top_k: int,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
reduce_results: bool = False,
|
|
renormalize: bool = True,
|
|
use_grouped_topk: bool = False,
|
|
num_expert_group: Optional[int] = None,
|
|
topk_group: Optional[int] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
tp_size: Optional[int] = None,
|
|
prefix: str = "",
|
|
custom_routing_function: Optional[Callable] = None,
|
|
correction_bias: Optional[torch.Tensor] = None,
|
|
activation: str = "silu",
|
|
use_presharded_weights: bool = False,
|
|
inplace: bool = True,
|
|
no_combine: bool = False,
|
|
):
|
|
super().__init__()
|
|
|
|
if params_dtype is None:
|
|
params_dtype = torch.get_default_dtype()
|
|
|
|
self.tp_size = (
|
|
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
|
)
|
|
self.top_k = top_k
|
|
self.num_experts = num_experts
|
|
assert intermediate_size % self.tp_size == 0
|
|
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
|
self.reduce_results = reduce_results
|
|
self.renormalize = renormalize
|
|
self.use_grouped_topk = use_grouped_topk
|
|
if self.use_grouped_topk:
|
|
assert num_expert_group is not None and topk_group is not None
|
|
self.num_expert_group = num_expert_group
|
|
self.topk_group = topk_group
|
|
self.custom_routing_function = custom_routing_function
|
|
self.correction_bias = correction_bias
|
|
self.activation = activation
|
|
self.use_presharded_weights = use_presharded_weights
|
|
self.inplace = inplace
|
|
self.no_combine = no_combine
|
|
self.local_num_experts = num_experts
|
|
|
|
if quant_config is None:
|
|
self.quant_method: Optional[QuantizeMethodBase] = (
|
|
UnquantizedFusedMoEMethod()
|
|
)
|
|
else:
|
|
self.quant_method = quant_config.get_quant_method(self, prefix)
|
|
assert self.quant_method is not None
|
|
|
|
self.quant_method.create_weights(
|
|
layer=self,
|
|
num_experts=num_experts,
|
|
hidden_size=hidden_size,
|
|
# FIXME: figure out which intermediate_size to use
|
|
intermediate_size=self.intermediate_size_per_partition,
|
|
intermediate_size_per_partition=self.intermediate_size_per_partition,
|
|
params_dtype=params_dtype,
|
|
weight_loader=self.weight_loader,
|
|
)
|
|
|
|
def _load_per_tensor_weight_scale(
|
|
self,
|
|
shard_id: str,
|
|
param: torch.nn.Parameter,
|
|
loaded_weight: torch.Tensor,
|
|
expert_id: int,
|
|
):
|
|
param_data = param.data
|
|
# for per tensor weight quantization
|
|
if shard_id in ("w1", "w3"):
|
|
# We have to keep the weight scales of w1 and w3 because
|
|
# we need to re-quantize w1/w3 weights after weight loading.
|
|
idx = 0 if shard_id == "w1" else 1
|
|
param_data[expert_id][idx] = loaded_weight
|
|
# If we are in the row parallel case (down_proj)
|
|
elif shard_id == "w2":
|
|
param_data[expert_id] = loaded_weight
|
|
|
|
def _load_model_weight_or_group_weight_scale(
|
|
self,
|
|
shard_dim: int,
|
|
expert_data: torch.Tensor,
|
|
shard_id: str,
|
|
loaded_weight: torch.tensor,
|
|
tp_rank: int,
|
|
):
|
|
# Load grouped weight scales for group quantization
|
|
# or model weights
|
|
if shard_id == "w2":
|
|
self._load_w2(
|
|
shard_id=shard_id,
|
|
shard_dim=shard_dim,
|
|
loaded_weight=loaded_weight,
|
|
expert_data=expert_data,
|
|
tp_rank=tp_rank,
|
|
)
|
|
elif shard_id in ("w1", "w3"):
|
|
self._load_w13(
|
|
shard_id=shard_id,
|
|
shard_dim=shard_dim,
|
|
loaded_weight=loaded_weight,
|
|
expert_data=expert_data,
|
|
tp_rank=tp_rank,
|
|
)
|
|
|
|
def _load_per_channel_weight_scale(
|
|
self,
|
|
expert_data: torch.Tensor,
|
|
shard_dim: int,
|
|
shard_id: str,
|
|
loaded_weight: torch.tensor,
|
|
tp_rank: int,
|
|
):
|
|
# for per channel weight quantization
|
|
if shard_id == "w2":
|
|
expert_data.copy_(loaded_weight)
|
|
elif shard_id in ("w1", "w3"):
|
|
self._load_w13(
|
|
shard_id=shard_id,
|
|
shard_dim=shard_dim,
|
|
loaded_weight=loaded_weight,
|
|
expert_data=expert_data,
|
|
tp_rank=tp_rank,
|
|
)
|
|
|
|
def _load_w13(
|
|
self,
|
|
expert_data: torch.Tensor,
|
|
shard_dim: int,
|
|
shard_id: str,
|
|
loaded_weight: torch.tensor,
|
|
tp_rank: int,
|
|
):
|
|
|
|
# Index the loaded weight for tp sharding.
|
|
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
|
shard_size = expert_data.shape[shard_dim] // 2
|
|
|
|
if not self.use_presharded_weights:
|
|
loaded_weight = loaded_weight.narrow(
|
|
shard_dim, shard_size * tp_rank, shard_size
|
|
)
|
|
|
|
# Narrow parameter and load.
|
|
# w1, gate_proj: Load into first logical weight of w13.
|
|
if shard_id == "w1":
|
|
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
|
# w3, up_proj: Load into second logical weight of w13.
|
|
else:
|
|
assert shard_id == "w3"
|
|
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
|
expert_data.copy_(loaded_weight)
|
|
|
|
def _load_w2(
|
|
self,
|
|
expert_data: torch.Tensor,
|
|
shard_dim: int,
|
|
shard_id: str,
|
|
loaded_weight: torch.tensor,
|
|
tp_rank: int,
|
|
):
|
|
|
|
# Index the loaded weight for tp sharding.
|
|
# down_proj: "RowParallel" so tp sharding on input_dim
|
|
# Narrow parameter and load.
|
|
shard_size = expert_data.shape[shard_dim]
|
|
|
|
if not self.use_presharded_weights:
|
|
loaded_weight = loaded_weight.narrow(
|
|
shard_dim, shard_size * tp_rank, shard_size
|
|
)
|
|
|
|
# w2, down_proj: Load into only logical weight of w2.
|
|
expert_data.copy_(loaded_weight)
|
|
|
|
def _load_single_value(
|
|
self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int
|
|
):
|
|
param_data = param.data
|
|
|
|
# Input scales can be loaded directly and should be equal.
|
|
param_data[expert_id] = loaded_weight
|
|
|
|
def _load_g_idx(
|
|
self,
|
|
shard_id: str,
|
|
expert_data: torch.Tensor,
|
|
shard_dim: int,
|
|
loaded_weight: torch.tensor,
|
|
tp_rank: int,
|
|
):
|
|
|
|
if shard_id == "w2":
|
|
self._load_w2(
|
|
shard_id=shard_id,
|
|
shard_dim=shard_dim,
|
|
loaded_weight=loaded_weight,
|
|
expert_data=expert_data,
|
|
tp_rank=tp_rank,
|
|
)
|
|
else:
|
|
assert shard_id in ("w1", "w3")
|
|
expert_data.copy_(loaded_weight)
|
|
|
|
def weight_loader(
|
|
self,
|
|
param: torch.nn.Parameter,
|
|
loaded_weight: torch.Tensor,
|
|
weight_name: str,
|
|
shard_id: str,
|
|
expert_id: int,
|
|
) -> None:
|
|
# compressed-tensors checkpoints with packed weights are stored flipped
|
|
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
|
# against known CompressionFormat enum values that have this quality
|
|
loaded_weight = (
|
|
loaded_weight.t().contiguous()
|
|
if (
|
|
self.quant_method.__class__.__name__
|
|
== "CompressedTensorsWNA16MoEMethod"
|
|
)
|
|
else loaded_weight
|
|
)
|
|
|
|
if shard_id not in ("w1", "w2", "w3"):
|
|
raise ValueError(
|
|
f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."
|
|
)
|
|
|
|
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
|
|
# Fetch the dim to shard the parameter/loaded weight
|
|
# based on the shard id. This will be whatever
|
|
# dimension intermediate_size is used.
|
|
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
|
|
|
expert_data = param.data[expert_id]
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
|
|
# is_transposed: if the dim to shard the weight
|
|
# should be flipped. Required by GPTQ, compressed-tensors
|
|
# should be whatever dimension intermediate_size is
|
|
is_transposed = getattr(param, "is_transposed", False)
|
|
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
|
if is_transposed:
|
|
shard_dim = ~shard_dim
|
|
|
|
# Case input scale: input_scale loading is only supported for fp8
|
|
if "input_scale" in weight_name:
|
|
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
|
|
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
|
loaded_weight = loaded_weight * 2.0
|
|
|
|
# this is needed for compressed-tensors only
|
|
loaded_weight = loaded_weight.to(param.data.device)
|
|
|
|
if (
|
|
param.data[expert_id] != 1
|
|
and (param.data[expert_id] - loaded_weight).abs() > 1e-5
|
|
):
|
|
raise ValueError(
|
|
"input_scales of w1 and w3 of a layer "
|
|
f"must be equal. But got {param.data[expert_id]} "
|
|
f"vs. {loaded_weight}"
|
|
)
|
|
|
|
self._load_single_value(
|
|
param=param, loaded_weight=loaded_weight, expert_id=expert_id
|
|
)
|
|
return
|
|
|
|
# Case g_idx
|
|
if "g_idx" in weight_name:
|
|
self._load_g_idx(
|
|
shard_dim=0,
|
|
shard_id=shard_id,
|
|
loaded_weight=loaded_weight,
|
|
expert_data=expert_data,
|
|
tp_rank=tp_rank,
|
|
)
|
|
return
|
|
|
|
# Case weight scales and zero_points
|
|
if "scale" in weight_name or "zero" in weight_name:
|
|
# load the weight scales and zp based on the quantization scheme
|
|
# supported weight scales/zp can be found in
|
|
# FusedMoeWeightScaleSupported
|
|
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
|
|
# specific to each case
|
|
quant_method = getattr(param, "quant_method", None)
|
|
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
|
|
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
|
|
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
|
loaded_weight = loaded_weight * 0.5
|
|
|
|
self._load_per_channel_weight_scale(
|
|
shard_id=shard_id,
|
|
shard_dim=shard_dim,
|
|
loaded_weight=loaded_weight,
|
|
expert_data=expert_data,
|
|
tp_rank=tp_rank,
|
|
)
|
|
elif quant_method in [
|
|
FusedMoeWeightScaleSupported.GROUP.value,
|
|
FusedMoeWeightScaleSupported.BLOCK.value,
|
|
]:
|
|
self._load_model_weight_or_group_weight_scale(
|
|
shard_id=shard_id,
|
|
shard_dim=shard_dim,
|
|
loaded_weight=loaded_weight,
|
|
expert_data=expert_data,
|
|
tp_rank=tp_rank,
|
|
)
|
|
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
|
|
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
|
|
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
|
loaded_weight = loaded_weight * 2.0
|
|
|
|
self._load_per_tensor_weight_scale(
|
|
shard_id=shard_id,
|
|
param=param,
|
|
loaded_weight=loaded_weight,
|
|
expert_id=expert_id,
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}"
|
|
)
|
|
return
|
|
|
|
# Case weight_shape
|
|
if "weight_shape" in weight_name:
|
|
# only required by compressed-tensors
|
|
self._load_single_value(
|
|
param=param, loaded_weight=loaded_weight, expert_id=expert_id
|
|
)
|
|
return
|
|
|
|
# Case model weights
|
|
if "weight" in weight_name:
|
|
self._load_model_weight_or_group_weight_scale(
|
|
shard_id=shard_id,
|
|
shard_dim=shard_dim,
|
|
loaded_weight=loaded_weight,
|
|
expert_data=expert_data,
|
|
tp_rank=tp_rank,
|
|
)
|
|
return
|
|
|
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
|
assert self.quant_method is not None
|
|
|
|
# Matrix multiply.
|
|
final_hidden_states = self.quant_method.apply(
|
|
layer=self,
|
|
x=hidden_states,
|
|
router_logits=router_logits,
|
|
top_k=self.top_k,
|
|
renormalize=self.renormalize,
|
|
use_grouped_topk=self.use_grouped_topk,
|
|
topk_group=self.topk_group,
|
|
num_expert_group=self.num_expert_group,
|
|
custom_routing_function=self.custom_routing_function,
|
|
correction_bias=self.correction_bias,
|
|
activation=self.activation,
|
|
)
|
|
|
|
if self.reduce_results and self.tp_size > 1:
|
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
|
|
|
return final_hidden_states
|
|
|
|
@classmethod
|
|
def make_expert_params_mapping(
|
|
cls,
|
|
ckpt_gate_proj_name: str,
|
|
ckpt_down_proj_name: str,
|
|
ckpt_up_proj_name: str,
|
|
num_experts: int,
|
|
) -> List[Tuple[str, str, int, str]]:
|
|
|
|
return [
|
|
# (param_name, weight_name, expert_id, shard_id)
|
|
(
|
|
(
|
|
"experts.w13_"
|
|
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
|
|
else "experts.w2_"
|
|
),
|
|
f"experts.{expert_id}.{weight_name}.",
|
|
expert_id,
|
|
shard_id,
|
|
)
|
|
for expert_id in range(num_experts)
|
|
for shard_id, weight_name in [
|
|
("w1", ckpt_gate_proj_name),
|
|
("w2", ckpt_down_proj_name),
|
|
("w3", ckpt_up_proj_name),
|
|
]
|
|
]
|
|
|
|
def _load_fp8_scale(
|
|
self,
|
|
param: torch.nn.Parameter,
|
|
loaded_weight: torch.Tensor,
|
|
weight_name: str,
|
|
shard_id: str,
|
|
expert_id: int,
|
|
) -> None:
|
|
param_data = param.data
|
|
|
|
# Input scales can be loaded directly and should be equal.
|
|
if "input_scale" in weight_name:
|
|
if (
|
|
param_data[expert_id] != 1
|
|
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
|
|
):
|
|
raise ValueError(
|
|
"input_scales of w1 and w3 of a layer "
|
|
f"must be equal. But got {param_data[expert_id]} "
|
|
f"vs. {loaded_weight}"
|
|
)
|
|
param_data[expert_id] = loaded_weight
|
|
# Weight scales
|
|
elif "weight_scale" in weight_name:
|
|
# If we are in merged column case (gate_up_proj)
|
|
if shard_id in ("w1", "w3"):
|
|
# We have to keep the weight scales of w1 and w3 because
|
|
# we need to re-quantize w1/w3 weights after weight loading.
|
|
idx = 0 if shard_id == "w1" else 1
|
|
param_data[expert_id][idx] = loaded_weight
|
|
# If we are in the row parallel case (down_proj)
|
|
else:
|
|
param_data[expert_id] = loaded_weight
|