sglang0.4.5.post1/python/sglang/srt/models/qwen2_vl.py

617 lines
22 KiB
Python

# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
import logging
from functools import lru_cache, partial
from typing import Iterable, List, Optional, Tuple, Type, TypedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import Qwen2VLConfig
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.activation import QuickGELU
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__)
# === Vision Inputs === #
class Qwen2VLImageInputs(TypedDict):
pixel_values: torch.Tensor
"""Shape:
`(num_patches, num_channels * patch_size * patch_size)`
"""
image_grid_thw: torch.Tensor
"""Shape: `(num_images, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
"""
class Qwen2VLVideoInputs(TypedDict):
pixel_values_videos: torch.Tensor
"""Shape:
`(num_patches,
num_channels * temporal_patch_size * patch_size * patch_size)`
"""
video_grid_thw: torch.Tensor
"""Shape: `(num_videos, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
"""
# === Vision Encoder === #
class Qwen2VisionMLP(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int = None,
act_layer: Type[nn.Module] = QuickGELU,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.fc1 = ColumnParallelLinear(
in_features,
hidden_features,
quant_config=quant_config,
prefix=add_prefix("fc1", prefix),
)
self.act = act_layer()
self.fc2 = RowParallelLinear(
hidden_features,
in_features,
quant_config=quant_config,
prefix=add_prefix("fc2", prefix),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_parallel, _ = self.fc1(x)
x_parallel = self.act(x_parallel)
x, _ = self.fc2(x_parallel)
return x
class Qwen2VisionBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float,
act_layer: Type[nn.Module] = QuickGELU,
norm_layer: Type[nn.Module] = None,
attn_implementation: Optional[str] = "sdpa",
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
if attn_implementation == "sdpa":
use_context_forward = False
softmax_in_single_precision = False
elif attn_implementation == "flash_attention_2":
softmax_in_single_precision = False
use_context_forward = True
elif attn_implementation == "eager":
softmax_in_single_precision = True
use_context_forward = False
self.attn = VisionAttention(
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
use_qkv_parallel=False,
use_context_forward=use_context_forward,
softmax_in_single_precision=softmax_in_single_precision,
flatten_batch=True,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
self.mlp = Qwen2VisionMLP(
dim,
mlp_hidden_dim,
act_layer=act_layer,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
position_embeddings: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.norm1(x)
hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
attn = self.attn(
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
)
attn = rearrange(attn, "b s ... -> s b ...")
x = x + attn
x = x + self.mlp(self.norm2(x))
return x
class Qwen2VisionPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 14,
temporal_patch_size: int = 2,
in_chans: int = 3,
embed_dim: int = 1152,
) -> None:
super().__init__()
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.embed_dim = embed_dim
kernel_size = [temporal_patch_size, patch_size, patch_size]
self.proj = nn.Conv3d(
in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.embed_dim)
return x
class Qwen2VisionPatchMerger(nn.Module):
def __init__(
self,
d_model: int,
context_dim: int,
norm_layer: Type[nn.Module] = None,
spatial_merge_size: int = 2,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.ln_q = norm_layer(context_dim)
self.mlp = nn.ModuleList(
[
ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=add_prefix("mlp.0", prefix),
),
nn.GELU(),
RowParallelLinear(
self.hidden_size,
d_model,
bias=True,
quant_config=quant_config,
prefix=add_prefix("mlp.2", prefix),
),
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.ln_q(x)
x = x.view(-1, self.hidden_size)
mlp_fc1, mlp_act, mlp_fc2 = self.mlp
x_parallel, _ = mlp_fc1(x)
x_parallel = mlp_act(x_parallel)
out, _ = mlp_fc2(x_parallel)
return out
class Qwen2VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
self.dim = dim
self.theta = theta
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._seq_len_cached = 0
self._freqs_cached = None
def update_freqs_cache(self, seqlen: int) -> None:
if seqlen > self._seq_len_cached:
seqlen *= 2
self._seq_len_cached = seqlen
self.inv_freq = 1.0 / (
self.theta
** (
torch.arange(
0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device
)
/ self.dim
)
)
seq = torch.arange(
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
)
freqs = torch.outer(seq, self.inv_freq)
self._freqs_cached = freqs
def forward(self, seqlen: int) -> torch.Tensor:
self.update_freqs_cache(seqlen)
return self._freqs_cached[:seqlen]
class Qwen2VisionTransformer(nn.Module):
def __init__(
self,
vision_config: Qwen2VLVisionConfig,
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
patch_size: int = vision_config.patch_size
temporal_patch_size: int = vision_config.temporal_patch_size
spatial_merge_size: int = vision_config.spatial_merge_size
in_chans: int = vision_config.in_chans
hidden_size: int = vision_config.hidden_size
embed_dim: int = vision_config.embed_dim
depth: int = vision_config.depth
num_heads: int = vision_config.num_heads
mlp_ratio: float = vision_config.mlp_ratio
self.spatial_merge_size = spatial_merge_size
self.patch_embed = Qwen2VisionPatchEmbed(
patch_size=patch_size,
temporal_patch_size=temporal_patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
head_dim = embed_dim // num_heads
self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList(
[
Qwen2VisionBlock(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
attn_implementation="sdpa",
quant_config=quant_config,
prefix=add_prefix(f"blocks.{i}", prefix),
)
for i in range(depth)
]
)
self.merger = Qwen2VisionPatchMerger(
d_model=hidden_size,
context_dim=embed_dim,
norm_layer=norm_layer,
quant_config=quant_config,
prefix=add_prefix("merger", prefix),
)
@property
def dtype(self) -> torch.dtype:
return next(self.parameters()).dtype
@property
def device(self) -> torch.device:
return self.blocks[0].mlp.fc2.weight.device
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
pos_ids = []
for i in range(grid_thw.size(0)):
t, h, w = grid_thw[i].tolist()
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
hpos_ids = (
hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
.permute(0, 2, 1, 3)
.flatten()
)
wpos_ids = (
wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
.permute(0, 2, 1, 3)
.flatten()
)
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def forward(
self,
x: torch.Tensor,
grid_thw: torch.Tensor,
) -> torch.Tensor:
# patchify
x = x.to(device=self.device, dtype=self.dtype)
x = self.patch_embed(x)
# compute position embedding
rotary_pos_emb = self.rot_pos_emb(grid_thw)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
# transformers
x = x.unsqueeze(1)
for blk in self.blocks:
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
# adapter
x = self.merger(x)
return x
cached_get_processor = lru_cache(get_processor)
class Qwen2VLForConditionalGeneration(nn.Module):
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
processor = cached_get_processor(self.config._name_or_path)
grid_t, grid_h, grid_w = image_grid_thw
num_image_tokens = (
grid_t
* grid_h
* grid_w
// processor.image_processor.merge_size
// processor.image_processor.merge_size
)
return num_image_tokens
def __init__(
self,
config: Qwen2VLConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.visual = Qwen2VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
# NOTE: Qwen2-VL vision encoder does not support any
# quantization method now.
quant_config=None,
prefix=add_prefix("visual", prefix),
)
self.model = Qwen2Model(
config, quant_config, prefix=add_prefix("model", prefix)
)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
# Use grid_t * grid_w * grid_h to pad tokens for each image
# add replaced padding by unique image hash
def pad_input_ids(self, input_ids: List[int], multi_modal_inputs: MultimodalInputs):
# Get all special token IDs
im_start_id: int = multi_modal_inputs.im_start_id
im_end_id: int = multi_modal_inputs.im_end_id
media_token_pairs = [(im_start_id, im_end_id)]
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
return pattern.pad_input_tokens(input_ids, multi_modal_inputs)
def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
pixel_values = image_input.pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
return image_embeds
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
video_embeds = self.visual(
pixel_values_videos, grid_thw=video_input["video_grid_thw"]
)
return video_embeds
def get_input_embeddings(self):
return self.model.embed_tokens
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
get_embedding: bool = False,
):
"""Run forward pass for Qwen2-VL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
"""
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
positions = forward_batch.mrope_positions
if not (
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
):
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"
)
inputs_embeds = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(),
mm_data_embedding_func=self.get_image_feature,
)
hidden_states = self.model(
input_ids=None,
positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
)
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
else:
return self.pooler(hidden_states, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "up_proj", 1),
("gate_up_proj", "gate_proj", 0),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if "visual" in name and "qkv.weight" in name:
visual_num_heads = self.config.vision_config.num_heads
visual_embed_dim = self.config.vision_config.embed_dim
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(
3, visual_num_heads, head_size, visual_embed_dim
)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
elif "visual" in name and "qkv.bias" in name:
visual_num_heads = self.config.vision_config.num_heads
visual_embed_dim = self.config.vision_config.embed_dim
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1)
if "visual" in name:
# adapt to VisionAttention
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
try:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
except KeyError:
print(params_dict.keys())
raise
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = Qwen2VLForConditionalGeneration