462 lines
16 KiB
Python
462 lines
16 KiB
Python
# Copyright 2023-2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
# Adapted from:
|
|
# https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py
|
|
|
|
from typing import Iterable, Optional, Set, Tuple
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers import PretrainedConfig
|
|
|
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
|
from sglang.srt.layers.activation import GeluAndMul
|
|
from sglang.srt.layers.layernorm import GemmaRMSNorm
|
|
from sglang.srt.layers.linear import (
|
|
MergedColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
from sglang.srt.layers.rotary_embedding import get_rope
|
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
from sglang.srt.model_loader.weight_utils import (
|
|
default_weight_loader,
|
|
maybe_remap_kv_scale_name,
|
|
)
|
|
from sglang.srt.utils import add_prefix, make_layers
|
|
|
|
|
|
# Aligned with HF's implementation, using sliding window inclusive with the last token
|
|
# SGLang assumes exclusive
|
|
def get_attention_sliding_window_size(config):
|
|
return config.sliding_window - 1
|
|
|
|
|
|
class Gemma2MLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
hidden_act: str,
|
|
hidden_activation: str,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.gate_up_proj = MergedColumnParallelLinear(
|
|
hidden_size,
|
|
[intermediate_size] * 2,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("gate_up_proj", prefix),
|
|
)
|
|
self.down_proj = RowParallelLinear(
|
|
intermediate_size,
|
|
hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("down_proj", prefix),
|
|
)
|
|
if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"):
|
|
raise ValueError(
|
|
"Gemma2 uses `gelu_pytorch_tanh` as the hidden activation "
|
|
"function. Please set `hidden_act` and `hidden_activation` to "
|
|
"`gelu_pytorch_tanh`."
|
|
)
|
|
self.act_fn = GeluAndMul()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
gate_up, _ = self.gate_up_proj(x)
|
|
x = self.act_fn(gate_up)
|
|
x, _ = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
class Gemma2Attention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
layer_id: int,
|
|
config: PretrainedConfig,
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
num_kv_heads: int,
|
|
head_dim: int,
|
|
max_position_embeddings: int,
|
|
rope_theta: float,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.layer_id = layer_id
|
|
self.config = config
|
|
self.hidden_size = hidden_size
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
self.total_num_heads = num_heads
|
|
assert self.total_num_heads % tp_size == 0
|
|
self.num_heads = self.total_num_heads // tp_size
|
|
self.total_num_kv_heads = num_kv_heads
|
|
if self.total_num_kv_heads >= tp_size:
|
|
# Number of KV heads is greater than TP size, so we partition
|
|
# the KV heads across multiple tensor parallel GPUs.
|
|
assert self.total_num_kv_heads % tp_size == 0
|
|
else:
|
|
# Number of KV heads is less than TP size, so we replicate
|
|
# the KV heads across multiple tensor parallel GPUs.
|
|
assert tp_size % self.total_num_kv_heads == 0
|
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
|
self.head_dim = head_dim
|
|
self.q_size = self.num_heads * self.head_dim
|
|
self.kv_size = self.num_kv_heads * self.head_dim
|
|
self.scaling = config.query_pre_attn_scalar**-0.5
|
|
self.rope_theta = rope_theta
|
|
|
|
self.qkv_proj = QKVParallelLinear(
|
|
hidden_size,
|
|
self.head_dim,
|
|
self.total_num_heads,
|
|
self.total_num_kv_heads,
|
|
bias=config.attention_bias,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("qkv_proj", prefix),
|
|
)
|
|
self.o_proj = RowParallelLinear(
|
|
self.total_num_heads * self.head_dim,
|
|
hidden_size,
|
|
bias=config.attention_bias,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("o_proj", prefix),
|
|
)
|
|
self.rotary_emb = get_rope(
|
|
self.head_dim,
|
|
rotary_dim=self.head_dim,
|
|
max_position=max_position_embeddings,
|
|
base=self.rope_theta,
|
|
is_neox_style=True,
|
|
)
|
|
|
|
use_sliding_window = layer_id % 2 == 0 and hasattr(config, "sliding_window")
|
|
self.attn = RadixAttention(
|
|
self.num_heads,
|
|
self.head_dim,
|
|
self.scaling,
|
|
num_kv_heads=self.num_kv_heads,
|
|
layer_id=layer_id,
|
|
logit_cap=self.config.attn_logit_softcapping,
|
|
sliding_window_size=(
|
|
get_attention_sliding_window_size(config)
|
|
if use_sliding_window
|
|
else None
|
|
),
|
|
prefix=add_prefix("attn", prefix),
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
) -> torch.Tensor:
|
|
qkv, _ = self.qkv_proj(hidden_states)
|
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
q, k = self.rotary_emb(positions, q, k)
|
|
attn_output = self.attn(q, k, v, forward_batch)
|
|
output, _ = self.o_proj(attn_output)
|
|
return output
|
|
|
|
|
|
class Gemma2DecoderLayer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
layer_id: int,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
self.self_attn = Gemma2Attention(
|
|
layer_id=layer_id,
|
|
config=config,
|
|
hidden_size=self.hidden_size,
|
|
num_heads=config.num_attention_heads,
|
|
num_kv_heads=config.num_key_value_heads,
|
|
head_dim=config.head_dim,
|
|
max_position_embeddings=config.max_position_embeddings,
|
|
rope_theta=config.rope_theta,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("self_attn", prefix),
|
|
)
|
|
self.hidden_size = config.hidden_size
|
|
self.mlp = Gemma2MLP(
|
|
hidden_size=self.hidden_size,
|
|
intermediate_size=config.intermediate_size,
|
|
hidden_act=config.hidden_act,
|
|
hidden_activation=config.hidden_activation,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("mlp", prefix),
|
|
)
|
|
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.post_attention_layernorm = GemmaRMSNorm(
|
|
config.hidden_size, eps=config.rms_norm_eps
|
|
)
|
|
self.pre_feedforward_layernorm = GemmaRMSNorm(
|
|
config.hidden_size, eps=config.rms_norm_eps
|
|
)
|
|
self.post_feedforward_layernorm = GemmaRMSNorm(
|
|
config.hidden_size, eps=config.rms_norm_eps
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
residual: Optional[torch.Tensor],
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
if residual is None:
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
else:
|
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
|
hidden_states = self.self_attn(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
forward_batch=forward_batch,
|
|
)
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
|
|
hidden_states, residual = self.pre_feedforward_layernorm(
|
|
hidden_states, residual
|
|
)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
|
return hidden_states, residual
|
|
|
|
|
|
class Gemma2Model(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
self.embed_tokens = VocabParallelEmbedding(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
)
|
|
self.layers = make_layers(
|
|
config.num_hidden_layers,
|
|
lambda idx, prefix: Gemma2DecoderLayer(
|
|
layer_id=idx,
|
|
config=config,
|
|
quant_config=quant_config,
|
|
),
|
|
prefix=add_prefix("layers", prefix),
|
|
)
|
|
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
# Normalize the embedding by sqrt(hidden_size)
|
|
# The normalizer's data type should be downcasted to the model's
|
|
# data type such as bfloat16, not float32.
|
|
# See https://github.com/huggingface/transformers/pull/29402
|
|
normalizer = self.config.hidden_size**0.5
|
|
self.register_buffer("normalizer", torch.tensor(normalizer))
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
input_embeds: torch.Tensor = None,
|
|
) -> torch.Tensor:
|
|
if input_embeds is None:
|
|
hidden_states = self.embed_tokens(input_ids)
|
|
else:
|
|
hidden_states = input_embeds
|
|
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=torch.float16)
|
|
hidden_states *= normalizer
|
|
|
|
residual = None
|
|
for i in range(len(self.layers)):
|
|
layer = self.layers[i]
|
|
hidden_states, residual = layer(
|
|
positions,
|
|
hidden_states,
|
|
forward_batch,
|
|
residual,
|
|
)
|
|
hidden_states, _ = self.norm(hidden_states, residual)
|
|
return hidden_states
|
|
|
|
|
|
class Gemma2ForCausalLM(nn.Module):
|
|
# BitandBytes specific attributes
|
|
default_bitsandbytes_target_modules = [
|
|
".gate_proj.",
|
|
".down_proj.",
|
|
".up_proj.",
|
|
".q_proj.",
|
|
".k_proj.",
|
|
".v_proj.",
|
|
".o_proj.",
|
|
]
|
|
bitsandbytes_stacked_params_mapping = {
|
|
# shard_name, weight_name, index
|
|
"q_proj": ("qkv_proj", 0),
|
|
"k_proj": ("qkv_proj", 1),
|
|
"v_proj": ("qkv_proj", 2),
|
|
"gate_proj": ("gate_up_proj", 0),
|
|
"up_proj": ("gate_up_proj", 1),
|
|
}
|
|
|
|
packed_modules_mapping = {
|
|
"qkv_proj": [
|
|
"q_proj",
|
|
"k_proj",
|
|
"v_proj",
|
|
],
|
|
"gate_up_proj": [
|
|
"gate_proj",
|
|
"up_proj",
|
|
],
|
|
}
|
|
|
|
# LoRA specific attributes
|
|
supported_lora_modules = [
|
|
"qkv_proj",
|
|
"o_proj",
|
|
"gate_up_proj",
|
|
"down_proj",
|
|
]
|
|
# Gemma does not apply LoRA to the embedding layer.
|
|
embedding_modules = {}
|
|
embedding_padding_modules = []
|
|
supports_lora = True
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.config = config
|
|
self.quant_config = quant_config
|
|
self.model = Gemma2Model(
|
|
config, quant_config, prefix=add_prefix("model", prefix)
|
|
)
|
|
self.logits_processor = LogitsProcessor(config)
|
|
|
|
@torch.no_grad()
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
input_embeds: torch.Tensor = None,
|
|
) -> torch.Tensor:
|
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
|
return self.logits_processor(
|
|
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
|
)
|
|
|
|
def get_hidden_dim(self, module_name):
|
|
# return input_dim, output_dim
|
|
if module_name in ["q_proj", "qkv_proj"]:
|
|
return (
|
|
self.config.hidden_size,
|
|
self.config.head_dim * self.config.num_attention_heads,
|
|
)
|
|
elif module_name in ["o_proj"]:
|
|
return (
|
|
self.config.head_dim * self.config.num_attention_heads,
|
|
self.config.hidden_size,
|
|
)
|
|
elif module_name in ["kv_proj"]:
|
|
return (
|
|
self.config.hidden_size,
|
|
self.config.head_dim * self.config.num_key_value_heads,
|
|
)
|
|
elif module_name == "gate_up_proj":
|
|
return self.config.hidden_size, self.config.intermediate_size
|
|
elif module_name == "down_proj":
|
|
return self.config.intermediate_size, self.config.hidden_size
|
|
else:
|
|
raise NotImplementedError()
|
|
|
|
def get_module_name(self, name):
|
|
params_mapping = {
|
|
"q_proj": "qkv_proj",
|
|
"k_proj": "qkv_proj",
|
|
"v_proj": "qkv_proj",
|
|
"gate_proj": "gate_up_proj",
|
|
"up_proj": "gate_up_proj",
|
|
}
|
|
return params_mapping.get(name, name)
|
|
|
|
def get_attention_sliding_window_size(self):
|
|
return get_attention_sliding_window_size(self.config)
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("qkv_proj", "q_proj", "q"),
|
|
("qkv_proj", "k_proj", "k"),
|
|
("qkv_proj", "v_proj", "v"),
|
|
("gate_up_proj", "gate_proj", 0),
|
|
("gate_up_proj", "up_proj", 1),
|
|
]
|
|
params_dict = dict(self.named_parameters())
|
|
loaded_params: Set[str] = set()
|
|
for name, loaded_weight in weights:
|
|
for param_name, shard_name, shard_id in stacked_params_mapping:
|
|
if shard_name not in name:
|
|
continue
|
|
name = name.replace(shard_name, param_name)
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
# lm_head is not used in vllm as it is tied with embed_token.
|
|
# To prevent errors, skip loading lm_head.weight.
|
|
if "lm_head.weight" in name:
|
|
continue
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
# Remapping the name of FP8 kv-scale.
|
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
|
if name is None:
|
|
continue
|
|
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
|
|
|
|
EntryClass = Gemma2ForCausalLM
|