# Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== # Adapted from # https://github.com/THUDM/ChatGLM2-6B """Inference-only ChatGLM model compatible with THUDM weights.""" from typing import Iterable, Optional, Tuple import torch from torch import nn from torch.nn import LayerNorm from sglang.srt.configs import ChatGLMConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix LoraConfig = None class GLMAttention(nn.Module): def __init__( self, config, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.multi_query_attention = config.multi_query_attention self.total_num_kv_heads = ( config.multi_query_group_num if config.multi_query_attention else config.num_attention_heads ) if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. assert self.total_num_kv_heads % tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = config.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.query_key_value = QKVParallelLinear( self.hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=config.add_bias_linear or config.add_qkv_bias, quant_config=quant_config, prefix=add_prefix("query_key_value", prefix), ) self.dense = RowParallelLinear( self.total_num_heads * self.head_dim, config.hidden_size, bias=config.add_bias_linear, quant_config=quant_config, prefix=add_prefix("dense", prefix), ) # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 rope_ratio = getattr(config, "rope_ratio", 1.0) max_positions = getattr(config, "seq_length", 8192) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim // 2, max_position=max_positions, base=10000 * rope_ratio, is_neox_style=False, ) self.attn = RadixAttention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, prefix=add_prefix("attn", prefix), ) def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.query_key_value(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(position_ids, q, k) context_layer = self.attn( q, k, v, forward_batch, ) attn_output, _ = self.dense(context_layer) return attn_output class GLMMLP(nn.Module): """MLP. MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension. """ def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.add_bias = config.add_bias_linear # Project to 4h. self.dense_h_to_4h = MergedColumnParallelLinear( config.hidden_size, [config.ffn_hidden_size] * 2, bias=config.add_bias_linear, quant_config=quant_config, prefix=add_prefix("dense_h_to_4h", prefix), ) self.activation_func = SiluAndMul() # Project back to h. self.dense_4h_to_h = RowParallelLinear( config.ffn_hidden_size, config.hidden_size, bias=config.add_bias_linear, quant_config=quant_config, prefix=add_prefix("dense_4h_to_h", prefix), ) def forward(self, hidden_states): # [s, b, 4hp] intermediate_parallel, _ = self.dense_h_to_4h(hidden_states) intermediate_parallel = self.activation_func(intermediate_parallel) # [s, b, h] output, _ = self.dense_4h_to_h(intermediate_parallel) return output class GLMBlock(nn.Module): """A single transformer layer. Transformer layer takes input with size [s, b, h] and returns an output of the same size. """ def __init__( self, config, layer_id: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.apply_residual_connection_post_layernorm = ( config.apply_residual_connection_post_layernorm ) self.fp32_residual_connection = config.fp32_residual_connection layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm # Layernorm on the input data. self.input_layernorm = layer_norm_func( config.hidden_size, eps=config.layernorm_epsilon ) # Self attention. self.self_attention = GLMAttention( config, layer_id, quant_config, prefix=add_prefix("self_attention", prefix) ) self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output self.post_attention_layernorm = layer_norm_func( config.hidden_size, eps=config.layernorm_epsilon ) # MLP self.mlp = GLMMLP(config, quant_config, prefix=add_prefix("mlp", prefix)) def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: # hidden_states: [num_tokens, h] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output = self.self_attention( hidden_states=layernorm_output, position_ids=position_ids, forward_batch=forward_batch, ) # Residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states layernorm_input = residual + attention_output # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) # Second residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = layernorm_input output = self.mlp(layernorm_output) + residual return output class GLMTransformer(nn.Module): """Transformer class.""" def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.post_layer_norm = config.post_layer_norm # Number of layers. self.num_layers = config.num_layers # Transformer layers. self.layers = nn.ModuleList( [ GLMBlock( config, i, quant_config, prefix=add_prefix(f"layers.{i}", prefix), ) for i in range(self.num_layers) ] ) if self.post_layer_norm: layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm # Final layer norm before output. self.final_layernorm = layer_norm_func( config.hidden_size, eps=config.layernorm_epsilon ) def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: for i in range(self.num_layers): layer = self.layers[i] hidden_states = layer( hidden_states=hidden_states, position_ids=position_ids, forward_batch=forward_batch, ) # Final layer norm. if self.post_layer_norm: hidden_states = self.final_layernorm(hidden_states) return hidden_states class ChatGLMM(nn.Module): def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.embedding = VocabParallelEmbedding( config.padded_vocab_size, config.hidden_size, prefix=add_prefix("embedding", prefix), ) self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num self.kv_channels = config.kv_channels self.encoder = GLMTransformer( config, quant_config, add_prefix("encoder", prefix) ) self.output_layer = ParallelLMHead( config.padded_vocab_size, config.hidden_size, prefix=add_prefix("output_layer", prefix), ) def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: inputs_embeds = self.embedding(input_ids) # Run encoder. hidden_states = self.encoder( hidden_states=inputs_embeds, position_ids=position_ids, forward_batch=forward_batch, ) return hidden_states class ChatGLMForCausalLM(nn.Module): packed_modules_mapping = { "query_key_value": ["query_key_value"], "dense_h_to_4h": ["dense_h_to_4h"], } # LoRA specific attributes supported_lora_modules = [ "query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h", ] embedding_modules = {} embedding_padding_modules = [] def __init__( self, config: ChatGLMConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.config: ChatGLMConfig = config self.quant_config = quant_config self.max_position_embeddings = getattr(config, "max_sequence_length", 8192) self.transformer = ChatGLMM( config, quant_config, prefix=add_prefix("transformer", prefix) ) self.lm_head = self.transformer.output_layer self.logits_processor = LogitsProcessor(config) @torch.no_grad() def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, forward_batch) return self.logits_processor( input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in weights: if "rotary_pos_emb.inv_freq" in name: continue if "word_embeddings" in name: name = name.replace(".word_embeddings", "") # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) class ChatGLMModel(ChatGLMForCausalLM): pass EntryClass = [ChatGLMModel]