sglang0.4.5.post1/python/sglang/srt/models/chatglm.py

427 lines
14 KiB
Python

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights."""
from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from torch.nn import LayerNorm
from sglang.srt.configs import ChatGLMConfig
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
LoraConfig = None
class GLMAttention(nn.Module):
def __init__(
self,
config,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.multi_query_attention = config.multi_query_attention
self.total_num_kv_heads = (
config.multi_query_group_num
if config.multi_query_attention
else config.num_attention_heads
)
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = config.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.query_key_value = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.add_bias_linear or config.add_qkv_bias,
quant_config=quant_config,
prefix=add_prefix("query_key_value", prefix),
)
self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=config.add_bias_linear,
quant_config=quant_config,
prefix=add_prefix("dense", prefix),
)
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
rope_ratio = getattr(config, "rope_ratio", 1.0)
max_positions = getattr(config, "seq_length", 8192)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim // 2,
max_position=max_positions,
base=10000 * rope_ratio,
is_neox_style=False,
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
prefix=add_prefix("attn", prefix),
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
context_layer = self.attn(
q,
k,
v,
forward_batch,
)
attn_output, _ = self.dense(context_layer)
return attn_output
class GLMMLP(nn.Module):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.add_bias = config.add_bias_linear
# Project to 4h.
self.dense_h_to_4h = MergedColumnParallelLinear(
config.hidden_size,
[config.ffn_hidden_size] * 2,
bias=config.add_bias_linear,
quant_config=quant_config,
prefix=add_prefix("dense_h_to_4h", prefix),
)
self.activation_func = SiluAndMul()
# Project back to h.
self.dense_4h_to_h = RowParallelLinear(
config.ffn_hidden_size,
config.hidden_size,
bias=config.add_bias_linear,
quant_config=quant_config,
prefix=add_prefix("dense_4h_to_h", prefix),
)
def forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
output, _ = self.dense_4h_to_h(intermediate_parallel)
return output
class GLMBlock(nn.Module):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
def __init__(
self,
config,
layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm
)
self.fp32_residual_connection = config.fp32_residual_connection
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
# Layernorm on the input data.
self.input_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon
)
# Self attention.
self.self_attention = GLMAttention(
config, layer_id, quant_config, prefix=add_prefix("self_attention", prefix)
)
self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output
self.post_attention_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon
)
# MLP
self.mlp = GLMMLP(config, quant_config, prefix=add_prefix("mlp", prefix))
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
# hidden_states: [num_tokens, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output = self.self_attention(
hidden_states=layernorm_output,
position_ids=position_ids,
forward_batch=forward_batch,
)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
layernorm_input = residual + attention_output
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
output = self.mlp(layernorm_output) + residual
return output
class GLMTransformer(nn.Module):
"""Transformer class."""
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.post_layer_norm = config.post_layer_norm
# Number of layers.
self.num_layers = config.num_layers
# Transformer layers.
self.layers = nn.ModuleList(
[
GLMBlock(
config,
i,
quant_config,
prefix=add_prefix(f"layers.{i}", prefix),
)
for i in range(self.num_layers)
]
)
if self.post_layer_norm:
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
# Final layer norm before output.
self.final_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
for i in range(self.num_layers):
layer = self.layers[i]
hidden_states = layer(
hidden_states=hidden_states,
position_ids=position_ids,
forward_batch=forward_batch,
)
# Final layer norm.
if self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
class ChatGLMM(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.embedding = VocabParallelEmbedding(
config.padded_vocab_size,
config.hidden_size,
prefix=add_prefix("embedding", prefix),
)
self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(
config, quant_config, add_prefix("encoder", prefix)
)
self.output_layer = ParallelLMHead(
config.padded_vocab_size,
config.hidden_size,
prefix=add_prefix("output_layer", prefix),
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
inputs_embeds = self.embedding(input_ids)
# Run encoder.
hidden_states = self.encoder(
hidden_states=inputs_embeds,
position_ids=position_ids,
forward_batch=forward_batch,
)
return hidden_states
class ChatGLMForCausalLM(nn.Module):
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"],
}
# LoRA specific attributes
supported_lora_modules = [
"query_key_value",
"dense",
"dense_h_to_4h",
"dense_4h_to_h",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
config: ChatGLMConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config: ChatGLMConfig = config
self.quant_config = quant_config
self.max_position_embeddings = getattr(config, "max_sequence_length", 8192)
self.transformer = ChatGLMM(
config, quant_config, prefix=add_prefix("transformer", prefix)
)
self.lm_head = self.transformer.output_layer
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "rotary_pos_emb.inv_freq" in name:
continue
if "word_embeddings" in name:
name = name.replace(".word_embeddings", "")
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
class ChatGLMModel(ChatGLMForCausalLM):
pass
EntryClass = [ChatGLMModel]