695 lines
24 KiB
Python
695 lines
24 KiB
Python
# Copyright 2025 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
import copy
|
|
from typing import Iterable, Optional, Set, Tuple
|
|
|
|
import einops
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
from transformers import (
|
|
ROPE_INIT_FUNCTIONS,
|
|
AutoModel,
|
|
Gemma3TextConfig,
|
|
PretrainedConfig,
|
|
PreTrainedModel,
|
|
)
|
|
|
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
|
from sglang.srt.layers.activation import GeluAndMul
|
|
from sglang.srt.layers.layernorm import Gemma3RMSNorm
|
|
from sglang.srt.layers.linear import (
|
|
MergedColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
|
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
from sglang.srt.model_loader.weight_utils import (
|
|
default_weight_loader,
|
|
maybe_remap_kv_scale_name,
|
|
)
|
|
from sglang.srt.utils import add_prefix, make_layers
|
|
|
|
|
|
# Aligned with HF's implementation, using sliding window inclusive with the last token
|
|
# SGLang assumes exclusive
|
|
def get_attention_sliding_window_size(config):
|
|
return config.sliding_window - 1
|
|
|
|
|
|
# Adapted from:
|
|
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3.py
|
|
def extract_layer_index(prefix: str) -> int:
|
|
"""Extract the layer index from a prefix string."""
|
|
parts = prefix.split(".")
|
|
for part in parts:
|
|
if part.startswith("layers."):
|
|
layer_str = part.split(".")[-1]
|
|
try:
|
|
return int(layer_str)
|
|
except ValueError:
|
|
continue
|
|
return -1
|
|
|
|
|
|
class Gemma3MLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
hidden_activation: str,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.gate_up_proj = MergedColumnParallelLinear(
|
|
hidden_size,
|
|
[intermediate_size] * 2,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("gate_up_proj", prefix),
|
|
)
|
|
self.down_proj = RowParallelLinear(
|
|
intermediate_size,
|
|
hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("down_proj", prefix),
|
|
)
|
|
if hidden_activation != "gelu_pytorch_tanh":
|
|
raise ValueError(
|
|
"Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
|
|
"function. Please set `hidden_activation` to "
|
|
"`gelu_pytorch_tanh`."
|
|
)
|
|
self.act_fn = GeluAndMul()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
gate_up, _ = self.gate_up_proj(x)
|
|
x = self.act_fn(gate_up)
|
|
x, _ = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
class Gemma3Attention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
layer_id: int,
|
|
config: Gemma3TextConfig,
|
|
max_position_embeddings: int,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.layer_id = layer_id
|
|
self.config = config
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
|
|
self.total_num_heads = config.num_attention_heads
|
|
assert self.total_num_heads % tp_size == 0
|
|
self.num_heads = self.total_num_heads // tp_size
|
|
self.total_num_kv_heads = config.num_key_value_heads
|
|
|
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
|
|
|
if self.total_num_kv_heads >= tp_size:
|
|
# Number of KV heads is greater than TP size, so we partition
|
|
# the KV heads across multiple tensor parallel GPUs.
|
|
assert self.total_num_kv_heads % tp_size == 0
|
|
else:
|
|
# Number of KV heads is less than TP size, so we replicate
|
|
# the KV heads across multiple tensor parallel GPUs.
|
|
assert tp_size % self.total_num_kv_heads == 0
|
|
|
|
hidden_size = config.hidden_size
|
|
|
|
head_dim = getattr(
|
|
config, "head_dim", hidden_size // config.num_attention_heads
|
|
)
|
|
self.head_dim = head_dim
|
|
|
|
self.q_size = self.num_heads * self.head_dim
|
|
|
|
self.kv_size = self.num_kv_heads * self.head_dim
|
|
self.scaling = config.query_pre_attn_scalar**-0.5
|
|
|
|
self.qkv_proj = QKVParallelLinear(
|
|
hidden_size,
|
|
self.head_dim,
|
|
self.total_num_heads,
|
|
self.total_num_kv_heads,
|
|
bias=config.attention_bias,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("qkv_proj", prefix),
|
|
)
|
|
self.o_proj = RowParallelLinear(
|
|
self.total_num_heads * self.head_dim,
|
|
hidden_size,
|
|
bias=config.attention_bias,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("o_proj", prefix),
|
|
)
|
|
|
|
# Determine if layer uses sliding window based on pattern
|
|
self.is_sliding = bool((layer_id + 1) % config.sliding_window_pattern)
|
|
|
|
# Initialize the rotary embedding.
|
|
if self.is_sliding:
|
|
# Local attention. Override the values in config.json.
|
|
self.rope_theta = config.rope_local_base_freq
|
|
self.rope_scaling = {"rope_type": "default"}
|
|
# FIXME(mick): idk why vllm does this
|
|
# self.sliding_window = config.interleaved_sliding_window
|
|
self.sliding_window = get_attention_sliding_window_size(config)
|
|
else:
|
|
# Global attention. Use the values in config.json.
|
|
self.rope_theta = config.rope_theta
|
|
self.rope_scaling = config.rope_scaling
|
|
self.sliding_window = None
|
|
|
|
self.attn = RadixAttention(
|
|
self.num_heads,
|
|
self.head_dim,
|
|
self.scaling,
|
|
num_kv_heads=self.num_kv_heads,
|
|
layer_id=layer_id,
|
|
logit_cap=getattr(self.config, "attn_logit_softcapping", None),
|
|
# Module must also define `get_attention_sliding_window_size` to correctly initialize
|
|
# attention backend in `ForwardBatch`.
|
|
sliding_window_size=self.sliding_window,
|
|
prefix=add_prefix("attn", prefix),
|
|
)
|
|
|
|
# Gemma3 adds normalization for q and k
|
|
self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
|
|
self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
|
|
|
|
def naive_attn_with_masks(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
out: torch.Tensor,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
q = q.view(-1, self.num_heads, self.head_dim)
|
|
# Expand the key and value to handle GQA.
|
|
num_queries_per_kv = self.num_heads // self.num_kv_heads
|
|
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
|
k = k.repeat_interleave(num_queries_per_kv, dim=-2)
|
|
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
|
v = v.repeat_interleave(num_queries_per_kv, dim=-2)
|
|
|
|
if self.is_sliding:
|
|
attn_masks = kwargs["local_attn_masks"]
|
|
else:
|
|
attn_masks = kwargs["global_attn_masks"]
|
|
|
|
seq_lens = kwargs["seq_lens"]
|
|
start_idx = 0
|
|
for seq_len, attn_mask in zip(seq_lens, attn_masks):
|
|
end_idx = start_idx + seq_len
|
|
query = q[start_idx:end_idx].unsqueeze(0)
|
|
key = k[start_idx:end_idx].unsqueeze(0)
|
|
value = v[start_idx:end_idx].unsqueeze(0)
|
|
|
|
# Transpose.
|
|
query = query.transpose(1, 2)
|
|
key = key.transpose(1, 2)
|
|
value = value.transpose(1, 2)
|
|
|
|
output = F.scaled_dot_product_attention(
|
|
query,
|
|
key,
|
|
value,
|
|
attn_mask,
|
|
self.scaling,
|
|
)
|
|
output = output.transpose(1, 2).flatten(-2, -1)
|
|
out[start_idx:end_idx] = output
|
|
start_idx = end_idx
|
|
return out
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
|
forward_batch: ForwardBatch,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
qkv, _ = self.qkv_proj(hidden_states)
|
|
# [s, h * head_dim]
|
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
|
|
# [s, h, head_dim]
|
|
q = q.unflatten(-1, (self.num_heads, self.head_dim))
|
|
# -> [h, s, head_dim]
|
|
q = q.transpose(0, 1).unsqueeze(0)
|
|
q = self.q_norm(q)
|
|
k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
|
|
# -> [h, s, head_dim]
|
|
k = k.transpose(0, 1).unsqueeze(0)
|
|
k = self.k_norm(k)
|
|
|
|
# q, k = self.rotary_emb(positions, q, k)
|
|
cos, sin = position_embeddings
|
|
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
|
|
|
# [b, h, s, head_dim] -> [b, s, h, head_dim]
|
|
q = q.permute(0, 2, 1, 3)
|
|
k = k.permute(0, 2, 1, 3)
|
|
|
|
attn_output = self.attn(q, k, v, forward_batch=forward_batch)
|
|
output, _ = self.o_proj(attn_output)
|
|
return output
|
|
|
|
|
|
class Gemma3DecoderLayer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
layer_id: int,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
self.self_attn = Gemma3Attention(
|
|
layer_id=layer_id,
|
|
config=config,
|
|
max_position_embeddings=config.max_position_embeddings,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("self_attn", prefix),
|
|
)
|
|
self.hidden_size = config.hidden_size
|
|
self.mlp = Gemma3MLP(
|
|
hidden_size=self.hidden_size,
|
|
intermediate_size=config.intermediate_size,
|
|
hidden_activation=config.hidden_activation,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("mlp", prefix),
|
|
)
|
|
self.input_layernorm = Gemma3RMSNorm(
|
|
config.hidden_size, eps=config.rms_norm_eps
|
|
)
|
|
self.post_attention_layernorm = Gemma3RMSNorm(
|
|
config.hidden_size, eps=config.rms_norm_eps
|
|
)
|
|
self.pre_feedforward_layernorm = Gemma3RMSNorm(
|
|
config.hidden_size, eps=config.rms_norm_eps
|
|
)
|
|
self.post_feedforward_layernorm = Gemma3RMSNorm(
|
|
config.hidden_size, eps=config.rms_norm_eps
|
|
)
|
|
self.is_sliding = self.self_attn.is_sliding
|
|
self.layer_id = layer_id
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
position_embeddings_global: torch.Tensor,
|
|
position_embeddings_local: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
**kwargs,
|
|
) -> tuple[
|
|
torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]
|
|
]:
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
|
|
# apply global RoPE to non-sliding layer only
|
|
if self.self_attn.is_sliding:
|
|
position_embeddings = position_embeddings_local
|
|
else:
|
|
position_embeddings = position_embeddings_global
|
|
|
|
hidden_states = self.self_attn(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
position_embeddings=position_embeddings,
|
|
forward_batch=forward_batch,
|
|
**kwargs,
|
|
)
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.pre_feedforward_layernorm(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
return outputs
|
|
|
|
|
|
class Gemma3RotaryEmbedding(nn.Module):
|
|
def __init__(self, config: Gemma3TextConfig, device=None):
|
|
super().__init__()
|
|
# BC: "rope_type" was originally "type"
|
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
|
self.rope_type = config.rope_scaling.get(
|
|
"rope_type", config.rope_scaling.get("type")
|
|
)
|
|
else:
|
|
self.rope_type = "default"
|
|
self.max_seq_len_cached = config.max_position_embeddings
|
|
self.original_max_seq_len = config.max_position_embeddings
|
|
|
|
self.config = config
|
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
|
|
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
self.original_inv_freq = self.inv_freq
|
|
|
|
def _dynamic_frequency_update(self, position_ids, device):
|
|
"""
|
|
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
|
1 - growing beyond the cached sequence length (allow scaling)
|
|
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
|
"""
|
|
seq_len = torch.max(position_ids) + 1
|
|
if seq_len > self.max_seq_len_cached: # growth
|
|
inv_freq, self.attention_scaling = self.rope_init_fn(
|
|
self.config, device, seq_len=seq_len
|
|
)
|
|
self.register_buffer(
|
|
"inv_freq", inv_freq, persistent=False
|
|
) # TODO joao: may break with compilation
|
|
self.max_seq_len_cached = seq_len
|
|
|
|
if (
|
|
seq_len < self.original_max_seq_len
|
|
and self.max_seq_len_cached > self.original_max_seq_len
|
|
): # reset
|
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
|
# the buffer is automatically moved, but not the original copy)
|
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
|
self.max_seq_len_cached = self.original_max_seq_len
|
|
|
|
@torch.no_grad()
|
|
def forward(self, x, position_ids):
|
|
if "dynamic" in self.rope_type:
|
|
self._dynamic_frequency_update(position_ids, device=x.device)
|
|
|
|
# Core RoPE block
|
|
inv_freq_expanded = (
|
|
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
|
)
|
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
|
device_type = x.device.type
|
|
device_type = (
|
|
device_type
|
|
if isinstance(device_type, str) and device_type != "mps"
|
|
else "cpu"
|
|
)
|
|
with torch.autocast(device_type=device_type, enabled=False):
|
|
freqs = (
|
|
inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()
|
|
).transpose(1, 2)
|
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
cos = emb.cos()
|
|
sin = emb.sin()
|
|
|
|
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
|
cos = cos * self.attention_scaling
|
|
sin = sin * self.attention_scaling
|
|
|
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
|
|
|
|
|
class Gemma3TextScaledWordEmbedding(nn.Embedding):
|
|
"""
|
|
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_embeddings: int,
|
|
embedding_dim: int,
|
|
padding_idx: int,
|
|
embed_scale: Optional[float] = 1.0,
|
|
):
|
|
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
|
self.embed_scale = embed_scale
|
|
|
|
def forward(self, input_ids: torch.Tensor):
|
|
return super().forward(input_ids) * self.embed_scale
|
|
|
|
|
|
class Gemma3TextModel(PreTrainedModel):
|
|
def __init__(
|
|
self,
|
|
config: Gemma3TextConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__(config=config)
|
|
self.config = config
|
|
self.quant_config = quant_config
|
|
|
|
self.padding_idx = config.pad_token_id
|
|
self.vocab_size = config.vocab_size
|
|
|
|
# Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
|
|
self.embed_tokens = Gemma3TextScaledWordEmbedding(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
self.padding_idx,
|
|
embed_scale=self.config.hidden_size**0.5,
|
|
)
|
|
|
|
self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.rotary_emb = Gemma3RotaryEmbedding(config=config)
|
|
self.gradient_checkpointing = False
|
|
|
|
# when we want to create a local RoPE layer. Config defaults should hold values for global RoPE
|
|
config = copy.deepcopy(config)
|
|
config.rope_theta = config.rope_local_base_freq
|
|
config.rope_scaling = {"rope_type": "default"}
|
|
self.rotary_emb_local = Gemma3RotaryEmbedding(config=config)
|
|
|
|
self.layers = make_layers(
|
|
config.num_hidden_layers,
|
|
lambda idx, prefix: Gemma3DecoderLayer(
|
|
layer_id=idx,
|
|
config=config,
|
|
quant_config=quant_config,
|
|
prefix=prefix,
|
|
),
|
|
prefix=add_prefix("layers", prefix),
|
|
)
|
|
self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.post_init()
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
input_embeds: torch.Tensor = None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
if input_embeds is None:
|
|
hidden_states = self.embed_tokens(input_ids)
|
|
else:
|
|
hidden_states = input_embeds
|
|
|
|
if positions.dim() == 1:
|
|
positions = einops.rearrange(positions, "s -> 1 s")
|
|
|
|
position_embeddings_global = self.rotary_emb(hidden_states, positions)
|
|
position_embeddings_local = self.rotary_emb_local(hidden_states, positions)
|
|
for layer in self.layers:
|
|
layer_outputs = layer(
|
|
positions=positions,
|
|
position_embeddings_global=position_embeddings_global,
|
|
position_embeddings_local=position_embeddings_local,
|
|
hidden_states=hidden_states,
|
|
forward_batch=forward_batch,
|
|
**kwargs,
|
|
)
|
|
hidden_states = layer_outputs[0]
|
|
|
|
hidden_states = self.norm(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class Gemma3ForCausalLM(PreTrainedModel):
|
|
config_class = Gemma3TextConfig
|
|
|
|
_tied_weights_keys = ["lm_head.weight"]
|
|
_tp_plan = {"lm_head": "colwise_rep"}
|
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
|
config_class = Gemma3TextConfig
|
|
base_model_prefix = "language_model"
|
|
|
|
# BitandBytes specific attributes
|
|
default_bitsandbytes_target_modules = [
|
|
".gate_proj.",
|
|
".down_proj.",
|
|
".up_proj.",
|
|
".q_proj.",
|
|
".k_proj.",
|
|
".v_proj.",
|
|
".o_proj.",
|
|
]
|
|
bitsandbytes_stacked_params_mapping = {
|
|
# shard_name, weight_name, index
|
|
"q_proj": ("qkv_proj", 0),
|
|
"k_proj": ("qkv_proj", 1),
|
|
"v_proj": ("qkv_proj", 2),
|
|
"gate_proj": ("gate_up_proj", 0),
|
|
"up_proj": ("gate_up_proj", 1),
|
|
}
|
|
|
|
packed_modules_mapping = {
|
|
"qkv_proj": [
|
|
"q_proj",
|
|
"k_proj",
|
|
"v_proj",
|
|
],
|
|
"gate_up_proj": [
|
|
"gate_proj",
|
|
"up_proj",
|
|
],
|
|
}
|
|
|
|
# LoRA specific attributes
|
|
supported_lora_modules = [
|
|
"qkv_proj",
|
|
"o_proj",
|
|
"gate_up_proj",
|
|
"down_proj",
|
|
]
|
|
# Gemma does not apply LoRA to the embedding layer.
|
|
embedding_modules = {}
|
|
embedding_padding_modules = []
|
|
supports_lora = True
|
|
|
|
def __init__(
|
|
self,
|
|
config: Gemma3TextConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__(config=config)
|
|
self.config = config
|
|
self.quant_config = quant_config
|
|
self.model = Gemma3TextModel(
|
|
config, quant_config, prefix=add_prefix("model", prefix)
|
|
)
|
|
self.logits_processor = LogitsProcessor(config)
|
|
|
|
if self.config.tie_word_embeddings:
|
|
self.lm_head = self.model.embed_tokens
|
|
else:
|
|
self.lm_head = ParallelLMHead(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("lm_head", prefix),
|
|
)
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self) -> nn.Embedding:
|
|
return self.model.embed_tokens
|
|
|
|
def get_attention_sliding_window_size(self):
|
|
return get_attention_sliding_window_size(self.config)
|
|
|
|
def dtype(self) -> torch.dtype:
|
|
return next(self.parameters()).dtype
|
|
|
|
@torch.no_grad()
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
input_embeds: torch.Tensor = None,
|
|
**kwargs,
|
|
) -> LogitsProcessor:
|
|
hidden_states = self.model(
|
|
input_ids, positions, forward_batch, input_embeds, **kwargs
|
|
)
|
|
|
|
return self.logits_processor(
|
|
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
|
)
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("qkv_proj", "q_proj", "q"),
|
|
("qkv_proj", "k_proj", "k"),
|
|
("qkv_proj", "v_proj", "v"),
|
|
("gate_up_proj", "gate_proj", 0),
|
|
("gate_up_proj", "up_proj", 1),
|
|
]
|
|
params_dict = dict(self.named_parameters())
|
|
loaded_params: Set[str] = set()
|
|
for name, loaded_weight in weights:
|
|
for param_name, shard_name, shard_id in stacked_params_mapping:
|
|
# if param_name in name:
|
|
# print(f"{param_name} is already in {name}")
|
|
if shard_name not in name:
|
|
continue
|
|
name = name.replace(shard_name, param_name)
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
# lm_head is not used in vllm as it is tied with embed_token.
|
|
# To prevent errors, skip loading lm_head.weight.
|
|
if "lm_head.weight" in name:
|
|
continue
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
# Remapping the name of FP8 kv-scale.
|
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
|
if name is None:
|
|
continue
|
|
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
# unloaded_params = params_dict.keys() - loaded_params
|
|
# if unloaded_params:
|
|
# logger.warning(
|
|
# "Some weights are not initialized from checkpoints: %s", unloaded_params
|
|
# )
|
|
return loaded_params
|
|
|
|
|
|
EntryClass = Gemma3ForCausalLM
|
|
AutoModel.register(Gemma3TextConfig, Gemma3ForCausalLM, exist_ok=True)
|