# coding=utf-8 # Adapted from # https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py # Copyright 2024 The Qwen team. # Copyright 2023 The vLLM team. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" import logging from functools import lru_cache, partial from typing import Iterable, List, Optional, Tuple, Type, TypedDict import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from transformers import Qwen2VLConfig from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.layers.activation import QuickGELU from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.managers.mm_utils import ( MultiModalityDataPaddingPatternTokenPairs, general_mm_embed_routine, ) from sglang.srt.managers.schedule_batch import MultimodalInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2 import Qwen2Model from sglang.srt.utils import add_prefix logger = logging.getLogger(__name__) # === Vision Inputs === # class Qwen2VLImageInputs(TypedDict): pixel_values: torch.Tensor """Shape: `(num_patches, num_channels * patch_size * patch_size)` """ image_grid_thw: torch.Tensor """Shape: `(num_images, 3)` This should be in `(grid_t, grid_h, grid_w)` format. """ class Qwen2VLVideoInputs(TypedDict): pixel_values_videos: torch.Tensor """Shape: `(num_patches, num_channels * temporal_patch_size * patch_size * patch_size)` """ video_grid_thw: torch.Tensor """Shape: `(num_videos, 3)` This should be in `(grid_t, grid_h, grid_w)` format. """ # === Vision Encoder === # class Qwen2VisionMLP(nn.Module): def __init__( self, in_features: int, hidden_features: int = None, act_layer: Type[nn.Module] = QuickGELU, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.fc1 = ColumnParallelLinear( in_features, hidden_features, quant_config=quant_config, prefix=add_prefix("fc1", prefix), ) self.act = act_layer() self.fc2 = RowParallelLinear( hidden_features, in_features, quant_config=quant_config, prefix=add_prefix("fc2", prefix), ) def forward(self, x: torch.Tensor) -> torch.Tensor: x_parallel, _ = self.fc1(x) x_parallel = self.act(x_parallel) x, _ = self.fc2(x_parallel) return x class Qwen2VisionBlock(nn.Module): def __init__( self, dim: int, num_heads: int, mlp_ratio: float, act_layer: Type[nn.Module] = QuickGELU, norm_layer: Type[nn.Module] = None, attn_implementation: Optional[str] = "sdpa", quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.norm1 = norm_layer(dim) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) if attn_implementation == "sdpa": use_context_forward = False softmax_in_single_precision = False elif attn_implementation == "flash_attention_2": softmax_in_single_precision = False use_context_forward = True elif attn_implementation == "eager": softmax_in_single_precision = True use_context_forward = False self.attn = VisionAttention( embed_dim=dim, num_heads=num_heads, projection_size=dim, use_qkv_parallel=False, use_context_forward=use_context_forward, softmax_in_single_precision=softmax_in_single_precision, flatten_batch=True, quant_config=quant_config, prefix=add_prefix("attn", prefix), ) self.mlp = Qwen2VisionMLP( dim, mlp_hidden_dim, act_layer=act_layer, quant_config=quant_config, prefix=add_prefix("mlp", prefix), ) def forward( self, x: torch.Tensor, cu_seqlens: torch.Tensor, position_embeddings: torch.Tensor, ) -> torch.Tensor: hidden_states = self.norm1(x) hidden_states = rearrange(hidden_states, "s b ... -> b s ...") attn = self.attn( hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, ) attn = rearrange(attn, "b s ... -> s b ...") x = x + attn x = x + self.mlp(self.norm2(x)) return x class Qwen2VisionPatchEmbed(nn.Module): def __init__( self, patch_size: int = 14, temporal_patch_size: int = 2, in_chans: int = 3, embed_dim: int = 1152, ) -> None: super().__init__() self.patch_size = patch_size self.temporal_patch_size = temporal_patch_size self.embed_dim = embed_dim kernel_size = [temporal_patch_size, patch_size, patch_size] self.proj = nn.Conv3d( in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False ) def forward(self, x: torch.Tensor) -> torch.Tensor: L, C = x.shape x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) x = self.proj(x).view(L, self.embed_dim) return x class Qwen2VisionPatchMerger(nn.Module): def __init__( self, d_model: int, context_dim: int, norm_layer: Type[nn.Module] = None, spatial_merge_size: int = 2, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.hidden_size = context_dim * (spatial_merge_size**2) if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.ln_q = norm_layer(context_dim) self.mlp = nn.ModuleList( [ ColumnParallelLinear( self.hidden_size, self.hidden_size, bias=True, quant_config=quant_config, prefix=add_prefix("mlp.0", prefix), ), nn.GELU(), RowParallelLinear( self.hidden_size, d_model, bias=True, quant_config=quant_config, prefix=add_prefix("mlp.2", prefix), ), ] ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.ln_q(x) x = x.view(-1, self.hidden_size) mlp_fc1, mlp_act, mlp_fc2 = self.mlp x_parallel, _ = mlp_fc1(x) x_parallel = mlp_act(x_parallel) out, _ = mlp_fc2(x_parallel) return out class Qwen2VisionRotaryEmbedding(nn.Module): def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() self.dim = dim self.theta = theta inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._seq_len_cached = 0 self._freqs_cached = None def update_freqs_cache(self, seqlen: int) -> None: if seqlen > self._seq_len_cached: seqlen *= 2 self._seq_len_cached = seqlen self.inv_freq = 1.0 / ( self.theta ** ( torch.arange( 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device ) / self.dim ) ) seq = torch.arange( seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype ) freqs = torch.outer(seq, self.inv_freq) self._freqs_cached = freqs def forward(self, seqlen: int) -> torch.Tensor: self.update_freqs_cache(seqlen) return self._freqs_cached[:seqlen] class Qwen2VisionTransformer(nn.Module): def __init__( self, vision_config: Qwen2VLVisionConfig, norm_eps: float = 1e-6, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() patch_size: int = vision_config.patch_size temporal_patch_size: int = vision_config.temporal_patch_size spatial_merge_size: int = vision_config.spatial_merge_size in_chans: int = vision_config.in_chans hidden_size: int = vision_config.hidden_size embed_dim: int = vision_config.embed_dim depth: int = vision_config.depth num_heads: int = vision_config.num_heads mlp_ratio: float = vision_config.mlp_ratio self.spatial_merge_size = spatial_merge_size self.patch_embed = Qwen2VisionPatchEmbed( patch_size=patch_size, temporal_patch_size=temporal_patch_size, in_chans=in_chans, embed_dim=embed_dim, ) norm_layer = partial(nn.LayerNorm, eps=norm_eps) head_dim = embed_dim // num_heads self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2) self.blocks = nn.ModuleList( [ Qwen2VisionBlock( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, norm_layer=norm_layer, attn_implementation="sdpa", quant_config=quant_config, prefix=add_prefix(f"blocks.{i}", prefix), ) for i in range(depth) ] ) self.merger = Qwen2VisionPatchMerger( d_model=hidden_size, context_dim=embed_dim, norm_layer=norm_layer, quant_config=quant_config, prefix=add_prefix("merger", prefix), ) @property def dtype(self) -> torch.dtype: return next(self.parameters()).dtype @property def device(self) -> torch.device: return self.blocks[0].mlp.fc2.weight.device def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: pos_ids = [] for i in range(grid_thw.size(0)): t, h, w = grid_thw[i].tolist() hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) hpos_ids = ( hpos_ids.reshape( h // self.spatial_merge_size, self.spatial_merge_size, w // self.spatial_merge_size, self.spatial_merge_size, ) .permute(0, 2, 1, 3) .flatten() ) wpos_ids = ( wpos_ids.reshape( h // self.spatial_merge_size, self.spatial_merge_size, w // self.spatial_merge_size, self.spatial_merge_size, ) .permute(0, 2, 1, 3) .flatten() ) pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb def forward( self, x: torch.Tensor, grid_thw: torch.Tensor, ) -> torch.Tensor: # patchify x = x.to(device=self.device, dtype=self.dtype) x = self.patch_embed(x) # compute position embedding rotary_pos_emb = self.rot_pos_emb(grid_thw) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) # compute cu_seqlens cu_seqlens = torch.repeat_interleave( grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] ).cumsum(dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) # transformers x = x.unsqueeze(1) for blk in self.blocks: x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings) # adapter x = self.merger(x) return x cached_get_processor = lru_cache(get_processor) class Qwen2VLForConditionalGeneration(nn.Module): def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]): processor = cached_get_processor(self.config._name_or_path) grid_t, grid_h, grid_w = image_grid_thw num_image_tokens = ( grid_t * grid_h * grid_w // processor.image_processor.merge_size // processor.image_processor.merge_size ) return num_image_tokens def __init__( self, config: Qwen2VLConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.config = config self.visual = Qwen2VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), # NOTE: Qwen2-VL vision encoder does not support any # quantization method now. quant_config=None, prefix=add_prefix("visual", prefix), ) self.model = Qwen2Model( config, quant_config, prefix=add_prefix("model", prefix) ) if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), ) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) # Use grid_t * grid_w * grid_h to pad tokens for each image # add replaced padding by unique image hash def pad_input_ids(self, input_ids: List[int], multi_modal_inputs: MultimodalInputs): # Get all special token IDs im_start_id: int = multi_modal_inputs.im_start_id im_end_id: int = multi_modal_inputs.im_end_id media_token_pairs = [(im_start_id, im_end_id)] pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs) return pattern.pad_input_tokens(input_ids, multi_modal_inputs) def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor: pixel_values = image_input.pixel_values.type(self.visual.dtype) image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws) return image_embeds def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor: pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype) video_embeds = self.visual( pixel_values_videos, grid_thw=video_input["video_grid_thw"] ) return video_embeds def get_input_embeddings(self): return self.model.embed_tokens def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, get_embedding: bool = False, ): """Run forward pass for Qwen2-VL. Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. positions: Flattened (concatenated) position ids corresponding to a batch. **NOTE**: If mrope is enabled (default setting for Qwen2-VL opensource models), the shape will be `(3, seq_len)`, otherwise it will be `(seq_len,). (Use input_metadata.mrope_positions to replace it) """ if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": positions = forward_batch.mrope_positions if not ( forward_batch.forward_mode.is_decode() or not forward_batch.contains_image_inputs() ): if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": assert positions.ndim == 2 and positions.size(0) == 3, ( "multimodal section rotary embedding requires " f"(3, seq_len) positions, but got {positions.size()}" ) inputs_embeds = general_mm_embed_routine( input_ids=input_ids, forward_batch=forward_batch, embed_tokens=self.get_input_embeddings(), mm_data_embedding_func=self.get_image_feature, ) hidden_states = self.model( input_ids=None, positions=positions, forward_batch=forward_batch, input_embeds=inputs_embeds, ) if not get_embedding: return self.logits_processor( input_ids, hidden_states, self.lm_head, forward_batch ) else: return self.pooler(hidden_states, forward_batch) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "up_proj", 1), ("gate_up_proj", "gate_proj", 0), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if self.config.tie_word_embeddings and "lm_head.weight" in name: continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: if "visual" in name and "qkv.weight" in name: visual_num_heads = self.config.vision_config.num_heads visual_embed_dim = self.config.vision_config.embed_dim head_size = visual_embed_dim // visual_num_heads loaded_weight = loaded_weight.view( 3, visual_num_heads, head_size, visual_embed_dim ) loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.reshape(-1, visual_embed_dim) elif "visual" in name and "qkv.bias" in name: visual_num_heads = self.config.vision_config.num_heads visual_embed_dim = self.config.vision_config.embed_dim head_size = visual_embed_dim // visual_num_heads loaded_weight = loaded_weight.view(3, visual_num_heads, head_size) loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.reshape(-1) if "visual" in name: # adapt to VisionAttention name = name.replace(r"attn.qkv.", r"attn.qkv_proj.") try: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] except KeyError: print(params_dict.keys()) raise weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) EntryClass = Qwen2VLForConditionalGeneration