sglang0.4.5.post1/python/sglang/srt/models/gemma3_causal.py

695 lines
24 KiB
Python

# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import copy
from typing import Iterable, Optional, Set, Tuple
import einops
import torch
import torch.nn.functional as F
from torch import nn
from transformers import (
ROPE_INIT_FUNCTIONS,
AutoModel,
Gemma3TextConfig,
PretrainedConfig,
PreTrainedModel,
)
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.layernorm import Gemma3RMSNorm
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.utils import add_prefix, make_layers
# Aligned with HF's implementation, using sliding window inclusive with the last token
# SGLang assumes exclusive
def get_attention_sliding_window_size(config):
return config.sliding_window - 1
# Adapted from:
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3.py
def extract_layer_index(prefix: str) -> int:
"""Extract the layer index from a prefix string."""
parts = prefix.split(".")
for part in parts:
if part.startswith("layers."):
layer_str = part.split(".")[-1]
try:
return int(layer_str)
except ValueError:
continue
return -1
class Gemma3MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_activation: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
)
if hidden_activation != "gelu_pytorch_tanh":
raise ValueError(
"Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
"function. Please set `hidden_activation` to "
"`gelu_pytorch_tanh`."
)
self.act_fn = GeluAndMul()
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class Gemma3Attention(nn.Module):
def __init__(
self,
layer_id: int,
config: Gemma3TextConfig,
max_position_embeddings: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.layer_id = layer_id
self.config = config
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = config.num_key_value_heads
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
hidden_size = config.hidden_size
head_dim = getattr(
config, "head_dim", hidden_size // config.num_attention_heads
)
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = config.query_pre_attn_scalar**-0.5
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.attention_bias,
quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=config.attention_bias,
quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
)
# Determine if layer uses sliding window based on pattern
self.is_sliding = bool((layer_id + 1) % config.sliding_window_pattern)
# Initialize the rotary embedding.
if self.is_sliding:
# Local attention. Override the values in config.json.
self.rope_theta = config.rope_local_base_freq
self.rope_scaling = {"rope_type": "default"}
# FIXME(mick): idk why vllm does this
# self.sliding_window = config.interleaved_sliding_window
self.sliding_window = get_attention_sliding_window_size(config)
else:
# Global attention. Use the values in config.json.
self.rope_theta = config.rope_theta
self.rope_scaling = config.rope_scaling
self.sliding_window = None
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
logit_cap=getattr(self.config, "attn_logit_softcapping", None),
# Module must also define `get_attention_sliding_window_size` to correctly initialize
# attention backend in `ForwardBatch`.
sliding_window_size=self.sliding_window,
prefix=add_prefix("attn", prefix),
)
# Gemma3 adds normalization for q and k
self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
def naive_attn_with_masks(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
**kwargs,
) -> torch.Tensor:
q = q.view(-1, self.num_heads, self.head_dim)
# Expand the key and value to handle GQA.
num_queries_per_kv = self.num_heads // self.num_kv_heads
k = k.view(-1, self.num_kv_heads, self.head_dim)
k = k.repeat_interleave(num_queries_per_kv, dim=-2)
v = v.view(-1, self.num_kv_heads, self.head_dim)
v = v.repeat_interleave(num_queries_per_kv, dim=-2)
if self.is_sliding:
attn_masks = kwargs["local_attn_masks"]
else:
attn_masks = kwargs["global_attn_masks"]
seq_lens = kwargs["seq_lens"]
start_idx = 0
for seq_len, attn_mask in zip(seq_lens, attn_masks):
end_idx = start_idx + seq_len
query = q[start_idx:end_idx].unsqueeze(0)
key = k[start_idx:end_idx].unsqueeze(0)
value = v[start_idx:end_idx].unsqueeze(0)
# Transpose.
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
output = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask,
self.scaling,
)
output = output.transpose(1, 2).flatten(-2, -1)
out[start_idx:end_idx] = output
start_idx = end_idx
return out
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
forward_batch: ForwardBatch,
**kwargs,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
# [s, h * head_dim]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# [s, h, head_dim]
q = q.unflatten(-1, (self.num_heads, self.head_dim))
# -> [h, s, head_dim]
q = q.transpose(0, 1).unsqueeze(0)
q = self.q_norm(q)
k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
# -> [h, s, head_dim]
k = k.transpose(0, 1).unsqueeze(0)
k = self.k_norm(k)
# q, k = self.rotary_emb(positions, q, k)
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb(q, k, cos, sin)
# [b, h, s, head_dim] -> [b, s, h, head_dim]
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
attn_output = self.attn(q, k, v, forward_batch=forward_batch)
output, _ = self.o_proj(attn_output)
return output
class Gemma3DecoderLayer(nn.Module):
def __init__(
self,
layer_id: int,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Gemma3Attention(
layer_id=layer_id,
config=config,
max_position_embeddings=config.max_position_embeddings,
quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
)
self.hidden_size = config.hidden_size
self.mlp = Gemma3MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_activation=config.hidden_activation,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
self.input_layernorm = Gemma3RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_attention_layernorm = Gemma3RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.pre_feedforward_layernorm = Gemma3RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_feedforward_layernorm = Gemma3RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.is_sliding = self.self_attn.is_sliding
self.layer_id = layer_id
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
position_embeddings_global: torch.Tensor,
position_embeddings_local: torch.Tensor,
forward_batch: ForwardBatch,
**kwargs,
) -> tuple[
torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]
]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# apply global RoPE to non-sliding layer only
if self.self_attn.is_sliding:
position_embeddings = position_embeddings_local
else:
position_embeddings = position_embeddings_global
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
position_embeddings=position_embeddings,
forward_batch=forward_batch,
**kwargs,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
return outputs
class Gemma3RotaryEmbedding(nn.Module):
def __init__(self, config: Gemma3TextConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get(
"rope_type", config.rope_scaling.get("type")
)
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len
)
self.register_buffer(
"inv_freq", inv_freq, persistent=False
) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if (
seq_len < self.original_max_seq_len
and self.max_seq_len_cached > self.original_max_seq_len
): # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = (
device_type
if isinstance(device_type, str) and device_type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False):
freqs = (
inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class Gemma3TextScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int,
embed_scale: Optional[float] = 1.0,
):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
class Gemma3TextModel(PreTrainedModel):
def __init__(
self,
config: Gemma3TextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config=config)
self.config = config
self.quant_config = quant_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
# Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
self.embed_tokens = Gemma3TextScaledWordEmbedding(
config.vocab_size,
config.hidden_size,
self.padding_idx,
embed_scale=self.config.hidden_size**0.5,
)
self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Gemma3RotaryEmbedding(config=config)
self.gradient_checkpointing = False
# when we want to create a local RoPE layer. Config defaults should hold values for global RoPE
config = copy.deepcopy(config)
config.rope_theta = config.rope_local_base_freq
config.rope_scaling = {"rope_type": "default"}
self.rotary_emb_local = Gemma3RotaryEmbedding(config=config)
self.layers = make_layers(
config.num_hidden_layers,
lambda idx, prefix: Gemma3DecoderLayer(
layer_id=idx,
config=config,
quant_config=quant_config,
prefix=prefix,
),
prefix=add_prefix("layers", prefix),
)
self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_init()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
if positions.dim() == 1:
positions = einops.rearrange(positions, "s -> 1 s")
position_embeddings_global = self.rotary_emb(hidden_states, positions)
position_embeddings_local = self.rotary_emb_local(hidden_states, positions)
for layer in self.layers:
layer_outputs = layer(
positions=positions,
position_embeddings_global=position_embeddings_global,
position_embeddings_local=position_embeddings_local,
hidden_states=hidden_states,
forward_batch=forward_batch,
**kwargs,
)
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
return hidden_states
class Gemma3ForCausalLM(PreTrainedModel):
config_class = Gemma3TextConfig
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
config_class = Gemma3TextConfig
base_model_prefix = "language_model"
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
# Gemma does not apply LoRA to the embedding layer.
embedding_modules = {}
embedding_padding_modules = []
supports_lora = True
def __init__(
self,
config: Gemma3TextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config=config)
self.config = config
self.quant_config = quant_config
self.model = Gemma3TextModel(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.logits_processor = LogitsProcessor(config)
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
def get_attention_sliding_window_size(self):
return get_attention_sliding_window_size(self.config)
def dtype(self) -> torch.dtype:
return next(self.parameters()).dtype
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
**kwargs,
) -> LogitsProcessor:
hidden_states = self.model(
input_ids, positions, forward_batch, input_embeds, **kwargs
)
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for param_name, shard_name, shard_id in stacked_params_mapping:
# if param_name in name:
# print(f"{param_name} is already in {name}")
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
# unloaded_params = params_dict.keys() - loaded_params
# if unloaded_params:
# logger.warning(
# "Some weights are not initialized from checkpoints: %s", unloaded_params
# )
return loaded_params
EntryClass = Gemma3ForCausalLM
AutoModel.register(Gemma3TextConfig, Gemma3ForCausalLM, exist_ok=True)