658 lines
24 KiB
Python
658 lines
24 KiB
Python
# coding=utf-8
|
|
# Adapted from
|
|
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
|
|
# Copyright 2024 The Qwen team.
|
|
# Copyright 2023 The vLLM team.
|
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
|
# and OPT implementations in this library. It has been modified from its
|
|
# original forms to accommodate minor architectural differences compared
|
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
|
import logging
|
|
from functools import lru_cache, partial
|
|
from typing import Iterable, List, Optional, Tuple, Type
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
from transformers import AutoModel, Qwen2VLConfig
|
|
from transformers.activations import ACT2FN
|
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
|
|
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
|
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
|
Qwen2_5_VLConfig,
|
|
Qwen2_5_VLVisionConfig,
|
|
)
|
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
|
Qwen2_5_VLForConditionalGeneration,
|
|
)
|
|
|
|
from sglang.srt.distributed import (
|
|
get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
)
|
|
from sglang.srt.hf_transformers_utils import get_processor
|
|
from sglang.srt.layers.attention.vision import VisionAttention
|
|
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
|
from sglang.srt.managers.mm_utils import (
|
|
MultiModalityDataPaddingPatternTokenPairs,
|
|
general_mm_embed_routine,
|
|
)
|
|
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
from sglang.srt.models.qwen2 import Qwen2Model
|
|
from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
|
|
from sglang.srt.utils import add_prefix
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Qwen2_5_VLMLP(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
hidden_features: int = None,
|
|
bias: bool = True,
|
|
hidden_act="silu",
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.gate_proj = ColumnParallelLinear(
|
|
in_features,
|
|
hidden_features,
|
|
bias=bias,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("gate_proj", prefix),
|
|
)
|
|
self.up_proj = ColumnParallelLinear(
|
|
in_features,
|
|
hidden_features,
|
|
bias=bias,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("up_proj", prefix),
|
|
)
|
|
self.down_proj = RowParallelLinear(
|
|
hidden_features,
|
|
in_features,
|
|
bias=bias,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("down_proj", prefix),
|
|
)
|
|
self.act = ACT2FN[hidden_act]
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x_parallel_gate, _ = self.gate_proj(x)
|
|
x_parallel_gate = self.act(x_parallel_gate)
|
|
x_parallel_up, _ = self.up_proj(x)
|
|
x_parallel = x_parallel_gate * x_parallel_up
|
|
x, _ = self.down_proj(x_parallel)
|
|
return x
|
|
|
|
|
|
class Qwen2_5_VisionBlock(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
intermediate_dim: int,
|
|
num_heads: int,
|
|
hidden_act="silu",
|
|
norm_layer: Type[nn.Module] = None,
|
|
attn_implementation: Optional[str] = "sdpa",
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
if norm_layer is None:
|
|
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
|
self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
|
|
self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
|
|
if attn_implementation == "sdpa":
|
|
use_context_forward = False
|
|
softmax_in_single_precision = False
|
|
flatten_batch = True
|
|
elif attn_implementation == "flash_attention_2":
|
|
softmax_in_single_precision = False
|
|
use_context_forward = True
|
|
flatten_batch = True
|
|
elif attn_implementation == "eager":
|
|
softmax_in_single_precision = True
|
|
use_context_forward = False
|
|
flatten_batch = True
|
|
|
|
self.attn = VisionAttention(
|
|
embed_dim=dim,
|
|
num_heads=num_heads,
|
|
projection_size=dim,
|
|
use_qkv_parallel=False,
|
|
use_context_forward=use_context_forward,
|
|
softmax_in_single_precision=softmax_in_single_precision,
|
|
flatten_batch=flatten_batch,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("attn", prefix),
|
|
)
|
|
self.mlp = Qwen2_5_VLMLP(
|
|
dim,
|
|
intermediate_dim,
|
|
hidden_act=hidden_act,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("mlp", prefix),
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
cu_seqlens: torch.Tensor,
|
|
position_embeddings: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
hidden_states = self.norm1(x)
|
|
hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
|
|
attn = self.attn(
|
|
hidden_states,
|
|
cu_seqlens=cu_seqlens,
|
|
position_embeddings=position_embeddings,
|
|
)
|
|
attn = rearrange(attn, "b s ... -> s b ...")
|
|
x = x + attn
|
|
norm2 = self.norm2(x)
|
|
mlp = self.mlp(norm2)
|
|
x = x + mlp
|
|
return x
|
|
|
|
|
|
class Qwen2_5_VisionPatchEmbed(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
patch_size: int = 14,
|
|
temporal_patch_size: int = 2,
|
|
in_chans: int = 3,
|
|
embed_dim: int = 1152,
|
|
) -> None:
|
|
super().__init__()
|
|
self.patch_size = patch_size
|
|
self.temporal_patch_size = temporal_patch_size
|
|
self.embed_dim = embed_dim
|
|
|
|
kernel_size = [temporal_patch_size, patch_size, patch_size]
|
|
self.proj = nn.Conv3d(
|
|
in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
target_dtype = self.proj.weight.dtype
|
|
L, C = x.shape
|
|
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
|
|
x = self.proj(x.to(dtype=target_dtype)).view(L, self.embed_dim)
|
|
return x
|
|
|
|
|
|
class Qwen2_5_VisionPatchMerger(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
context_dim: int,
|
|
spatial_merge_size: int = 2,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.hidden_size = context_dim * (spatial_merge_size**2)
|
|
self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
|
|
self.mlp = nn.ModuleList(
|
|
[
|
|
ColumnParallelLinear(
|
|
self.hidden_size,
|
|
self.hidden_size,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("mlp.0", prefix),
|
|
),
|
|
nn.GELU(),
|
|
RowParallelLinear(
|
|
self.hidden_size,
|
|
dim,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("mlp.2", prefix),
|
|
),
|
|
]
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.ln_q(x)
|
|
x = x.view(-1, self.hidden_size)
|
|
|
|
mlp_fc1, mlp_act, mlp_fc2 = self.mlp
|
|
x_parallel, _ = mlp_fc1(x)
|
|
x_parallel = mlp_act(x_parallel)
|
|
out, _ = mlp_fc2(x_parallel)
|
|
return out
|
|
|
|
|
|
class Qwen2_5_VisionRotaryEmbedding(nn.Module):
|
|
|
|
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
|
super().__init__()
|
|
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
|
|
def forward(self, seqlen: int) -> torch.Tensor:
|
|
seq = torch.arange(
|
|
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
|
|
)
|
|
freqs = torch.outer(seq, self.inv_freq)
|
|
return freqs
|
|
|
|
|
|
class Qwen2_5_VisionTransformer(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
vision_config: Qwen2_5_VLVisionConfig,
|
|
norm_eps: float = 1e-6,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
patch_size: int = vision_config.patch_size
|
|
temporal_patch_size: int = vision_config.temporal_patch_size
|
|
spatial_merge_size: int = vision_config.spatial_merge_size
|
|
self.spatial_merge_size = spatial_merge_size
|
|
self.spatial_merge_unit: int = spatial_merge_size * spatial_merge_size
|
|
in_chans: int = vision_config.in_channels
|
|
hidden_size: int = vision_config.hidden_size
|
|
depth: int = vision_config.depth
|
|
num_heads: int = vision_config.num_heads
|
|
self.fullatt_block_indexes = vision_config.fullatt_block_indexes
|
|
self.window_size = vision_config.window_size
|
|
self.patch_size = vision_config.patch_size
|
|
mlp_hidden_size: int = vision_config.intermediate_size
|
|
self.patch_embed = Qwen2_5_VisionPatchEmbed(
|
|
patch_size=patch_size,
|
|
temporal_patch_size=temporal_patch_size,
|
|
in_chans=in_chans,
|
|
embed_dim=hidden_size,
|
|
)
|
|
|
|
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
|
|
head_dim = hidden_size // num_heads
|
|
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
|
self.blocks = nn.ModuleList(
|
|
[
|
|
Qwen2_5_VisionBlock(
|
|
dim=hidden_size,
|
|
intermediate_dim=mlp_hidden_size,
|
|
num_heads=num_heads,
|
|
hidden_act=vision_config.hidden_act,
|
|
norm_layer=norm_layer,
|
|
attn_implementation="sdpa",
|
|
quant_config=quant_config,
|
|
prefix=add_prefix(f"blocks.{i}", prefix),
|
|
)
|
|
for i in range(depth)
|
|
]
|
|
)
|
|
self.merger = Qwen2_5_VisionPatchMerger(
|
|
dim=vision_config.out_hidden_size,
|
|
context_dim=hidden_size,
|
|
spatial_merge_size=spatial_merge_size,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("merger", prefix),
|
|
)
|
|
|
|
def get_window_index(self, grid_thw):
|
|
cu_window_seqlens: list = [0]
|
|
window_index_id = 0
|
|
vit_merger_window_size = (
|
|
self.window_size // self.spatial_merge_size // self.patch_size
|
|
)
|
|
window_index: list = []
|
|
for grid_t, grid_h, grid_w in grid_thw:
|
|
llm_grid_h, llm_grid_w = (
|
|
grid_h // self.spatial_merge_size,
|
|
grid_w // self.spatial_merge_size,
|
|
)
|
|
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
|
|
grid_t, llm_grid_h, llm_grid_w
|
|
)
|
|
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
|
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
|
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
|
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
|
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
|
|
index_padded = index_padded.reshape(
|
|
grid_t,
|
|
num_windows_h,
|
|
vit_merger_window_size,
|
|
num_windows_w,
|
|
vit_merger_window_size,
|
|
)
|
|
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
|
grid_t,
|
|
num_windows_h * num_windows_w,
|
|
vit_merger_window_size,
|
|
vit_merger_window_size,
|
|
)
|
|
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
|
index_padded = index_padded.reshape(-1)
|
|
index_new = index_padded[index_padded != -100]
|
|
window_index.append(index_new + window_index_id)
|
|
cu_seqlens_tmp = (
|
|
seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
|
|
)
|
|
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
|
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
|
window_index = torch.cat(window_index, dim=0)
|
|
return window_index, cu_window_seqlens
|
|
|
|
@property
|
|
def dtype(self) -> torch.dtype:
|
|
return self.blocks[0].mlp.gate_proj.weight.dtype
|
|
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return self.blocks[0].mlp.gate_proj.weight.device
|
|
|
|
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
|
pos_ids = []
|
|
for i in range(grid_thw.size(0)):
|
|
t, h, w = grid_thw[i].tolist()
|
|
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
|
|
|
hpos_ids = hpos_ids.reshape(
|
|
h // self.spatial_merge_size,
|
|
self.spatial_merge_size,
|
|
w // self.spatial_merge_size,
|
|
self.spatial_merge_size,
|
|
)
|
|
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
|
hpos_ids = hpos_ids.flatten()
|
|
|
|
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
|
wpos_ids = wpos_ids.reshape(
|
|
h // self.spatial_merge_size,
|
|
self.spatial_merge_size,
|
|
w // self.spatial_merge_size,
|
|
self.spatial_merge_size,
|
|
)
|
|
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
|
wpos_ids = wpos_ids.flatten()
|
|
|
|
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
|
pos_ids = torch.cat(pos_ids, dim=0)
|
|
max_grid_size = grid_thw[:, 1:].max()
|
|
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
|
return rotary_pos_emb
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
grid_thw: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
# patchify
|
|
x = x.to(device=self.device, dtype=self.dtype)
|
|
x = self.patch_embed(x)
|
|
|
|
# compute position embedding
|
|
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
|
|
|
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
|
|
cu_window_seqlens = torch.tensor(
|
|
cu_window_seqlens,
|
|
device=x.device,
|
|
dtype=torch.int32,
|
|
)
|
|
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
|
|
|
seq_len, _ = x.size()
|
|
|
|
x = x.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
|
x = x[window_index, :, :]
|
|
x = x.reshape(seq_len, -1)
|
|
rotary_pos_emb = rotary_pos_emb.reshape(
|
|
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
|
|
)
|
|
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
|
|
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
|
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
|
position_embeddings = (emb.cos(), emb.sin())
|
|
|
|
# compute cu_seqlens
|
|
cu_seqlens = torch.cat(
|
|
[
|
|
torch.tensor([0], device=grid_thw.device),
|
|
(grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0),
|
|
]
|
|
)
|
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
|
|
|
# transformers
|
|
x = x.unsqueeze(1)
|
|
for layer_num, blk in enumerate(self.blocks):
|
|
if layer_num in self.fullatt_block_indexes:
|
|
cu_seqlens_now = cu_seqlens
|
|
else:
|
|
cu_seqlens_now = cu_window_seqlens
|
|
x = blk(
|
|
x, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings
|
|
)
|
|
|
|
# adapter
|
|
x = self.merger(x)
|
|
|
|
reverse_indices = torch.argsort(window_index)
|
|
x = x[reverse_indices, :]
|
|
|
|
return x
|
|
|
|
|
|
cached_get_processor = lru_cache(get_processor)
|
|
|
|
|
|
class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: Qwen2VLConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.config = config
|
|
self.visual = Qwen2_5_VisionTransformer(
|
|
config.vision_config,
|
|
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
|
# NOTE: Qwen2-VL vision encoder does not support any
|
|
# quantization method now.
|
|
quant_config=None,
|
|
prefix=add_prefix("visual", prefix),
|
|
)
|
|
|
|
self.model = Qwen2Model(
|
|
config,
|
|
quant_config,
|
|
prefix=add_prefix("model", prefix),
|
|
)
|
|
|
|
if config.tie_word_embeddings:
|
|
self.lm_head = self.model.embed_tokens
|
|
else:
|
|
self.lm_head = ParallelLMHead(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("lm_head", prefix),
|
|
)
|
|
|
|
self.logits_processor = LogitsProcessor(config)
|
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
|
|
|
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
|
# Get all special token IDs
|
|
im_start_id: int = image_inputs.im_start_id
|
|
im_end_id: int = image_inputs.im_end_id
|
|
|
|
media_token_pairs = [(im_start_id, im_end_id)]
|
|
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
|
|
|
return pattern.pad_input_tokens(input_ids, image_inputs)
|
|
|
|
def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
|
|
pixel_values = image_input.pixel_values.type(self.visual.dtype)
|
|
image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
|
|
return image_embeds
|
|
|
|
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
|
pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
|
|
video_embeds = self.visual(
|
|
pixel_values_videos, grid_thw=video_input["video_grid_thw"]
|
|
)
|
|
return video_embeds
|
|
|
|
def get_input_embeddings(self):
|
|
return self.model.embed_tokens
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
get_embedding: bool = False,
|
|
):
|
|
"""Run forward pass for Qwen2_5-VL.
|
|
|
|
Args:
|
|
input_ids: Flattened (concatenated) input_ids corresponding to a
|
|
batch.
|
|
positions: Flattened (concatenated) position ids corresponding to a
|
|
batch.
|
|
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
|
|
opensource models), the shape will be `(3, seq_len)`,
|
|
otherwise it will be `(seq_len,).
|
|
(Use input_metadata.mrope_positions to replace it)
|
|
"""
|
|
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
|
positions = forward_batch.mrope_positions
|
|
|
|
if not (
|
|
forward_batch.forward_mode.is_decode()
|
|
or not forward_batch.contains_image_inputs()
|
|
):
|
|
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
|
assert positions.ndim == 2 and positions.size(0) == 3, (
|
|
"multimodal section rotary embedding requires "
|
|
f"(3, seq_len) positions, but got {positions.size()}"
|
|
)
|
|
|
|
inputs_embeds = general_mm_embed_routine(
|
|
input_ids=input_ids,
|
|
forward_batch=forward_batch,
|
|
embed_tokens=self.get_input_embeddings(),
|
|
mm_data_embedding_func=self.get_image_feature,
|
|
)
|
|
|
|
hidden_states = self.model(
|
|
input_ids=None,
|
|
positions=positions,
|
|
forward_batch=forward_batch,
|
|
input_embeds=inputs_embeds,
|
|
)
|
|
|
|
if not get_embedding:
|
|
return self.logits_processor(
|
|
input_ids, hidden_states, self.lm_head, forward_batch
|
|
)
|
|
else:
|
|
return self.pooler(hidden_states, forward_batch)
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("qkv_proj", "q_proj", "q"),
|
|
("qkv_proj", "k_proj", "k"),
|
|
("qkv_proj", "v_proj", "v"),
|
|
("gate_up_proj", "up_proj", 1),
|
|
("gate_up_proj", "gate_proj", 0),
|
|
]
|
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
|
for name, loaded_weight in weights:
|
|
if "rotary_emb.inv_freq" in name:
|
|
continue
|
|
|
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
if "visual" in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
if "visual" in name and "qkv.weight" in name:
|
|
visual_num_heads = self.config.vision_config.num_heads
|
|
visual_embed_dim = self.config.vision_config.hidden_size
|
|
head_size = visual_embed_dim // visual_num_heads
|
|
loaded_weight = loaded_weight.view(
|
|
3, visual_num_heads, head_size, visual_embed_dim
|
|
)
|
|
loaded_weight = loaded_weight.transpose(0, 1)
|
|
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
|
|
elif "visual" in name and "qkv.bias" in name:
|
|
visual_num_heads = self.config.vision_config.num_heads
|
|
visual_embed_dim = self.config.vision_config.hidden_size
|
|
head_size = visual_embed_dim // visual_num_heads
|
|
loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
|
|
loaded_weight = loaded_weight.transpose(0, 1)
|
|
loaded_weight = loaded_weight.reshape(-1)
|
|
|
|
if "visual" in name:
|
|
# adapt to VisionAttention
|
|
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
|
|
|
try:
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
param = params_dict[name]
|
|
except KeyError:
|
|
print(params_dict.keys())
|
|
raise
|
|
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
|
|
|
|
EntryClass = [Qwen2_5_VLForConditionalGeneration]
|