1187 lines
40 KiB
Python
1187 lines
40 KiB
Python
# Adapted from
|
|
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
|
# Copyright 2023 The SGLang team.
|
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
|
# and OPT implementations in this library. It has been modified from its
|
|
# original forms to accommodate minor architectural differences compared
|
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
|
|
from functools import partial
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Iterable,
|
|
List,
|
|
Literal,
|
|
Optional,
|
|
Tuple,
|
|
TypedDict,
|
|
Union,
|
|
)
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.types
|
|
from PIL import Image
|
|
from torch import nn
|
|
from torch.nn.init import trunc_normal_
|
|
from transformers import PretrainedConfig
|
|
|
|
from sglang.srt.layers.activation import get_act_fn
|
|
from sglang.srt.layers.attention.vision import VisionAttention
|
|
from sglang.srt.layers.linear import (
|
|
ColumnParallelLinear,
|
|
ReplicatedLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
from sglang.srt.managers.mm_utils import (
|
|
MultiModalityDataPaddingPatternTokenPairs,
|
|
general_mm_embed_routine,
|
|
)
|
|
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
|
|
from sglang.srt.utils import add_prefix
|
|
|
|
RawImageType = Union[Image.Image, torch.Tensor]
|
|
|
|
|
|
# sin/cos positional embedding helpers are adapted from:
|
|
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
|
|
def get_1d_sincos_pos_embed_from_grid(
|
|
embed_dim: int, pos: np.ndarray, version: Tuple[int, int] = (2, 0)
|
|
) -> torch.Tensor:
|
|
"""
|
|
embed_dim: output dimension for each position
|
|
pos: a list of positions to be encoded: size (M,) / (H, W)
|
|
out: (M, D) / (H, W, D)
|
|
"""
|
|
assert embed_dim % 2 == 0
|
|
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
|
omega /= embed_dim / 2.0
|
|
omega = 1.0 / 10000**omega # (D/2,)
|
|
|
|
if version == (2, 0):
|
|
pos = pos.reshape(-1) # (M,)
|
|
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
|
emb_sin = np.sin(out) # (M, D/2)
|
|
emb_cos = np.cos(out) # (M, D/2)
|
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
|
else:
|
|
out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
|
|
emb_sin = np.sin(out) # (H, W, D/2)
|
|
emb_cos = np.cos(out) # (H, W, D/2)
|
|
emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
|
|
return emb
|
|
|
|
|
|
def get_2d_sincos_pos_embed_from_grid(
|
|
embed_dim: int, grid: np.ndarray, version: Tuple[int, int] = (2, 0)
|
|
) -> torch.Tensor:
|
|
assert embed_dim % 2 == 0
|
|
|
|
# use half of dimensions to encode grid_h
|
|
emb_h = get_1d_sincos_pos_embed_from_grid(
|
|
embed_dim // 2, grid[0], version
|
|
) # (H*W, D/2) or (H, W, D/2)
|
|
emb_w = get_1d_sincos_pos_embed_from_grid(
|
|
embed_dim // 2, grid[1], version
|
|
) # (H*W, D/2) or (H, W, D/2)
|
|
|
|
if version == (2, 0):
|
|
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
|
else:
|
|
emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
|
|
return emb
|
|
|
|
|
|
def get_2d_sincos_pos_embed(
|
|
embed_dim: int,
|
|
grid_size: Union[int, Tuple[int, int]],
|
|
cls_token: bool = False,
|
|
version: Tuple[int, int] = (2, 0),
|
|
) -> torch.Tensor:
|
|
"""
|
|
grid_size: int of the grid height and width
|
|
return:
|
|
pos_embed: [grid_size*grid_size, embed_dim] or
|
|
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
|
"""
|
|
if isinstance(grid_size, int):
|
|
grid_h_size, grid_w_size = grid_size, grid_size
|
|
else:
|
|
grid_h_size, grid_w_size = grid_size[0], grid_size[1]
|
|
|
|
grid_h = np.arange(grid_h_size, dtype=np.float32)
|
|
grid_w = np.arange(grid_w_size, dtype=np.float32)
|
|
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
|
grid = np.stack(grid, axis=0)
|
|
assert isinstance(grid, np.ndarray) and grid.shape == (2, grid_h_size, grid_w_size)
|
|
|
|
if version == (2, 0):
|
|
grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
|
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
|
|
if cls_token:
|
|
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
|
else:
|
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
|
|
return pos_embed
|
|
|
|
|
|
class Idefics2VisionMLP(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.config = config
|
|
self.activation_fn = get_act_fn(config.hidden_act)
|
|
self.fc1 = ColumnParallelLinear(
|
|
config.hidden_size,
|
|
config.intermediate_size,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("fc1", prefix),
|
|
)
|
|
self.fc2 = RowParallelLinear(
|
|
config.intermediate_size,
|
|
config.hidden_size,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("fc2", prefix),
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states, _ = self.fc1(hidden_states)
|
|
hidden_states = self.activation_fn(hidden_states)
|
|
hidden_states, _ = self.fc2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class Idefics2EncoderLayer(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.embed_dim = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.self_attn = VisionAttention(
|
|
embed_dim=config.hidden_size,
|
|
num_heads=self.num_heads,
|
|
projection_size=config.intermediate_size,
|
|
use_qkv_parallel=True,
|
|
quant_config=quant_config,
|
|
dropout=config.attention_dropout,
|
|
use_context_forward=False,
|
|
softmax_in_single_precision=True,
|
|
flatten_batch=False,
|
|
prefix=add_prefix("self_attn", prefix),
|
|
)
|
|
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
|
self.mlp = Idefics2VisionMLP(
|
|
config,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("mlp", prefix),
|
|
)
|
|
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
cu_seqlens: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
hidden_states (`torch.FloatTensor`):
|
|
Input to the layer of shape `(batch, seq_len, embed_dim)`.
|
|
|
|
"""
|
|
residual = hidden_states
|
|
hidden_states = self.layer_norm1(hidden_states)
|
|
hidden_states = self.self_attn(hidden_states, cu_seqlens=cu_seqlens)
|
|
|
|
hidden_states = residual + hidden_states
|
|
residual = hidden_states
|
|
hidden_states = self.layer_norm2(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
return hidden_states
|
|
|
|
|
|
class Idefics2Encoder(nn.Module):
|
|
"""
|
|
Transformer encoder consisting of `config.num_hidden_layers` self attention
|
|
layers. Each layer is a
|
|
[`Idefics2EncoderLayer`].
|
|
|
|
Args:
|
|
config: Idefics2Config
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.config = config
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
Idefics2EncoderLayer(
|
|
config,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix(f"layers.{i}", prefix),
|
|
)
|
|
for i in range(config.num_hidden_layers)
|
|
]
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
inputs_embeds: torch.Tensor,
|
|
cu_seqlens: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
r"""
|
|
Args:
|
|
inputs_embeds (torch.Tensor):
|
|
Optionally, instead of passing `input_ids` you can choose to
|
|
directly pass an embedded representation.
|
|
This is useful if you want more control over how to convert
|
|
`input_ids` indices into associated vectorsthan the model's
|
|
internal embedding lookup matrix.
|
|
"""
|
|
hidden_states = inputs_embeds
|
|
for encoder_layer in self.layers:
|
|
layer_outputs = encoder_layer(
|
|
hidden_states,
|
|
cu_seqlens=cu_seqlens,
|
|
)
|
|
hidden_states = layer_outputs
|
|
return hidden_states
|
|
|
|
|
|
class Idefics2VisionEmbeddings(nn.Module):
|
|
"""
|
|
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings
|
|
` to enable images of variable
|
|
resolution.
|
|
|
|
The modifications are adapted from [Patch n' Pack: NaViT, a Vision
|
|
Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
|
|
which allows treating images in their native aspect ratio and without the
|
|
need to resize them to the same fixed size. In particular, we start from the
|
|
original pre-trained SigLIP model(which uses images of fixed-size square
|
|
images) and adapt it by training on images of variable resolutions.
|
|
"""
|
|
|
|
def __init__(self, config: PretrainedConfig):
|
|
super().__init__()
|
|
self.embed_dim = config.hidden_size
|
|
self.image_size = config.image_size
|
|
self.patch_size = config.patch_size
|
|
self.patch_embedding = nn.Conv2d(
|
|
in_channels=config.num_channels,
|
|
out_channels=self.embed_dim,
|
|
kernel_size=self.patch_size,
|
|
stride=self.patch_size,
|
|
padding="valid",
|
|
)
|
|
self.num_patches_per_side = self.image_size // self.patch_size
|
|
self.num_patches = self.num_patches_per_side**2
|
|
self.num_positions = self.num_patches
|
|
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
|
|
|
def get_position_ids(
|
|
self,
|
|
pixel_values: torch.FloatTensor,
|
|
patch_attention_mask: torch.BoolTensor,
|
|
tgt_sizes: Optional[torch.IntTensor] = None,
|
|
):
|
|
batch_size, _, max_im_h, max_im_w = pixel_values.shape
|
|
|
|
max_nb_patches_h, max_nb_patches_w = (
|
|
max_im_h // self.patch_size,
|
|
max_im_w // self.patch_size,
|
|
)
|
|
boundaries = torch.arange(
|
|
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
|
|
)
|
|
position_ids = torch.full(
|
|
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
|
|
)
|
|
|
|
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
|
|
|
|
if tgt_sizes is not None:
|
|
nb_patches_h = tgt_sizes[batch_idx][0]
|
|
nb_patches_w = tgt_sizes[batch_idx][1]
|
|
else:
|
|
nb_patches_h = p_attn_mask[:, 0].sum()
|
|
nb_patches_w = p_attn_mask[0].sum()
|
|
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
|
|
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
|
|
bucket_coords_h = torch.bucketize(
|
|
fractional_coords_h, boundaries, right=True
|
|
)
|
|
bucket_coords_w = torch.bucketize(
|
|
fractional_coords_w, boundaries, right=True
|
|
)
|
|
pos_ids = (
|
|
bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
|
|
).flatten()
|
|
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
|
|
position_ids = position_ids.to(self.position_embedding.weight.device)
|
|
return position_ids
|
|
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.FloatTensor,
|
|
patch_attention_mask: torch.BoolTensor,
|
|
tgt_sizes: Optional[torch.IntTensor] = None,
|
|
) -> torch.Tensor:
|
|
target_dtype = self.patch_embedding.weight.dtype
|
|
pixel_values = pixel_values.to(
|
|
device=self.patch_embedding.weight.device, dtype=target_dtype
|
|
)
|
|
patch_embeds = self.patch_embedding(pixel_values)
|
|
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
|
position_ids = self.get_position_ids(
|
|
pixel_values, patch_attention_mask, tgt_sizes
|
|
)
|
|
|
|
embeddings = embeddings + self.position_embedding(position_ids)
|
|
return embeddings
|
|
|
|
|
|
class Idefics2VisionTransformer(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
embed_dim = config.hidden_size
|
|
self.config = config
|
|
self.embeddings = Idefics2VisionEmbeddings(config)
|
|
self.encoder = Idefics2Encoder(
|
|
config=config,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("encoder", prefix),
|
|
)
|
|
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
|
|
|
def get_input_embeddings(self) -> nn.Embedding:
|
|
return self.embeddings
|
|
|
|
def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor:
|
|
patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] # shape: (batch_size,)
|
|
cu_seqlens = torch.cat(
|
|
[
|
|
torch.tensor([0], device=patch_len.device, dtype=torch.int32),
|
|
torch.cumsum(patch_len, dim=0, dtype=torch.int32),
|
|
],
|
|
dim=0,
|
|
).to(tgt_sizes.device)
|
|
return cu_seqlens
|
|
|
|
def forward(
|
|
self,
|
|
pixel_values,
|
|
patch_attention_mask: Optional[torch.BoolTensor] = None,
|
|
tgt_sizes: Optional[torch.IntTensor] = None,
|
|
) -> torch.Tensor:
|
|
hidden_states = self.embeddings(
|
|
pixel_values=pixel_values,
|
|
patch_attention_mask=patch_attention_mask,
|
|
tgt_sizes=tgt_sizes,
|
|
)
|
|
cu_seqlens = self.compute_cu_seqlens(tgt_sizes)
|
|
encoder_outputs = self.encoder(
|
|
hidden_states,
|
|
cu_seqlens=cu_seqlens,
|
|
)
|
|
last_hidden_state = self.post_layernorm(encoder_outputs)
|
|
return last_hidden_state
|
|
|
|
|
|
class MiniCPMVImagePixelInputs(TypedDict):
|
|
type: Literal["pixel_values"]
|
|
data: List[torch.Tensor]
|
|
"""
|
|
Shape: `(batch_size * num_images, num_channels, height, width)`
|
|
|
|
Note that the image size may vary, so we pass it as a list
|
|
instead of a batched tensor.
|
|
"""
|
|
|
|
image_bounds: torch.Tensor
|
|
"""
|
|
Shape: `(batch_size * num_images, 2)`
|
|
|
|
This should be in `(start, stop)` format.
|
|
"""
|
|
|
|
tgt_sizes: torch.Tensor
|
|
"""
|
|
Shape: `(batch_size * num_images, 2)`
|
|
|
|
This should be in `(height, width)` format.
|
|
"""
|
|
|
|
|
|
class MiniCPMVImageEmbeddingInputs(TypedDict):
|
|
type: Literal["image_embeds"]
|
|
data: torch.Tensor
|
|
"""
|
|
Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
|
|
|
|
`hidden_size` must match the hidden size of language model backbone.
|
|
instead of a batched tensor.
|
|
"""
|
|
|
|
image_bounds: torch.Tensor
|
|
"""
|
|
Shape: `(batch_size * num_images, 2)`
|
|
|
|
This should be in `(start, stop)` format.
|
|
"""
|
|
|
|
|
|
MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, MiniCPMVImageEmbeddingInputs]
|
|
|
|
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
|
|
|
|
|
|
class BaseResampler(nn.Module):
|
|
"""
|
|
A 2D perceiver-resampler network with one cross attention layers by
|
|
(grid_size**2) learnable queries and 2d sincos pos_emb.
|
|
Outputs:
|
|
A tensor with the shape of (grid_size**2, embed_dim)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_queries: int,
|
|
embed_dim: int,
|
|
num_heads: int,
|
|
kv_dim: Optional[int] = None,
|
|
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
|
|
do_post_projection: bool = True,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.num_queries = num_queries
|
|
self.embed_dim = embed_dim
|
|
self.num_heads = num_heads
|
|
|
|
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
|
|
trunc_normal_(self.query, std=0.02)
|
|
if kv_dim is not None and kv_dim != embed_dim:
|
|
self.kv_proj = ReplicatedLinear(
|
|
kv_dim,
|
|
embed_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("kv_proj", prefix),
|
|
)
|
|
else:
|
|
# Maintain the same return value with ReplicatedLinear.forward
|
|
self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa
|
|
nn.Identity()(*args, **kwargs),
|
|
None,
|
|
)
|
|
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
|
|
self.ln_q = norm_layer(embed_dim)
|
|
self.ln_kv = norm_layer(embed_dim)
|
|
self.do_post_projection = do_post_projection
|
|
self.ln_post = norm_layer(embed_dim) if do_post_projection else None
|
|
self.proj = (
|
|
nn.Parameter((embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
|
|
if do_post_projection
|
|
else None
|
|
)
|
|
|
|
def _init_weights(self, m: nn.Module) -> None:
|
|
if isinstance(m, nn.Linear):
|
|
trunc_normal_(m.weight, std=0.02)
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.LayerNorm):
|
|
nn.init.constant_(m.bias, 0)
|
|
nn.init.constant_(m.weight, 1.0)
|
|
|
|
def _repeat(self, query, N: int):
|
|
return query.unsqueeze(1).repeat(1, N, 1)
|
|
|
|
|
|
class Resampler2_5(BaseResampler):
|
|
|
|
def __init__(
|
|
self,
|
|
num_queries: int,
|
|
embed_dim: int,
|
|
num_heads: int,
|
|
kv_dim: Optional[int] = None,
|
|
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
|
|
max_size: Tuple[int, int] = (70, 70),
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__(
|
|
num_queries,
|
|
embed_dim,
|
|
num_heads,
|
|
kv_dim,
|
|
norm_layer,
|
|
quant_config=quant_config,
|
|
prefix=prefix,
|
|
)
|
|
|
|
self.max_size = max_size
|
|
self._set_2d_pos_cache(self.max_size)
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
def _set_2d_pos_cache(
|
|
self, max_size: Tuple[int, int], device: torch.types.Device = "cpu"
|
|
) -> None:
|
|
pos_embed_arr = get_2d_sincos_pos_embed(
|
|
self.embed_dim, max_size, version=(2, 5)
|
|
)
|
|
pos_embed = torch.from_numpy(pos_embed_arr).float().to(device)
|
|
self.register_buffer("pos_embed", pos_embed, persistent=False)
|
|
|
|
def _adjust_pos_cache(
|
|
self, tgt_sizes: torch.Tensor, device: torch.types.Device
|
|
) -> None:
|
|
max_h = tgt_sizes[:, 0].max().item()
|
|
max_w = tgt_sizes[:, 1].max().item()
|
|
assert isinstance(max_h, int) and isinstance(max_w, int)
|
|
|
|
if max_h > self.max_size[0] or max_w > self.max_size[1]:
|
|
self.max_size = (
|
|
max(max_h, self.max_size[0]),
|
|
max(max_w, self.max_size[1]),
|
|
)
|
|
self._set_2d_pos_cache(self.max_size, device)
|
|
|
|
def forward(self, x: torch.Tensor, tgt_sizes: torch.Tensor) -> torch.Tensor:
|
|
assert x.shape[0] == tgt_sizes.shape[0]
|
|
bs = x.shape[0]
|
|
|
|
device = x.device
|
|
dtype = x.dtype
|
|
|
|
patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
|
|
|
|
self._adjust_pos_cache(tgt_sizes, device=device)
|
|
|
|
max_patch_len = patch_len.max().item()
|
|
assert isinstance(max_patch_len, int)
|
|
|
|
key_padding_mask = torch.zeros(
|
|
(bs, max_patch_len), dtype=torch.bool, device=device
|
|
)
|
|
|
|
pos_embed = []
|
|
for i in range(bs):
|
|
tgt_h, tgt_w = tgt_sizes[i].tolist()
|
|
pos_embed.append(
|
|
self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype)
|
|
) # patches * D
|
|
key_padding_mask[i, patch_len[i] :] = True
|
|
pos_embed = torch.nn.utils.rnn.pad_sequence(
|
|
pos_embed, batch_first=True, padding_value=0.0
|
|
).permute(
|
|
1, 0, 2
|
|
) # BLD => L * B * D
|
|
x, _ = self.kv_proj(x) # B * L * D
|
|
x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
|
|
|
|
q = self.ln_q(self.query) # Q * D
|
|
|
|
out = self.attn(
|
|
self._repeat(q, bs), # Q * B * D
|
|
x + pos_embed, # L * B * D + L * B * D
|
|
x,
|
|
key_padding_mask=key_padding_mask,
|
|
)[0]
|
|
# out: Q * B * D
|
|
x = out.permute(1, 0, 2) # B * Q * D
|
|
|
|
x = self.ln_post(x)
|
|
x = x @ self.proj
|
|
return x
|
|
|
|
|
|
def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
|
|
version_float = getattr(config, "version", None)
|
|
|
|
# The old configs do not include version number
|
|
# TODO: Remove this after the HF repos are updated
|
|
if version_float is None:
|
|
if config.hidden_size == 2304 and config.query_num == 64:
|
|
return 2, 0
|
|
return 2, 5
|
|
|
|
version_str = str(version_float)
|
|
return tuple(int(x) for x in version_str.split("."))
|
|
|
|
|
|
class MiniCPMVBaseModel(nn.Module):
|
|
"""
|
|
The abstract class of MiniCPMV can only be inherited, but cannot be
|
|
instantiated.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
# All MiniCPM-V models disable `tie_word_embeddings` but
|
|
# `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
|
|
# check `tie_word_embeddings` until SGLang integrate MiniCPM-V model
|
|
# and config class
|
|
self.config = config
|
|
|
|
self.version = get_version_by_config(self.config)
|
|
self.llm = self.init_llm(
|
|
config=config, quant_config=quant_config, prefix=add_prefix("llm", prefix)
|
|
)
|
|
self.vpm = self.init_vision_module(
|
|
config, quant_config, add_prefix("vpm", prefix)
|
|
)
|
|
self.vision_dim = (
|
|
self.vpm.embed_dim
|
|
if self.version == (2, 0)
|
|
else self.vpm.embeddings.embed_dim
|
|
)
|
|
self.embed_dim = self.config.hidden_size
|
|
|
|
self.resampler = self.init_resampler(
|
|
self.embed_dim,
|
|
self.vision_dim,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("resampler", prefix),
|
|
)
|
|
|
|
self.logits_processor = LogitsProcessor(config)
|
|
|
|
def _get_image_bounds(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
pad_values: List[int],
|
|
im_start_id: int,
|
|
im_end_id: int,
|
|
slice_start_id: Optional[int] = None,
|
|
slice_end_id: Optional[int] = None,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Returns a tensor indicating the bounds (start and end token ids) of the images
|
|
"""
|
|
# All the images in the batch should share the same special image
|
|
# bound token ids.
|
|
start_cond = input_ids == im_start_id
|
|
end_cond = input_ids == im_end_id
|
|
if slice_start_id is not None:
|
|
start_cond |= input_ids == slice_start_id
|
|
end_cond |= input_ids == slice_end_id
|
|
|
|
(image_start_tokens,) = torch.where(start_cond)
|
|
image_start_tokens += 1
|
|
(image_end_tokens,) = torch.where(end_cond)
|
|
|
|
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the images
|
|
if len(image_start_tokens) != len(image_end_tokens):
|
|
if (
|
|
len(image_start_tokens) + 1 == len(image_end_tokens)
|
|
and input_ids[0] in pad_values
|
|
and len(image_start_tokens) != 0
|
|
and len(image_end_tokens) != 0
|
|
and image_end_tokens[0] < image_start_tokens[0]
|
|
):
|
|
image_start_tokens = torch.cat(
|
|
[
|
|
torch.tensor([0], device=image_start_tokens.device),
|
|
image_start_tokens,
|
|
]
|
|
)
|
|
valid_image_nums = min(len(image_start_tokens), len(image_end_tokens))
|
|
|
|
if valid_image_nums == 0:
|
|
return torch.zeros((0, 2), device=input_ids.device)
|
|
|
|
# Filter out pairs where start_token >= end_token
|
|
valid_pairs = []
|
|
for i in range(valid_image_nums):
|
|
start_token = image_start_tokens[i]
|
|
end_token = image_end_tokens[i]
|
|
if start_token < end_token:
|
|
valid_pairs.append((start_token, end_token))
|
|
|
|
if not valid_pairs:
|
|
return torch.zeros((0, 2), device=input_ids.device)
|
|
|
|
# Convert valid pairs to tensor
|
|
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
|
|
return valid_pairs_tensor
|
|
|
|
def _parse_and_validate_inputs(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
**kwargs: object,
|
|
) -> Optional[MiniCPMVImageInputs]:
|
|
pixel_values = kwargs.pop("pixel_values", [])
|
|
tgt_sizes = kwargs.pop("tgt_sizes", [])
|
|
im_start_id = kwargs.pop("im_start_id", None)
|
|
im_end_id = kwargs.pop("im_end_id", None)
|
|
slice_start_id = kwargs.pop("slice_start_id", None)
|
|
slice_end_id = kwargs.pop("slice_end_id", None)
|
|
image_embeds = kwargs.pop("image_embeds", None)
|
|
pad_values = kwargs.pop("pad_values", None)
|
|
|
|
if image_embeds is not None:
|
|
image_bounds = self._get_image_bounds(
|
|
input_ids=input_ids,
|
|
pad_values=pad_values,
|
|
im_start_id=im_start_id,
|
|
im_end_id=im_end_id,
|
|
slice_start_id=slice_start_id,
|
|
slice_end_id=slice_end_id,
|
|
)
|
|
if not isinstance(image_embeds, (torch.Tensor, list)):
|
|
raise ValueError(
|
|
f"Incorrect type of image embeds. "
|
|
f"Got type: {type(image_embeds)}"
|
|
)
|
|
|
|
if isinstance(image_embeds, list):
|
|
image_embeds = torch.cat(image_embeds)
|
|
|
|
return MiniCPMVImageEmbeddingInputs(
|
|
image_bounds=image_bounds,
|
|
data=image_embeds,
|
|
type="image_embeds",
|
|
)
|
|
|
|
image_bounds = self._get_image_bounds(
|
|
input_ids=input_ids,
|
|
pad_values=pad_values,
|
|
im_start_id=im_start_id,
|
|
im_end_id=im_end_id,
|
|
slice_start_id=slice_start_id,
|
|
slice_end_id=slice_end_id,
|
|
)
|
|
return MiniCPMVImagePixelInputs(
|
|
image_bounds=image_bounds.to(device=input_ids.device),
|
|
data=pixel_values,
|
|
tgt_sizes=tgt_sizes,
|
|
type="pixel_values",
|
|
)
|
|
|
|
def get_embedding(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
image_inputs: Optional[MiniCPMVImageInputs],
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
|
|
|
|
if image_inputs is None: # No image
|
|
vision_hidden_states = torch.tensor([], device=input_ids.device)
|
|
else:
|
|
if image_inputs["type"] == "image_embeds":
|
|
vision_hidden_states = (
|
|
image_inputs["data"]
|
|
.type(vlm_embedding.dtype)
|
|
.to(vlm_embedding.device)
|
|
)
|
|
else:
|
|
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
|
|
# See NOTE in _parse_and_validate_inputs
|
|
image_bounds = image_inputs["image_bounds"]
|
|
if len(image_bounds) > 0:
|
|
image_indices = torch.stack(
|
|
[
|
|
torch.arange(start, end, dtype=torch.long)
|
|
for start, end in image_bounds.tolist()
|
|
]
|
|
).to(vlm_embedding.device)
|
|
|
|
vlm_embedding.scatter_(
|
|
0,
|
|
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
|
|
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
|
|
)
|
|
|
|
return vlm_embedding, vision_hidden_states
|
|
|
|
def get_input_embeddings(self) -> nn.Embedding:
|
|
return self.llm.get_input_embedding()
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
**kwargs: Any,
|
|
) -> torch.Tensor:
|
|
inputs_embeds = general_mm_embed_routine(
|
|
input_ids=input_ids,
|
|
forward_batch=forward_batch,
|
|
embed_tokens=self.get_input_embeddings(),
|
|
mm_data_embedding_func=self.get_image_features,
|
|
)
|
|
|
|
hidden_states = self.llm.model(
|
|
input_ids=None,
|
|
positions=positions,
|
|
forward_batch=forward_batch,
|
|
input_embeds=inputs_embeds,
|
|
)
|
|
|
|
return self.logits_processor(
|
|
input_ids, hidden_states, self.llm.lm_head, forward_batch
|
|
)
|
|
|
|
def init_llm(
|
|
self,
|
|
config: Qwen2Config,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> nn.Module:
|
|
raise NotImplementedError
|
|
|
|
def init_vision_module(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig],
|
|
prefix: str = "",
|
|
) -> nn.Module:
|
|
raise NotImplementedError
|
|
|
|
def init_resampler(
|
|
self,
|
|
embed_dim: int,
|
|
vision_dim: int,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> nn.Module:
|
|
raise NotImplementedError
|
|
|
|
def get_vision_embedding(
|
|
self,
|
|
pixel_values: List[torch.Tensor],
|
|
patch_attn_mask: Optional[torch.Tensor] = None,
|
|
tgt_sizes: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
raise NotImplementedError
|
|
|
|
def get_image_features(self, image_inputs: MultimodalInputs) -> torch.Tensor:
|
|
raise NotImplementedError
|
|
|
|
|
|
class MiniCPMV2_6(MiniCPMVBaseModel):
|
|
packed_modules_mapping = {
|
|
"qkv_proj": [
|
|
"q_proj",
|
|
"k_proj",
|
|
"v_proj",
|
|
],
|
|
"gate_up_proj": [
|
|
"gate_proj",
|
|
"up_proj",
|
|
],
|
|
}
|
|
# LoRA specific attributes
|
|
supported_lora_modules = [
|
|
# vision encoder
|
|
"fc1",
|
|
"fc2",
|
|
"out_proj",
|
|
# language model
|
|
"qkv_proj", # same name with vision encoder
|
|
"o_proj",
|
|
"gate_up_proj",
|
|
"down_proj",
|
|
# resampler
|
|
"kv_proj",
|
|
]
|
|
|
|
# BitandBytes specific attributes
|
|
bitsandbytes_stacked_params_mapping = {
|
|
# shard_name, weight_name, index
|
|
"q_proj": ("qkv_proj", 0),
|
|
"k_proj": ("qkv_proj", 1),
|
|
"v_proj": ("qkv_proj", 2),
|
|
"gate_proj": ("gate_up_proj", 0),
|
|
"up_proj": ("gate_up_proj", 1),
|
|
}
|
|
|
|
embedding_modules = {}
|
|
embedding_padding_modules = []
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__(config=config, quant_config=quant_config, prefix=prefix)
|
|
assert self.version == (2, 6)
|
|
|
|
def init_llm(
|
|
self,
|
|
config: Qwen2Config,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> nn.Module:
|
|
return Qwen2ForCausalLM(config=config, quant_config=quant_config, prefix=prefix)
|
|
|
|
def init_vision_module(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig],
|
|
prefix: str = "",
|
|
) -> nn.Module:
|
|
model = Idefics2VisionTransformer(
|
|
config=config.vision_config, quant_config=quant_config, prefix=prefix
|
|
)
|
|
if self.config.drop_vision_last_layer:
|
|
model.encoder.layers = model.encoder.layers[:-1]
|
|
|
|
setattr(model, "embed_dim", model.embeddings.embed_dim)
|
|
setattr(model, "patch_size", model.embeddings.patch_size)
|
|
return model
|
|
|
|
def init_resampler(
|
|
self,
|
|
embed_dim: int,
|
|
vision_dim: int,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> nn.Module:
|
|
with set_default_torch_dtype(torch.float16):
|
|
# The resampler in 2.6 remains consistent with the one in 2.5.
|
|
resampler = Resampler2_5(
|
|
num_queries=self.config.query_num,
|
|
embed_dim=embed_dim,
|
|
num_heads=embed_dim // 128,
|
|
kv_dim=vision_dim,
|
|
quant_config=quant_config,
|
|
prefix=prefix,
|
|
)
|
|
|
|
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
|
|
|
|
def get_vision_embedding(
|
|
self,
|
|
pixel_values: List[torch.Tensor],
|
|
patch_attn_mask: Optional[torch.Tensor] = None,
|
|
tgt_sizes: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
vision_embedding = self.vpm(
|
|
pixel_values,
|
|
patch_attention_mask=patch_attn_mask,
|
|
tgt_sizes=tgt_sizes,
|
|
)
|
|
return vision_embedding
|
|
|
|
def get_image_features(
|
|
self,
|
|
image_inputs: MultimodalInputs,
|
|
) -> torch.Tensor:
|
|
# list of tensors
|
|
pixel_values = image_inputs.pixel_values
|
|
|
|
tgt_sizes = image_inputs.tgt_sizes
|
|
|
|
device = self.vpm.embeddings.position_embedding.weight.device
|
|
dtype = self.vpm.embeddings.position_embedding.weight.dtype
|
|
all_pixel_values_lst = [
|
|
i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
|
|
]
|
|
|
|
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
|
|
assert isinstance(max_patches, int)
|
|
|
|
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
|
|
all_pixel_values_lst, batch_first=True, padding_value=0.0
|
|
)
|
|
B, L, _ = all_pixel_values.shape
|
|
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
|
|
patch_attn_mask = torch.zeros(
|
|
(B, 1, max_patches), dtype=torch.bool, device=device
|
|
)
|
|
|
|
tgt_sizes_tensor = tgt_sizes.clone().to(device=patch_attn_mask.device)
|
|
mask_shapes = tgt_sizes_tensor[:, 0] * tgt_sizes_tensor[:, 1]
|
|
patch_attn_mask[:, 0, :] = torch.arange(
|
|
patch_attn_mask.size(2), device=patch_attn_mask.device
|
|
).unsqueeze(0) < mask_shapes.unsqueeze(1)
|
|
|
|
vision_embedding = self.vpm(
|
|
all_pixel_values.type(dtype),
|
|
patch_attention_mask=patch_attn_mask,
|
|
tgt_sizes=tgt_sizes,
|
|
)
|
|
return self.resampler(vision_embedding, tgt_sizes)
|
|
|
|
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
|
# Get all special token IDs
|
|
im_start_id: int = image_inputs.im_start_id
|
|
im_end_id: int = image_inputs.im_end_id
|
|
slice_start_id: int = image_inputs.slice_start_id
|
|
slice_end_id: int = image_inputs.slice_end_id
|
|
|
|
media_token_pairs = [(im_start_id, im_end_id), (slice_start_id, slice_end_id)]
|
|
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
|
|
|
return pattern.pad_input_tokens(input_ids, image_inputs)
|
|
|
|
|
|
_SUPPORT_VERSION = {(2, 6): MiniCPMV2_6}
|
|
|
|
|
|
class MiniCPMV:
|
|
"""
|
|
Different versions of MiniCPMV use different visual encoders and LLMs,
|
|
which is not conducive to the current integration logic of LoRA and
|
|
bitsandbytes in SGLang. Therefore, it is necessary to separate them.
|
|
"""
|
|
|
|
# Ensure that the LoRA support check passes when the class is not
|
|
# initialized, but set all these attributes to empty.
|
|
packed_modules_mapping = {}
|
|
supported_lora_modules = []
|
|
embedding_modules = {}
|
|
embedding_padding_modules = []
|
|
|
|
minicpmv: nn.Module
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
if not hasattr(config, "version"):
|
|
version = (2, 6)
|
|
else:
|
|
version = str(config.version).split(".")
|
|
version = tuple([int(x) for x in version])
|
|
# Dispatch class based on version
|
|
instance_class = _SUPPORT_VERSION.get(version)
|
|
if instance_class is None:
|
|
raise ValueError("Currently, MiniCPMV only supports versions 2.6")
|
|
|
|
try:
|
|
minicpmv = instance_class(
|
|
config=config, quant_config=quant_config, prefix=prefix
|
|
)
|
|
self.minicpmv = minicpmv
|
|
except Exception as e:
|
|
print(f"Failed to instantiate MiniCPMV: {e}")
|
|
raise e
|
|
self.config = config
|
|
|
|
def __getattr__(self, name):
|
|
if name == "minicpmv":
|
|
return None
|
|
return getattr(self.minicpmv, name)
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
return self.minicpmv(*args, **kwargs)
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("qkv_proj", "q_proj", "q"),
|
|
("qkv_proj", "k_proj", "k"),
|
|
("qkv_proj", "v_proj", "v"),
|
|
("gate_up_proj", "gate_proj", 0),
|
|
("gate_up_proj", "up_proj", 1),
|
|
]
|
|
|
|
params_dict = dict(self.minicpmv.named_parameters())
|
|
for name, loaded_weight in weights:
|
|
if "rotary_emb.inv_freq~" in name or "projector" in name:
|
|
continue
|
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
|
# Models trained using ColossalAI may include these tensors in
|
|
# the checkpoint. Skip them.
|
|
continue
|
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
|
continue
|
|
|
|
# adapt to VisionAttention
|
|
name = name.replace(r"self_attn.out_proj", r"self_attn.proj")
|
|
|
|
if "sampler" in name:
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
continue
|
|
|
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
# replace the name and load with customized loader
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
# # Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
|
|
|
|
EntryClass = MiniCPMV
|