# Copyright 2025 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== # Adapted from: # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py import logging from functools import lru_cache from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict import torch from torch import nn from transformers import ( AutoModel, BatchFeature, Gemma3Config, Gemma3Processor, PreTrainedModel, ) from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.layers.layernorm import Gemma3RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.managers.mm_utils import ( MultiModalityDataPaddingPatternTokenPairs, general_mm_embed_routine, ) from sglang.srt.managers.schedule_batch import MultimodalInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, ) from sglang.srt.models.gemma3_causal import Gemma3ForCausalLM from sglang.srt.utils import add_prefix logger = logging.getLogger(__name__) cached_get_processor = lru_cache(get_processor) class Gemma3ImagePixelInputs(TypedDict): pixel_values: torch.Tensor """Shape: `(batch_size * num_images, num_channels, height, width)`""" class Gemma3MultiModalProjector(nn.Module): """Projector for Gemma3 multimodal.""" def __init__(self, config: Gemma3Config): super().__init__() self.mm_input_projection_weight = nn.Parameter( torch.zeros( config.vision_config.hidden_size, config.text_config.hidden_size ) ) self.mm_soft_emb_norm = Gemma3RMSNorm( config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps ) self.patches_per_image = int( config.vision_config.image_size // config.vision_config.patch_size ) self.tokens_per_side = int(config.mm_tokens_per_image**0.5) self.kernel_size = self.patches_per_image // self.tokens_per_side self.avg_pool = nn.AvgPool2d( kernel_size=self.kernel_size, stride=self.kernel_size ) def forward(self, vision_outputs: torch.Tensor) -> torch.Tensor: batch_size, seq_length, hidden_size = vision_outputs.shape # Reshape for pooling reshaped_vision_outputs = vision_outputs.transpose(1, 2) reshaped_vision_outputs = reshaped_vision_outputs.reshape( batch_size, hidden_size, self.patches_per_image, self.patches_per_image ) reshaped_vision_outputs = reshaped_vision_outputs.contiguous() # Apply pooling pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) pooled_vision_outputs = pooled_vision_outputs.flatten(2) pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) # Apply normalization normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) # Project to text embedding space projected_vision_outputs = torch.matmul( normed_vision_outputs, self.mm_input_projection_weight ) return projected_vision_outputs.type_as(vision_outputs) class Gemma3ForConditionalGeneration(PreTrainedModel): config_class = Gemma3Config """Gemma3 multimodal model for conditional generation.""" # BitandBytes specific attributes default_bitsandbytes_target_modules = [ ".gate_proj.", ".down_proj.", ".up_proj.", ".q_proj.", ".k_proj.", ".v_proj.", ".o_proj.", ] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), "k_proj": ("qkv_proj", 1), "v_proj": ("qkv_proj", 2), "gate_proj": ("gate_up_proj", 0), "up_proj": ("gate_up_proj", 1), } packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], "gate_up_proj": [ "gate_proj", "up_proj", ], } # LoRA specific attributes supported_lora_modules = [ "qkv_proj", "o_proj", "gate_up_proj", "down_proj", ] # Gemma does not apply LoRA to the embedding layer. embedding_modules = {} embedding_padding_modules = [] supports_lora = True def __init__( self, config: Gemma3Config, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__(config=config) self.config = config self.quant_config = quant_config # Vision components # TODO: replace with vision attention # self.vision_tower = SiglipVisionModel( # config.vision_config, # quant_config, # prefix=add_prefix("vision_tower", prefix), # ) self.vision_tower = AutoModel.from_config(config=config.vision_config) self.multi_modal_projector = Gemma3MultiModalProjector(config) self.vocab_size = config.text_config.vocab_size # Text model self.language_model = Gemma3ForCausalLM( config.text_config, quant_config, prefix=add_prefix("model", prefix) ) if self.language_model.logits_processor.logit_scale: logit_scale = getattr(config, "logit_scale", 1.0) self.language_model.logits_processor.logit_scale *= logit_scale self.post_init() def pad_input_ids( self, input_ids: List[int], image_inputs: MultimodalInputs ) -> List[int]: """Pad input IDs with image tokens.""" # Get special token IDs im_start_id: int = image_inputs.im_start_id im_end_id: int = image_inputs.im_end_id media_token_pairs = [(im_start_id, im_end_id)] pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs) ids = pattern.pad_input_tokens(input_ids, image_inputs) return ids def prepare_attn_masks( self, input_ids: torch.Tensor, positions: torch.Tensor, mask_dtype: torch.dtype, **kwargs, ) -> Dict: """Prepare attention masks for multimodal inputs.""" kwargs["has_images"] = True # Distinguish sequences by position id 0 start_indices = (positions == 0).cpu().nonzero() num_seqs = len(start_indices) seq_lens = [] for i in range(num_seqs): start_idx = start_indices[i].item() if i < num_seqs - 1: end_idx = start_indices[i + 1].item() else: end_idx = len(input_ids) seq_lens.append(end_idx - start_idx) kwargs["seq_lens"] = seq_lens # Create attention masks global_attn_masks = [] local_attn_masks = [] sliding_window = self.config.text_config.interleaved_sliding_window start_idx = 0 for seq_len in seq_lens: end_idx = start_idx + seq_len input_token_ids = input_ids[start_idx:end_idx] start_idx = end_idx # Create global causal mask global_attn_mask = torch.empty( 1, 1, seq_len, seq_len, dtype=mask_dtype, device=input_ids.device, ) global_attn_mask.fill_(float("-inf")) global_attn_mask = global_attn_mask.triu(diagonal=1) # Consider bidirectional attention between image tokens img_mask = torch.zeros_like(global_attn_mask) img_pos = input_token_ids == self.config.image_token_index img_mask[:, :, :, img_pos] += 1 img_mask[:, :, img_pos, :] += 1 global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) global_attn_masks.append(global_attn_mask) # Create local causal mask with sliding window local_attn_mask = torch.ones_like(global_attn_mask) local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window) local_attn_mask = torch.where( local_attn_mask == 0, global_attn_mask, float("-inf") ) local_attn_masks.append(local_attn_mask) kwargs["global_attn_masks"] = global_attn_masks kwargs["local_attn_masks"] = local_attn_masks return kwargs def get_input_embeddings(self) -> nn.Embedding: return self.language_model.get_input_embeddings() def get_attention_sliding_window_size(self): """ This value is used to initialize attention backends in `ForwardBatch`. """ return self.language_model.get_attention_sliding_window_size() def get_image_feature(self, image_input: MultimodalInputs): """ Projects the last hidden state from the vision model into language model space. Args: pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) The tensors corresponding to the input images. Returns: image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ pixel_values = image_input.pixel_values pixel_values = pixel_values.to("cuda") pixel_values = pixel_values.to(dtype=self.language_model.dtype()) vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state image_features = self.multi_modal_projector(vision_outputs) return image_features def embed_mm_inputs( self, input_ids: torch.Tensor, forward_batch: ForwardBatch, image_input: MultimodalInputs, ) -> torch.Tensor: if input_ids is None: raise ValueError("Unimplemented") # boolean-masking image tokens special_image_mask = torch.isin( input_ids, torch.tensor(image_input.pad_values, device=input_ids.device), ).unsqueeze(-1) num_image_tokens_in_input_ids = special_image_mask.sum() inputs_embeds = None if num_image_tokens_in_input_ids == 0: inputs_embeds = self.get_input_embeddings()(input_ids) return inputs_embeds else: # print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}") image_features = self.get_image_feature(image_input.pixel_values) # print(f"image tokens from image embeddings: {image_features.numel()}") num_image_tokens_in_embedding = ( image_features.shape[0] * image_features.shape[1] ) if num_image_tokens_in_input_ids != num_image_tokens_in_embedding: num_image = num_image_tokens_in_input_ids // image_features.shape[1] image_features = image_features[:num_image, :] logger.warning( f"Number of images does not match number of special image tokens in the input text. " f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} " "tokens from image embeddings." ) # Important: clamp after extracting original image boundaries input_ids.clamp_(min=0, max=self.vocab_size - 1) inputs_embeds = self.get_input_embeddings()(input_ids) special_image_mask = special_image_mask.expand_as(inputs_embeds).to( inputs_embeds.device ) image_features = image_features.to( inputs_embeds.device, inputs_embeds.dtype ) inputs_embeds = inputs_embeds.masked_scatter( special_image_mask, image_features ) return inputs_embeds @torch.no_grad() def forward( self, input_ids: torch.LongTensor, positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, **kwargs: object, ) -> LogitsProcessor: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. logits_to_keep (`int` or `torch.Tensor`, *optional*): If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: Example: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf") >>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf") >>> prompt = "answer en Where is the cow standing?" >>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, text=prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(**inputs, max_length=30) >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "answer en Where is the cow standing?\nbeach" ```""" # Important: position_ids in Gemma3 are 1-indexed # This really does cost me sometime positions += 1 # Replace image id with PAD if the image token if OOV, to avoid index-errors if input_ids is not None and self.config.image_token_index >= self.vocab_size: special_image_mask = input_ids == self.config.image_token_index llm_input_ids = input_ids.clone() llm_input_ids[special_image_mask] = 0 else: llm_input_ids = input_ids inputs_embeds = general_mm_embed_routine( input_ids=llm_input_ids, forward_batch=forward_batch, embed_tokens=self.get_input_embeddings(), mm_data_embedding_func=self.get_image_feature, ) outputs = self.language_model( input_ids=None, positions=positions, forward_batch=forward_batch, input_embeds=inputs_embeds, **kwargs, ) return outputs def tie_weights(self): return self.language_model.tie_weights() def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): """Load weights for the model.""" params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: if "language_model" in name: # Gemma3ForCausalLM.load_weights(self, [(name.replace("language_model.", ""), loaded_weight)]) causal_loaded_params = Gemma3ForCausalLM.load_weights( self, [(name, loaded_weight)] ) loaded_params.update(causal_loaded_params) continue else: # Skip lm_head.weight as it's tied with embed_tokens if "lm_head.weight" in name: continue # Skip loading extra bias for GPTQ models if name.endswith(".bias") and name not in params_dict: continue # Remapping the name of FP8 kv-scale name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) unloaded_params = params_dict.keys() - loaded_params if unloaded_params: pass # raise RuntimeError( # f"Some weights are not initialized from checkpoints: {unloaded_params}") return loaded_params EntryClass = Gemma3ForConditionalGeneration AutoModel.register(Gemma3Config, Gemma3ForConditionalGeneration, exist_ok=True)