# Copyright 2024 The LGcns AI Engineering Team # Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== # Adapted from llama2.py """Inference-only Exaone model compatible with HuggingFace weights.""" from typing import Any, Dict, Iterable, Optional, Tuple import torch from torch import nn from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix class ExaoneGatedMLP(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, hidden_act: str, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config, prefix=add_prefix("gate_up_proj", prefix), ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, quant_config=quant_config, prefix=add_prefix("c_proj", prefix), ) if hidden_act != "silu": raise ValueError( f"Unsupported activation: {hidden_act}. " "Only silu is supported for now." ) self.act_fn = SiluAndMul() def forward(self, x): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.c_proj(x) return x class ExaoneAttention(nn.Module): def __init__( self, config, hidden_size: int, num_heads: int, num_kv_heads: int, layer_id: int = 0, rope_theta: float = 500000, rope_scaling: Optional[Dict[str, Any]] = None, rope_is_neox_style: bool = True, max_position_embeddings: int = 4096, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. assert self.total_num_kv_heads % tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MistralConfig has an optional head_dim introduced by Mistral-Nemo self.head_dim = getattr( config, "head_dim", self.hidden_size // self.total_num_heads ) self.rotary_dim = int( self.head_dim * getattr(config, "partial_rotary_factor", 1) ) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=False, quant_config=quant_config, prefix=add_prefix("qkv_proj", prefix), ) self.out_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, prefix=add_prefix("out_proj", prefix), ) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.rotary_dim, max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, is_neox_style=rope_is_neox_style, ) self.attn = RadixAttention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, forward_batch) output, _ = self.out_proj(attn_output) return output class ExaoneDecoderLayer(nn.Module): def __init__( self, config, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 500000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( config, "original_max_position_embeddings", None ): rope_scaling["original_max_position_embeddings"] = ( config.original_max_position_embeddings ) rope_is_neox_style = getattr(config, "rope_is_neox_style", True) max_position_embeddings = getattr(config, "max_position_embeddings", 4096) self.self_attn = ExaoneAttention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, layer_id=layer_id, rope_theta=rope_theta, rope_scaling=rope_scaling, rope_is_neox_style=rope_is_neox_style, max_position_embeddings=max_position_embeddings, quant_config=quant_config, prefix=add_prefix("self_attn", prefix), ) self.mlp = ExaoneGatedMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.activation_function, quant_config=quant_config, prefix=add_prefix("mlp", prefix), ) rms_norm_eps = config.layer_norm_epsilon self.ln_1 = RMSNorm(config.hidden_size, eps=rms_norm_eps) self.ln_2 = RMSNorm(config.hidden_size, eps=rms_norm_eps) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states hidden_states = self.ln_1(hidden_states) else: hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, forward_batch=forward_batch, ) # Fully Connected hidden_states, residual = self.ln_2(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual class ExaoneModel(nn.Module): def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.wte = VocabParallelEmbedding( config.vocab_size, config.hidden_size, ) self.h = nn.ModuleList( [ ExaoneDecoderLayer( config, i, quant_config=quant_config, prefix=add_prefix(f"h.{i}", prefix), ) for i in range(config.num_hidden_layers) ] ) rms_norm_eps = config.layer_norm_epsilon self.ln_f = RMSNorm(config.hidden_size, eps=rms_norm_eps) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: if input_embeds is None: hidden_states = self.wte(input_ids) else: hidden_states = input_embeds residual = None for i in range(len(self.h)): layer = self.h[i] hidden_states, residual = layer( positions, hidden_states, forward_batch, residual, ) hidden_states, _ = self.ln_f(hidden_states, residual) return hidden_states class ExaoneForCausalLM(nn.Module): def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.config = config self.quant_config = quant_config self.transformer = ExaoneModel( config, quant_config=quant_config, prefix=add_prefix("transformer", prefix) ) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix) ) self.logits_processor = LogitsProcessor(config) @torch.no_grad() def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> LogitsProcessorOutput: hidden_states = self.transformer( input_ids, positions, forward_batch, input_embeds ) return self.logits_processor( input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "c_fc_0", 0), ("gate_up_proj", "c_fc_1", 1), ] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name or "projector" in name: continue if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue if name.startswith("model.vision_tower") and name not in params_dict: continue name = name.replace("attn.attention", "self_attn") for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) EntryClass = ExaoneForCausalLM