# Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== # Adapted from # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1 """Inference-only Grok1 model.""" import functools import json import logging import math import os import warnings from typing import Iterable, Optional, Tuple import numpy as np import torch from torch import nn from transformers import PretrainedConfig from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, ) from sglang.srt.layers.elementwise import fused_dual_residual_rmsnorm, fused_rmsnorm from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.router import fused_moe_router_shim from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.loader import DefaultModelLoader from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import dump_to_file logger = logging.getLogger(__name__) debug_tensor_dump_output_folder = None debug_tensor_dump_inject = False class Grok1MoE(nn.Module): """A tensor-parallel MoE implementation for Grok1 that shards each expert across all ranks. Each expert's weights are sharded across all ranks and a fused MoE kernel is used for the forward pass, and finally we reduce the outputs across ranks. """ def __init__( self, config: PretrainedConfig, num_experts: int, top_k: int, hidden_size: int, intermediate_size: int, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, reduce_results=True, use_presharded_weights: bool = False, inplace: bool = True, no_combine: bool = False, ): super().__init__() self.hidden_size = hidden_size # Gate always runs at full precision for stability (see https://arxiv.org/pdf/2101.03961) self.gate = ReplicatedLinear( hidden_size, num_experts, bias=False, params_dtype=torch.float32, quant_config=None, ) self.router_logit_softcapping = getattr( config, "router_logit_softcapping", 30.0 ) custom_routing_function = functools.partial( fused_moe_router_shim, self.router_logit_softcapping ) kwargs = {} if global_server_args_dict["enable_ep_moe"]: MoEImpl = EPMoE else: MoEImpl = FusedMoE kwargs["reduce_results"] = reduce_results kwargs["use_presharded_weights"] = use_presharded_weights kwargs["inplace"] = inplace kwargs["no_combine"] = no_combine self.experts = MoEImpl( num_experts=num_experts, top_k=top_k, hidden_size=hidden_size, intermediate_size=intermediate_size, params_dtype=params_dtype, renormalize=False, quant_config=quant_config, tp_size=tp_size, custom_routing_function=custom_routing_function, activation="gelu", **kwargs, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # need to assert self.gate.quant_method is unquantized return self.experts(hidden_states, self.gate.weight) class Grok1Attention(nn.Module): def __init__( self, config: PretrainedConfig, hidden_size: int, num_heads: int, num_kv_heads: int, layer_id: int = 0, max_position: int = 4096 * 32, rope_theta: float = 10000, quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, load_presharded_attn: bool = False, ) -> None: super().__init__() self.config = config self.layer_id = layer_id self.hidden_size = hidden_size attn_tp_rank = get_tensor_model_parallel_rank() attn_tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % attn_tp_size == 0 self.num_heads = self.total_num_heads // attn_tp_size self.total_num_kv_heads = num_kv_heads if self.total_num_kv_heads >= attn_tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. assert self.total_num_kv_heads % attn_tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert attn_tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size) self.head_dim = getattr(config, "head_dim", 128) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.load_presharded_attn = load_presharded_attn self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=False, quant_config=quant_config, tp_rank=attn_tp_rank, tp_size=attn_tp_size, load_presharded_attn=self.load_presharded_attn, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, reduce_results=reduce_results, tp_rank=attn_tp_rank, tp_size=attn_tp_size, use_presharded_weights=self.load_presharded_attn, ) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position, base=int(self.rope_theta), is_neox_style=True, ) logit_cap = max(getattr(config, "attn_logit_softcapping", 30.0), 0.0) self.attn = RadixAttention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, logit_cap=logit_cap, ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: if hidden_states.shape[0] == 0: assert ( not self.o_proj.reduce_results ), "short-circuiting allreduce will lead to hangs" return hidden_states if debug_tensor_dump_output_folder: dump_to_file( debug_tensor_dump_output_folder, f"attn_input_{self.layer_id}", hidden_states, ) if debug_tensor_dump_inject: name = os.path.join( debug_tensor_dump_output_folder, f"jax_dump_attn_input_{self.layer_id}.npy", ) logger.info(f"Load {name} from jax.") x = np.load(name) hidden_states = torch.tensor(x[0, : hidden_states.shape[0]]).to( hidden_states ) qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) if debug_tensor_dump_output_folder: num_tokens = q.shape[0] num_heads_q = self.num_heads head_dim = self.head_dim num_heads_kv = k.numel() // (num_tokens * head_dim) dump_to_file( debug_tensor_dump_output_folder, f"q_{self.layer_id}", tensor_model_parallel_all_gather( q.reshape(num_tokens, num_heads_q, head_dim).contiguous(), dim=1 ).contiguous(), ) dump_to_file( debug_tensor_dump_output_folder, f"k_{self.layer_id}", tensor_model_parallel_all_gather( k.reshape(num_tokens, num_heads_kv, head_dim).contiguous(), dim=1 ).contiguous(), ) dump_to_file( debug_tensor_dump_output_folder, f"v_{self.layer_id}", tensor_model_parallel_all_gather( v.reshape(num_tokens, num_heads_kv, head_dim).contiguous(), dim=1 ).contiguous(), ) attn_output = self.attn(q, k, v, forward_batch) if debug_tensor_dump_output_folder: dump_to_file( debug_tensor_dump_output_folder, f"attn_output_{self.layer_id}", tensor_model_parallel_all_gather( attn_output.reshape(num_tokens, num_heads_q, head_dim).contiguous(), dim=1, ).contiguous(), ) output, _ = self.o_proj(attn_output) return output class Grok1DecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, load_presharded_moe: bool = False, load_presharded_attn: bool = False, load_presharded_mlp: bool = False, ) -> None: super().__init__() self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size self.layer_id = layer_id rope_theta = getattr(config, "rope_theta", 10000) self.self_attn = Grok1Attention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, layer_id=layer_id, rope_theta=rope_theta, quant_config=quant_config, reduce_results=False, load_presharded_attn=load_presharded_attn, ) self.block_sparse_moe = Grok1MoE( config=config, num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=getattr( config, "moe_intermediate_size", getattr(config, "intermediate_size", None), ), quant_config=quant_config, reduce_results=True, use_presharded_weights=load_presharded_moe, inplace=True, no_combine=False, # just a suggestion to not combine topk ) self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.ffn = self.block_sparse_moe def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, residual: Optional[torch.Tensor] = None, deferred_norm: Optional[RMSNorm] = None, ) -> Tuple[torch.Tensor, torch.Tensor, RMSNorm]: # Self Attention if deferred_norm is not None: assert residual is not None # here hidden_states is output of ffn, residual is residual from after previous attn layer hidden_states, residual = fused_dual_residual_rmsnorm( hidden_states, residual, deferred_norm.weight, self.pre_attn_norm.weight, deferred_norm.variance_epsilon, ) else: # here hidden_states is the residual hidden_states, residual = ( fused_rmsnorm( hidden_states, self.pre_attn_norm.weight, self.pre_attn_norm.variance_epsilon, ), hidden_states, ) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, forward_batch=forward_batch, ) if get_tensor_model_parallel_world_size() > 1: hidden_states = tensor_model_parallel_all_reduce(hidden_states) hidden_states, residual = fused_dual_residual_rmsnorm( hidden_states, residual, self.post_attn_norm.weight, self.pre_moe_norm.weight, self.post_attn_norm.variance_epsilon, ) # Fully Connected hidden_states = self.ffn(hidden_states) return hidden_states, residual, self.post_moe_norm # defer layernorm class Grok1Model(nn.Module): def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, load_presharded_moe: bool = False, load_presharded_embedding: bool = False, load_presharded_attn: bool = False, load_presharded_mlp: bool = False, ) -> None: super().__init__() self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, use_presharded_weights=load_presharded_embedding, ) self.layers = nn.ModuleList( [ Grok1DecoderLayer( config, i, quant_config=quant_config, load_presharded_moe=load_presharded_moe, load_presharded_attn=load_presharded_attn, load_presharded_mlp=load_presharded_mlp, ) for i in range(config.num_hidden_layers) ] ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: if input_embeds is None: hidden_states = self.embed_tokens(input_ids) hidden_states.mul_(self.config.embedding_multiplier_scale) else: hidden_states = input_embeds residual, deferred_norm = None, None for i in range(len(self.layers)): hidden_states, residual, deferred_norm = self.layers[i]( positions, hidden_states, forward_batch, residual, deferred_norm ) if debug_tensor_dump_output_folder: hidden_states = ( fused_rmsnorm( hidden_states, deferred_norm.weight, deferred_norm.variance_epsilon, ) + residual ) dump_to_file( debug_tensor_dump_output_folder, "last_hidden_before_norm", hidden_states, ) hidden_states = fused_rmsnorm( hidden_states, self.norm.weight, self.norm.variance_epsilon, ) dump_to_file( debug_tensor_dump_output_folder, "last_hidden_after_norm", hidden_states, ) else: hidden_states, _ = fused_dual_residual_rmsnorm( hidden_states, residual, deferred_norm.weight, self.norm.weight, deferred_norm.variance_epsilon, ) return hidden_states class Grok1ForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config # Get presharded weights. self.load_presharded_mlp = getattr(config, "load_presharded_mlp", False) self.load_presharded_moe = ( self.config.num_local_experts > 0 and get_tensor_model_parallel_world_size() > 1 ) self.load_presharded_attn = getattr(config, "load_presharded_attn", False) self.load_presharded_embedding = getattr( config, "load_presharded_embedding", False ) self.is_weights_presharded = ( self.load_presharded_mlp or self.load_presharded_moe or self.load_presharded_attn or self.load_presharded_embedding ) if self.is_weights_presharded: setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights) default_replicate_lm_head = False self.replicate_lm_head = getattr( config, "replicate_lm_head", default_replicate_lm_head ) self.model = Grok1Model( config, quant_config=quant_config, load_presharded_moe=self.load_presharded_moe, load_presharded_embedding=self.load_presharded_embedding, load_presharded_attn=self.load_presharded_attn, load_presharded_mlp=self.load_presharded_mlp, ) lm_head_params_dtype = None if self.replicate_lm_head: self.lm_head = ReplicatedLinear( config.hidden_size, config.vocab_size, bias=False, params_dtype=lm_head_params_dtype, ) self.logits_processor = LogitsProcessor(config, skip_all_gather=True) else: self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, use_presharded_weights=self.load_presharded_embedding, params_dtype=lm_head_params_dtype, ) self.logits_processor = LogitsProcessor(config) # Dump tensors for debugging global debug_tensor_dump_output_folder, debug_tensor_dump_inject debug_tensor_dump_output_folder = global_server_args_dict[ "debug_tensor_dump_output_folder" ] debug_tensor_dump_inject = global_server_args_dict["debug_tensor_dump_inject"] warnings.filterwarnings("ignore", category=FutureWarning) if get_tensor_model_parallel_rank() == 0: logger.info( f"#parameters (analytical): {self.get_num_params_analytical() / 1e9:.2f} B, " f"#parameters (actual): {self.get_num_params_torch() / 1e9:.2f} B" ) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: if debug_tensor_dump_output_folder: dump_to_file(debug_tensor_dump_output_folder, "input_ids", input_ids) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights( self, weights: Iterable[Tuple[str, torch.Tensor]], num_experts: Optional[int] = None, ignore_parent_name: bool = False, ) -> dict[str, torch.Tensor]: if num_experts is None: num_experts = self.config.num_local_experts stacked_params_mapping = [] stacked_params_mapping += [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ] stacked_params_mapping += [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE expert_params_mapping = MoEImpl.make_expert_params_mapping( ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", ckpt_up_proj_name="w3", num_experts=num_experts, ) params_dict = dict(self.named_parameters()) all_names = set(params_dict.keys()) hit_names = set() def load_weight_wrapper( name: str, loaded_weight: torch.Tensor, *args, **kwargs ): if ignore_parent_name: name = name.split(".")[-1] if name not in params_dict: return # Fuse constant multipliers into the weights if "lm_head" in name: loaded_weight = ( loaded_weight.to(torch.float32) * self.config.output_multiplier_scale ) param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight, *args, **kwargs) hit_names.add(name) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue load_weight_wrapper(name, loaded_weight, shard_id) break else: for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue name = name.replace(weight_name, param_name) load_weight_wrapper( name, loaded_weight, name, shard_id=shard_id, expert_id=expert_id, ) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue if name is None: continue load_weight_wrapper(name=name, loaded_weight=loaded_weight) if len(hit_names) > 5: missing = all_names - hit_names missing_exclude_scales = {x for x in missing if "scale" not in x} logger.info( f"#all_names: {len(all_names)}, #hit_names: {len(hit_names)}, #missing_exclude_scales: {len(missing_exclude_scales)}", ) if len(missing_exclude_scales) > 0: raise ValueError( f"load_weights failed because some weights are missing: {missing_exclude_scales=}." ) elif len(hit_names) == 0: raise ValueError("load_weights failed because it did not hit any names.") return hit_names def get_num_params_analytical(self): cfg = self.config moe_intermediate_size = getattr( cfg, "moe_intermediate_size", getattr(cfg, "intermediate_size", None), ) num_experts = cfg.num_local_experts wq = ( cfg.num_hidden_layers * cfg.hidden_size * cfg.num_attention_heads * cfg.head_dim ) wkv = ( cfg.num_hidden_layers * cfg.hidden_size * cfg.num_key_value_heads * cfg.head_dim * 2 ) out = ( cfg.num_hidden_layers * cfg.hidden_size * cfg.num_attention_heads * cfg.head_dim ) ffn1 = ( cfg.num_hidden_layers * num_experts * cfg.hidden_size * moe_intermediate_size * 2 ) ffn2 = ( cfg.num_hidden_layers * num_experts * cfg.hidden_size * moe_intermediate_size ) embed = cfg.hidden_size * cfg.vocab_size * 2 return wq + wkv + out + ffn1 + ffn2 + embed def get_num_params_torch(self): return ( sum(p.numel() for p in self.parameters()) * get_tensor_model_parallel_world_size() ) old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights") def _prepare_presharded_weights( self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool ) -> Tuple[str, list[str], bool]: import glob import os if get_tensor_model_parallel_world_size() == 1: return old_prepare_weights(self, model_name_or_path, revision, fall_back_to_pt) if not os.path.isdir(model_name_or_path): from sglang.srt.model_loader.weight_utils import download_weights_from_hf allow_patterns = ["*.safetensors", "*.bin"] hf_folder = download_weights_from_hf( model_name_or_path, self.load_config.download_dir, allow_patterns, revision, ignore_patterns=self.load_config.ignore_patterns, ) else: hf_folder = model_name_or_path tp_rank = get_tensor_model_parallel_rank() # The old format allow_patterns = [f"*-{tp_rank:03d}.bin"] # The new format allow_patterns += [f"*-TP-{tp_rank:03d}.safetensors", "*-TP-common.safetensors"] hf_weights_files = [] for pattern in allow_patterns: hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) if hf_weights_files[0].endswith("safetensors"): use_safetensors = True else: use_safetensors = False return hf_folder, hf_weights_files, use_safetensors class Grok1ModelForCausalLM(Grok1ForCausalLM): """An alias for backward-compatbility.""" pass EntryClass = [Grok1ForCausalLM, Grok1ModelForCausalLM]