# Copyright 2025 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import copy from typing import Iterable, Optional, Set, Tuple import einops import torch import torch.nn.functional as F from torch import nn from transformers import ( ROPE_INIT_FUNCTIONS, AutoModel, Gemma3TextConfig, PretrainedConfig, PreTrainedModel, ) from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.layernorm import Gemma3RMSNorm from sglang.srt.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, ) from sglang.srt.utils import add_prefix, make_layers # Aligned with HF's implementation, using sliding window inclusive with the last token # SGLang assumes exclusive def get_attention_sliding_window_size(config): return config.sliding_window - 1 # Adapted from: # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3.py def extract_layer_index(prefix: str) -> int: """Extract the layer index from a prefix string.""" parts = prefix.split(".") for part in parts: if part.startswith("layers."): layer_str = part.split(".")[-1] try: return int(layer_str) except ValueError: continue return -1 class Gemma3MLP(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, hidden_activation: str, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config, prefix=add_prefix("gate_up_proj", prefix), ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, quant_config=quant_config, prefix=add_prefix("down_proj", prefix), ) if hidden_activation != "gelu_pytorch_tanh": raise ValueError( "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation " "function. Please set `hidden_activation` to " "`gelu_pytorch_tanh`." ) self.act_fn = GeluAndMul() def forward(self, x: torch.Tensor) -> torch.Tensor: gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) return x class Gemma3Attention(nn.Module): def __init__( self, layer_id: int, config: Gemma3TextConfig, max_position_embeddings: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.layer_id = layer_id self.config = config tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = config.num_key_value_heads self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. assert self.total_num_kv_heads % tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 hidden_size = config.hidden_size head_dim = getattr( config, "head_dim", hidden_size // config.num_attention_heads ) self.head_dim = head_dim self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = config.query_pre_attn_scalar**-0.5 self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=config.attention_bias, quant_config=quant_config, prefix=add_prefix("qkv_proj", prefix), ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=config.attention_bias, quant_config=quant_config, prefix=add_prefix("o_proj", prefix), ) # Determine if layer uses sliding window based on pattern self.is_sliding = bool((layer_id + 1) % config.sliding_window_pattern) # Initialize the rotary embedding. if self.is_sliding: # Local attention. Override the values in config.json. self.rope_theta = config.rope_local_base_freq self.rope_scaling = {"rope_type": "default"} # FIXME(mick): idk why vllm does this # self.sliding_window = config.interleaved_sliding_window self.sliding_window = get_attention_sliding_window_size(config) else: # Global attention. Use the values in config.json. self.rope_theta = config.rope_theta self.rope_scaling = config.rope_scaling self.sliding_window = None self.attn = RadixAttention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, logit_cap=getattr(self.config, "attn_logit_softcapping", None), # Module must also define `get_attention_sliding_window_size` to correctly initialize # attention backend in `ForwardBatch`. sliding_window_size=self.sliding_window, prefix=add_prefix("attn", prefix), ) # Gemma3 adds normalization for q and k self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) def naive_attn_with_masks( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor, **kwargs, ) -> torch.Tensor: q = q.view(-1, self.num_heads, self.head_dim) # Expand the key and value to handle GQA. num_queries_per_kv = self.num_heads // self.num_kv_heads k = k.view(-1, self.num_kv_heads, self.head_dim) k = k.repeat_interleave(num_queries_per_kv, dim=-2) v = v.view(-1, self.num_kv_heads, self.head_dim) v = v.repeat_interleave(num_queries_per_kv, dim=-2) if self.is_sliding: attn_masks = kwargs["local_attn_masks"] else: attn_masks = kwargs["global_attn_masks"] seq_lens = kwargs["seq_lens"] start_idx = 0 for seq_len, attn_mask in zip(seq_lens, attn_masks): end_idx = start_idx + seq_len query = q[start_idx:end_idx].unsqueeze(0) key = k[start_idx:end_idx].unsqueeze(0) value = v[start_idx:end_idx].unsqueeze(0) # Transpose. query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) output = F.scaled_dot_product_attention( query, key, value, attn_mask, self.scaling, ) output = output.transpose(1, 2).flatten(-2, -1) out[start_idx:end_idx] = output start_idx = end_idx return out def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], forward_batch: ForwardBatch, **kwargs, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) # [s, h * head_dim] q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # [s, h, head_dim] q = q.unflatten(-1, (self.num_heads, self.head_dim)) # -> [h, s, head_dim] q = q.transpose(0, 1).unsqueeze(0) q = self.q_norm(q) k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) # -> [h, s, head_dim] k = k.transpose(0, 1).unsqueeze(0) k = self.k_norm(k) # q, k = self.rotary_emb(positions, q, k) cos, sin = position_embeddings q, k = apply_rotary_pos_emb(q, k, cos, sin) # [b, h, s, head_dim] -> [b, s, h, head_dim] q = q.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3) attn_output = self.attn(q, k, v, forward_batch=forward_batch) output, _ = self.o_proj(attn_output) return output class Gemma3DecoderLayer(nn.Module): def __init__( self, layer_id: int, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size self.self_attn = Gemma3Attention( layer_id=layer_id, config=config, max_position_embeddings=config.max_position_embeddings, quant_config=quant_config, prefix=add_prefix("self_attn", prefix), ) self.hidden_size = config.hidden_size self.mlp = Gemma3MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_activation=config.hidden_activation, quant_config=quant_config, prefix=add_prefix("mlp", prefix), ) self.input_layernorm = Gemma3RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.post_attention_layernorm = Gemma3RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.pre_feedforward_layernorm = Gemma3RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.post_feedforward_layernorm = Gemma3RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.is_sliding = self.self_attn.is_sliding self.layer_id = layer_id def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, position_embeddings_global: torch.Tensor, position_embeddings_local: torch.Tensor, forward_batch: ForwardBatch, **kwargs, ) -> tuple[ torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] ]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # apply global RoPE to non-sliding layer only if self.self_attn.is_sliding: position_embeddings = position_embeddings_local else: position_embeddings = position_embeddings_global hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, position_embeddings=position_embeddings, forward_batch=forward_batch, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.pre_feedforward_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = self.post_feedforward_layernorm(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) return outputs class Gemma3RotaryEmbedding(nn.Module): def __init__(self, config: Gemma3TextConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get( "rope_type", config.rope_scaling.get("type") ) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq def _dynamic_frequency_update(self, position_ids, device): """ dynamic RoPE layers should recompute `inv_freq` in the following situations: 1 - growing beyond the cached sequence length (allow scaling) 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth inv_freq, self.attention_scaling = self.rope_init_fn( self.config, device, seq_len=seq_len ) self.register_buffer( "inv_freq", inv_freq, persistent=False ) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len if ( seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len ): # reset # This .to() is needed if the model has been moved to a device after being initialized (because # the buffer is automatically moved, but not the original copy) self.original_inv_freq = self.original_inv_freq.to(device) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() def forward(self, x, position_ids): if "dynamic" in self.rope_type: self._dynamic_frequency_update(position_ids, device=x.device) # Core RoPE block inv_freq_expanded = ( self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) ) position_ids_expanded = position_ids[:, None, :].float() # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = ( device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" ) with torch.autocast(device_type=device_type, enabled=False): freqs = ( inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float() ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention cos = cos * self.attention_scaling sin = sin * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class Gemma3TextScaledWordEmbedding(nn.Embedding): """ This module overrides nn.Embeddings' forward by multiplying with embeddings scale. """ def __init__( self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0, ): super().__init__(num_embeddings, embedding_dim, padding_idx) self.embed_scale = embed_scale def forward(self, input_ids: torch.Tensor): return super().forward(input_ids) * self.embed_scale class Gemma3TextModel(PreTrainedModel): def __init__( self, config: Gemma3TextConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__(config=config) self.config = config self.quant_config = quant_config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size # Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402 self.embed_tokens = Gemma3TextScaledWordEmbedding( config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5, ) self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Gemma3RotaryEmbedding(config=config) self.gradient_checkpointing = False # when we want to create a local RoPE layer. Config defaults should hold values for global RoPE config = copy.deepcopy(config) config.rope_theta = config.rope_local_base_freq config.rope_scaling = {"rope_type": "default"} self.rotary_emb_local = Gemma3RotaryEmbedding(config=config) self.layers = make_layers( config.num_hidden_layers, lambda idx, prefix: Gemma3DecoderLayer( layer_id=idx, config=config, quant_config=quant_config, prefix=prefix, ), prefix=add_prefix("layers", prefix), ) self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_init() def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, **kwargs, ) -> torch.Tensor: if input_embeds is None: hidden_states = self.embed_tokens(input_ids) else: hidden_states = input_embeds if positions.dim() == 1: positions = einops.rearrange(positions, "s -> 1 s") position_embeddings_global = self.rotary_emb(hidden_states, positions) position_embeddings_local = self.rotary_emb_local(hidden_states, positions) for layer in self.layers: layer_outputs = layer( positions=positions, position_embeddings_global=position_embeddings_global, position_embeddings_local=position_embeddings_local, hidden_states=hidden_states, forward_batch=forward_batch, **kwargs, ) hidden_states = layer_outputs[0] hidden_states = self.norm(hidden_states) return hidden_states class Gemma3ForCausalLM(PreTrainedModel): config_class = Gemma3TextConfig _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} config_class = Gemma3TextConfig base_model_prefix = "language_model" # BitandBytes specific attributes default_bitsandbytes_target_modules = [ ".gate_proj.", ".down_proj.", ".up_proj.", ".q_proj.", ".k_proj.", ".v_proj.", ".o_proj.", ] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), "k_proj": ("qkv_proj", 1), "v_proj": ("qkv_proj", 2), "gate_proj": ("gate_up_proj", 0), "up_proj": ("gate_up_proj", 1), } packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], "gate_up_proj": [ "gate_proj", "up_proj", ], } # LoRA specific attributes supported_lora_modules = [ "qkv_proj", "o_proj", "gate_up_proj", "down_proj", ] # Gemma does not apply LoRA to the embedding layer. embedding_modules = {} embedding_padding_modules = [] supports_lora = True def __init__( self, config: Gemma3TextConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__(config=config) self.config = config self.quant_config = quant_config self.model = Gemma3TextModel( config, quant_config, prefix=add_prefix("model", prefix) ) self.logits_processor = LogitsProcessor(config) if self.config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), ) self.post_init() def get_input_embeddings(self) -> nn.Embedding: return self.model.embed_tokens def get_attention_sliding_window_size(self): return get_attention_sliding_window_size(self.config) def dtype(self) -> torch.dtype: return next(self.parameters()).dtype @torch.no_grad() def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, **kwargs, ) -> LogitsProcessor: hidden_states = self.model( input_ids, positions, forward_batch, input_embeds, **kwargs ) return self.logits_processor( input_ids, hidden_states, self.model.embed_tokens, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: for param_name, shard_name, shard_id in stacked_params_mapping: # if param_name in name: # print(f"{param_name} is already in {name}") if shard_name not in name: continue name = name.replace(shard_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: # lm_head is not used in vllm as it is tied with embed_token. # To prevent errors, skip loading lm_head.weight. if "lm_head.weight" in name: continue # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) # unloaded_params = params_dict.keys() - loaded_params # if unloaded_params: # logger.warning( # "Some weights are not initialized from checkpoints: %s", unloaded_params # ) return loaded_params EntryClass = Gemma3ForCausalLM AutoModel.register(Gemma3TextConfig, Gemma3ForCausalLM, exist_ok=True)