# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py from abc import abstractmethod from enum import Enum from typing import Callable, List, Optional, Tuple import torch from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) from sglang.srt.layers.moe.fused_moe_native import moe_forward_native from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) from sglang.srt.utils import get_bool_env_var, is_hip, permute_weight, set_weight_attrs if torch.cuda.is_available(): from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts else: fused_experts = None # type: ignore import logging _is_hip = is_hip() if _is_hip: from aiter import ck_moe logger = logging.getLogger(__name__) class FusedMoeWeightScaleSupported(Enum): TENSOR = "tensor" CHANNEL = "channel" GROUP = "group" BLOCK = "block" class FusedMoEMethodBase(QuantizeMethodBase): @abstractmethod def create_weights( self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): raise NotImplementedError @abstractmethod def apply( self, layer: torch.nn.Module, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, renormalize: bool, use_grouped_topk: bool, ) -> torch.Tensor: raise NotImplementedError class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" def create_weights( self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): # Fused gate_up_proj (column parallel) w13_weight = torch.nn.Parameter( torch.empty( num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype ), requires_grad=False, ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) # down_proj (row parallel) w2_weight = torch.nn.Parameter( torch.empty( num_experts, hidden_size, intermediate_size, dtype=params_dtype ), requires_grad=False, ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if _is_hip and get_bool_env_var("CK_MOE"): layer.w13_weight = torch.nn.Parameter( permute_weight(layer.w13_weight.data), requires_grad=False, ) torch.cuda.empty_cache() layer.w2_weight = torch.nn.Parameter( permute_weight(layer.w2_weight.data), requires_grad=False, ) torch.cuda.empty_cache() return def apply( self, layer: torch.nn.Module, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, renormalize: bool, use_grouped_topk: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", inplace: bool = True, no_combine: bool = False, ) -> torch.Tensor: return self.forward( x=x, layer=layer, router_logits=router_logits, top_k=top_k, renormalize=renormalize, use_grouped_topk=use_grouped_topk, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, correction_bias=correction_bias, activation=activation, inplace=inplace, no_combine=no_combine, ) def forward_cuda( self, layer: torch.nn.Module, x: torch.Tensor, use_grouped_topk: bool, top_k: int, router_logits: torch.Tensor, renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", inplace: bool = True, no_combine: bool = False, ) -> torch.Tensor: topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, correction_bias=correction_bias, ) if _is_hip and get_bool_env_var("CK_MOE"): assert not no_combine, "unsupported" return ck_moe( x, layer.w13_weight, layer.w2_weight, topk_weights, topk_ids, None, None, None, None, 32, None, activation, ) else: return fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=inplace and not no_combine, activation=activation, no_combine=no_combine, ) def forward_cpu( self, layer: torch.nn.Module, x: torch.Tensor, use_grouped_topk: bool, top_k: int, router_logits: torch.Tensor, renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, inplace: bool = True, ) -> torch.Tensor: return moe_forward_native( layer, x, use_grouped_topk, top_k, router_logits, renormalize, topk_group, num_expert_group, custom_routing_function, correction_bias, ) def forward_tpu(self, *args, **kwargs) -> torch.Tensor: raise NotImplementedError("The TPU backend currently does not support MoE.") forward_native = forward_cuda class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. This layer contains both MergedColumnParallel weights (gate_up_proj / w13) and RowParallelLinear weights (down_proj/ w2). Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We copy that naming convention here and handle any remapping in the load_weights function in each model implementation. Args: num_experts: Number of experts in the model top_k: Number of experts selected for each token hidden_size: Input hidden state size of the transformer intermediate_size: Intermediate size of the experts params_dtype: Data type for the parameters. reduce_results: Whether to all all_reduce on the output of the layer renomalize: Whether to renormalize the logits in the fused_moe kernel quant_config: Quantization configure. inplace: suggestion to compute inplace (modify input activation). """ def __init__( self, num_experts: int, top_k: int, hidden_size: int, intermediate_size: int, params_dtype: Optional[torch.dtype] = None, reduce_results: bool = False, renormalize: bool = True, use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, prefix: str = "", custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", use_presharded_weights: bool = False, inplace: bool = True, no_combine: bool = False, ): super().__init__() if params_dtype is None: params_dtype = torch.get_default_dtype() self.tp_size = ( tp_size if tp_size is not None else get_tensor_model_parallel_world_size() ) self.top_k = top_k self.num_experts = num_experts assert intermediate_size % self.tp_size == 0 self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results self.renormalize = renormalize self.use_grouped_topk = use_grouped_topk if self.use_grouped_topk: assert num_expert_group is not None and topk_group is not None self.num_expert_group = num_expert_group self.topk_group = topk_group self.custom_routing_function = custom_routing_function self.correction_bias = correction_bias self.activation = activation self.use_presharded_weights = use_presharded_weights self.inplace = inplace self.no_combine = no_combine self.local_num_experts = num_experts if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( UnquantizedFusedMoEMethod() ) else: self.quant_method = quant_config.get_quant_method(self, prefix) assert self.quant_method is not None self.quant_method.create_weights( layer=self, num_experts=num_experts, hidden_size=hidden_size, # FIXME: figure out which intermediate_size to use intermediate_size=self.intermediate_size_per_partition, intermediate_size_per_partition=self.intermediate_size_per_partition, params_dtype=params_dtype, weight_loader=self.weight_loader, ) def _load_per_tensor_weight_scale( self, shard_id: str, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, ): param_data = param.data # for per tensor weight quantization if shard_id in ("w1", "w3"): # We have to keep the weight scales of w1 and w3 because # we need to re-quantize w1/w3 weights after weight loading. idx = 0 if shard_id == "w1" else 1 param_data[expert_id][idx] = loaded_weight # If we are in the row parallel case (down_proj) elif shard_id == "w2": param_data[expert_id] = loaded_weight def _load_model_weight_or_group_weight_scale( self, shard_dim: int, expert_data: torch.Tensor, shard_id: str, loaded_weight: torch.tensor, tp_rank: int, ): # Load grouped weight scales for group quantization # or model weights if shard_id == "w2": self._load_w2( shard_id=shard_id, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=tp_rank, ) elif shard_id in ("w1", "w3"): self._load_w13( shard_id=shard_id, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=tp_rank, ) def _load_per_channel_weight_scale( self, expert_data: torch.Tensor, shard_dim: int, shard_id: str, loaded_weight: torch.tensor, tp_rank: int, ): # for per channel weight quantization if shard_id == "w2": expert_data.copy_(loaded_weight) elif shard_id in ("w1", "w3"): self._load_w13( shard_id=shard_id, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=tp_rank, ) def _load_w13( self, expert_data: torch.Tensor, shard_dim: int, shard_id: str, loaded_weight: torch.tensor, tp_rank: int, ): # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim shard_size = expert_data.shape[shard_dim] // 2 if not self.use_presharded_weights: loaded_weight = loaded_weight.narrow( shard_dim, shard_size * tp_rank, shard_size ) # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. if shard_id == "w1": expert_data = expert_data.narrow(shard_dim, 0, shard_size) # w3, up_proj: Load into second logical weight of w13. else: assert shard_id == "w3" expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) expert_data.copy_(loaded_weight) def _load_w2( self, expert_data: torch.Tensor, shard_dim: int, shard_id: str, loaded_weight: torch.tensor, tp_rank: int, ): # Index the loaded weight for tp sharding. # down_proj: "RowParallel" so tp sharding on input_dim # Narrow parameter and load. shard_size = expert_data.shape[shard_dim] if not self.use_presharded_weights: loaded_weight = loaded_weight.narrow( shard_dim, shard_size * tp_rank, shard_size ) # w2, down_proj: Load into only logical weight of w2. expert_data.copy_(loaded_weight) def _load_single_value( self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int ): param_data = param.data # Input scales can be loaded directly and should be equal. param_data[expert_id] = loaded_weight def _load_g_idx( self, shard_id: str, expert_data: torch.Tensor, shard_dim: int, loaded_weight: torch.tensor, tp_rank: int, ): if shard_id == "w2": self._load_w2( shard_id=shard_id, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=tp_rank, ) else: assert shard_id in ("w1", "w3") expert_data.copy_(loaded_weight) def weight_loader( self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int, ) -> None: # compressed-tensors checkpoints with packed weights are stored flipped # TODO (mgoin): check self.quant_method.quant_config.quant_format # against known CompressionFormat enum values that have this quality loaded_weight = ( loaded_weight.t().contiguous() if ( self.quant_method.__class__.__name__ == "CompressedTensorsWNA16MoEMethod" ) else loaded_weight ) if shard_id not in ("w1", "w2", "w3"): raise ValueError( f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}." ) WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported] # Fetch the dim to shard the parameter/loaded weight # based on the shard id. This will be whatever # dimension intermediate_size is used. SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} expert_data = param.data[expert_id] tp_rank = get_tensor_model_parallel_rank() # is_transposed: if the dim to shard the weight # should be flipped. Required by GPTQ, compressed-tensors # should be whatever dimension intermediate_size is is_transposed = getattr(param, "is_transposed", False) shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] if is_transposed: shard_dim = ~shard_dim # Case input scale: input_scale loading is only supported for fp8 if "input_scale" in weight_name: # INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD) if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"): loaded_weight = loaded_weight * 2.0 # this is needed for compressed-tensors only loaded_weight = loaded_weight.to(param.data.device) if ( param.data[expert_id] != 1 and (param.data[expert_id] - loaded_weight).abs() > 1e-5 ): raise ValueError( "input_scales of w1 and w3 of a layer " f"must be equal. But got {param.data[expert_id]} " f"vs. {loaded_weight}" ) self._load_single_value( param=param, loaded_weight=loaded_weight, expert_id=expert_id ) return # Case g_idx if "g_idx" in weight_name: self._load_g_idx( shard_dim=0, shard_id=shard_id, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=tp_rank, ) return # Case weight scales and zero_points if "scale" in weight_name or "zero" in weight_name: # load the weight scales and zp based on the quantization scheme # supported weight scales/zp can be found in # FusedMoeWeightScaleSupported # TODO @dsikka: once hardened, refactor to use vLLM Parameters # specific to each case quant_method = getattr(param, "quant_method", None) if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value: # INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD) if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"): loaded_weight = loaded_weight * 0.5 self._load_per_channel_weight_scale( shard_id=shard_id, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=tp_rank, ) elif quant_method in [ FusedMoeWeightScaleSupported.GROUP.value, FusedMoeWeightScaleSupported.BLOCK.value, ]: self._load_model_weight_or_group_weight_scale( shard_id=shard_id, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=tp_rank, ) elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: # INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD) if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"): loaded_weight = loaded_weight * 2.0 self._load_per_tensor_weight_scale( shard_id=shard_id, param=param, loaded_weight=loaded_weight, expert_id=expert_id, ) else: raise ValueError( f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}" ) return # Case weight_shape if "weight_shape" in weight_name: # only required by compressed-tensors self._load_single_value( param=param, loaded_weight=loaded_weight, expert_id=expert_id ) return # Case model weights if "weight" in weight_name: self._load_model_weight_or_group_weight_scale( shard_id=shard_id, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=tp_rank, ) return def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, x=hidden_states, router_logits=router_logits, top_k=self.top_k, renormalize=self.renormalize, use_grouped_topk=self.use_grouped_topk, topk_group=self.topk_group, num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, correction_bias=self.correction_bias, activation=self.activation, ) if self.reduce_results and self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states @classmethod def make_expert_params_mapping( cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, ckpt_up_proj_name: str, num_experts: int, ) -> List[Tuple[str, str, int, str]]: return [ # (param_name, weight_name, expert_id, shard_id) ( ( "experts.w13_" if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_" ), f"experts.{expert_id}.{weight_name}.", expert_id, shard_id, ) for expert_id in range(num_experts) for shard_id, weight_name in [ ("w1", ckpt_gate_proj_name), ("w2", ckpt_down_proj_name), ("w3", ckpt_up_proj_name), ] ] def _load_fp8_scale( self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int, ) -> None: param_data = param.data # Input scales can be loaded directly and should be equal. if "input_scale" in weight_name: if ( param_data[expert_id] != 1 and (param_data[expert_id] - loaded_weight).abs() > 1e-5 ): raise ValueError( "input_scales of w1 and w3 of a layer " f"must be equal. But got {param_data[expert_id]} " f"vs. {loaded_weight}" ) param_data[expert_id] = loaded_weight # Weight scales elif "weight_scale" in weight_name: # If we are in merged column case (gate_up_proj) if shard_id in ("w1", "w3"): # We have to keep the weight scales of w1 and w3 because # we need to re-quantize w1/w3 weights after weight loading. idx = 0 if shard_id == "w1" else 1 param_data[expert_id][idx] = loaded_weight # If we are in the row parallel case (down_proj) else: param_data[expert_id] = loaded_weight