2071 lines
69 KiB
Python
2071 lines
69 KiB
Python
# Copyright 2023-2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
# Copied and Adapted from:
|
|
# https://github.com/deepseek-ai/Janus
|
|
|
|
|
|
import collections
|
|
import math
|
|
import os
|
|
from dataclasses import field
|
|
from enum import Enum
|
|
from functools import partial
|
|
from itertools import repeat
|
|
from typing import (
|
|
Callable,
|
|
Final,
|
|
Iterable,
|
|
Literal,
|
|
Optional,
|
|
Sequence,
|
|
Set,
|
|
Tuple,
|
|
Type,
|
|
Union,
|
|
)
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
from torch import Tensor, _assert, nn
|
|
from torch.nn.init import trunc_normal_
|
|
from transformers import AutoModel, PreTrainedModel
|
|
|
|
from sglang.srt.configs.janus_pro import *
|
|
from sglang.srt.layers.attention.vision import VisionAttention
|
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
from sglang.srt.layers.quantization import QuantizationConfig
|
|
from sglang.srt.managers.mm_utils import (
|
|
MultiModalityDataPaddingPatternTokenPairs,
|
|
general_mm_embed_routine,
|
|
)
|
|
from sglang.srt.managers.schedule_batch import MultimodalInputs, global_server_args_dict
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
from sglang.srt.models.llama import LlamaForCausalLM
|
|
from sglang.utils import logger
|
|
|
|
#################################################################################
|
|
# VQ Model Configs #
|
|
#################################################################################
|
|
|
|
|
|
# Copied from:
|
|
# https://github.com/deepseek-ai/Janus/tree/main/janus/models/vq_model.py
|
|
@dataclass
|
|
class ModelArgs:
|
|
codebook_size: int = 16384
|
|
codebook_embed_dim: int = 8
|
|
codebook_l2_norm: bool = True
|
|
codebook_show_usage: bool = True
|
|
commit_loss_beta: float = 0.25
|
|
entropy_loss_ratio: float = 0.0
|
|
|
|
encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
|
|
decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
|
|
z_channels: int = 256
|
|
dropout_p: float = 0.0
|
|
|
|
|
|
def named_apply(
|
|
fn: Callable,
|
|
module: nn.Module,
|
|
name="",
|
|
depth_first: bool = True,
|
|
include_root: bool = False,
|
|
) -> nn.Module:
|
|
if not depth_first and include_root:
|
|
fn(module=module, name=name)
|
|
for child_name, child_module in module.named_children():
|
|
child_name = ".".join((name, child_name)) if name else child_name
|
|
named_apply(
|
|
fn=fn,
|
|
module=child_module,
|
|
name=child_name,
|
|
depth_first=depth_first,
|
|
include_root=True,
|
|
)
|
|
if depth_first and include_root:
|
|
fn(module=module, name=name)
|
|
return module
|
|
|
|
|
|
def VQ_16(**kwargs):
|
|
return VQModel(
|
|
ModelArgs(
|
|
encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs
|
|
)
|
|
)
|
|
|
|
|
|
VQ_models = {"VQ-16": VQ_16}
|
|
|
|
import collections.abc
|
|
|
|
|
|
# From PyTorch internals
|
|
def _ntuple(n):
|
|
def parse(x):
|
|
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
|
return tuple(x)
|
|
return tuple(repeat(x, n))
|
|
|
|
return parse
|
|
|
|
|
|
def _trunc_normal_(tensor, mean, std, a, b):
|
|
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
|
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
|
def norm_cdf(x):
|
|
# Computes standard normal cumulative distribution function
|
|
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
|
|
|
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
|
logger.warn(
|
|
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
|
"The distribution of values may be incorrect.",
|
|
stacklevel=2,
|
|
)
|
|
|
|
# Values are generated by using a truncated uniform distribution and
|
|
# then using the inverse CDF for the normal distribution.
|
|
# Get upper and lower cdf values
|
|
l = norm_cdf((a - mean) / std)
|
|
u = norm_cdf((b - mean) / std)
|
|
|
|
# Uniformly fill tensor with values from [l, u], then translate to
|
|
# [2l-1, 2u-1].
|
|
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
|
|
|
# Use inverse cdf transform for normal distribution to get truncated
|
|
# standard normal
|
|
if tensor.dtype in [torch.float16, torch.bfloat16]:
|
|
# The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
|
|
og_dtype = tensor.dtype
|
|
tensor = tensor.to(torch.float32)
|
|
tensor.erfinv_()
|
|
tensor = tensor.to(og_dtype)
|
|
else:
|
|
tensor.erfinv_()
|
|
|
|
# Transform to proper mean, std
|
|
tensor.mul_(std * math.sqrt(2.0))
|
|
tensor.add_(mean)
|
|
|
|
# Clamp to ensure it's in the proper range
|
|
if tensor.dtype == torch.float16:
|
|
# The `clamp_` op is not (yet?) defined in float16+cpu
|
|
tensor = tensor.to(torch.float32)
|
|
tensor.clamp_(min=a, max=b)
|
|
else:
|
|
tensor.clamp_(min=a, max=b)
|
|
|
|
|
|
def trunc_normal_tf_(
|
|
tensor: torch.Tensor,
|
|
mean: float = 0.0,
|
|
std: float = 1.0,
|
|
a: float = -2.0,
|
|
b: float = 2.0,
|
|
):
|
|
"""Fills the input Tensor with values drawn from a truncated
|
|
normal distribution. The values are effectively drawn from the
|
|
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
|
|
with values outside :math:`[a, b]` redrawn until they are within
|
|
the bounds. The method used for generating the random values works
|
|
best when :math:`a \\leq \text{mean} \\leq b`.
|
|
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
|
|
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
|
|
and the result is subsquently scaled and shifted by the mean and std args.
|
|
Args:
|
|
tensor: an n-dimensional `torch.Tensor`
|
|
mean: the mean of the normal distribution
|
|
std: the standard deviation of the normal distribution
|
|
a: the minimum cutoff value
|
|
b: the maximum cutoff value
|
|
"""
|
|
with torch.no_grad():
|
|
_trunc_normal_(tensor, 0, 1.0, a, b)
|
|
tensor.mul_(std).add_(mean)
|
|
|
|
|
|
to_2tuple = _ntuple(2)
|
|
|
|
|
|
class Format(str, Enum):
|
|
NCHW = "NCHW"
|
|
NHWC = "NHWC"
|
|
NCL = "NCL"
|
|
NLC = "NLC"
|
|
|
|
|
|
def nchw_to(x: torch.Tensor, fmt: Format):
|
|
if fmt == Format.NHWC:
|
|
x = x.permute(0, 2, 3, 1)
|
|
elif fmt == Format.NLC:
|
|
x = x.flatten(2).transpose(1, 2)
|
|
elif fmt == Format.NCL:
|
|
x = x.flatten(2)
|
|
return x
|
|
|
|
|
|
def resample_patch_embed(
|
|
patch_embed,
|
|
new_size: List[int],
|
|
interpolation: str = "bicubic",
|
|
antialias: bool = True,
|
|
verbose: bool = False,
|
|
):
|
|
"""Resample the weights of the patch embedding kernel to target resolution.
|
|
We resample the patch embedding kernel by approximately inverting the effect
|
|
of patch resizing.
|
|
|
|
Code based on:
|
|
https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py
|
|
|
|
With this resizing, we can for example load a B/8 filter into a B/16 model
|
|
and, on 2x larger input image, the result will match.
|
|
|
|
Args:
|
|
patch_embed: original parameter to be resized.
|
|
new_size (tuple(int, int): target shape (height, width)-only.
|
|
interpolation (str): interpolation for resize
|
|
antialias (bool): use anti-aliasing filter in resize
|
|
verbose (bool): log operation
|
|
Returns:
|
|
Resized patch embedding kernel.
|
|
"""
|
|
import numpy as np
|
|
|
|
try:
|
|
from torch import vmap
|
|
except ImportError:
|
|
from torch.func import vmap
|
|
|
|
assert len(patch_embed.shape) == 4, "Four dimensions expected"
|
|
assert len(new_size) == 2, "New shape should only be hw"
|
|
old_size = patch_embed.shape[-2:]
|
|
if tuple(old_size) == tuple(new_size):
|
|
return patch_embed
|
|
|
|
if verbose:
|
|
logger.info(
|
|
f"Resize patch embedding {patch_embed.shape} to {new_size}, w/ {interpolation} interpolation."
|
|
)
|
|
|
|
def resize(x_np, _new_size):
|
|
x_tf = torch.Tensor(x_np)[None, None, ...]
|
|
x_upsampled = F.interpolate(
|
|
x_tf, size=_new_size, mode=interpolation, antialias=antialias
|
|
)[0, 0, ...].numpy()
|
|
return x_upsampled
|
|
|
|
def get_resize_mat(_old_size, _new_size):
|
|
mat = []
|
|
for i in range(np.prod(_old_size)):
|
|
basis_vec = np.zeros(_old_size)
|
|
basis_vec[np.unravel_index(i, _old_size)] = 1.0
|
|
mat.append(resize(basis_vec, _new_size).reshape(-1))
|
|
return np.stack(mat).T
|
|
|
|
resize_mat = get_resize_mat(old_size, new_size)
|
|
resize_mat_pinv = torch.tensor(
|
|
np.linalg.pinv(resize_mat.T), device=patch_embed.device
|
|
)
|
|
|
|
def resample_kernel(kernel):
|
|
resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)
|
|
return resampled_kernel.reshape(new_size)
|
|
|
|
v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1)
|
|
orig_dtype = patch_embed.dtype
|
|
patch_embed = patch_embed.float()
|
|
patch_embed = v_resample_kernel(patch_embed)
|
|
patch_embed = patch_embed.to(orig_dtype)
|
|
return patch_embed
|
|
|
|
|
|
# Copied from:
|
|
# https://github.com/deepseek-ai/Janus/tree/main/janus/models/siglip_vit.py
|
|
class PatchEmbed(nn.Module):
|
|
"""2D Image to Patch Embedding"""
|
|
|
|
output_fmt: Format
|
|
dynamic_img_pad: torch.jit.Final[bool]
|
|
|
|
def __init__(
|
|
self,
|
|
img_size: Optional[int] = 224,
|
|
patch_size: int = 16,
|
|
in_chans: int = 3,
|
|
embed_dim: int = 768,
|
|
norm_layer: Optional[Callable] = None,
|
|
flatten: bool = True,
|
|
output_fmt: Optional[str] = None,
|
|
bias: bool = True,
|
|
strict_img_size: bool = True,
|
|
dynamic_img_pad: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.patch_size = tuple(to_2tuple(patch_size))
|
|
self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size)
|
|
|
|
if output_fmt is not None:
|
|
self.flatten = False
|
|
self.output_fmt = Format(output_fmt)
|
|
else:
|
|
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
|
self.flatten = flatten
|
|
self.output_fmt = Format.NCHW
|
|
self.strict_img_size = strict_img_size
|
|
self.dynamic_img_pad = dynamic_img_pad
|
|
|
|
self.proj = nn.Conv2d(
|
|
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias
|
|
)
|
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
|
|
|
def _init_img_size(self, img_size: Union[int, Tuple[int, int]]):
|
|
assert self.patch_size
|
|
if img_size is None:
|
|
return None, None, None
|
|
img_size = to_2tuple(img_size)
|
|
grid_size = tuple([s // p for s, p in zip(img_size, self.patch_size)])
|
|
num_patches = grid_size[0] * grid_size[1]
|
|
return img_size, grid_size, num_patches
|
|
|
|
def set_input_size(
|
|
self,
|
|
img_size: Optional[Union[int, Tuple[int, int]]] = None,
|
|
patch_size: Optional[Union[int, Tuple[int, int]]] = None,
|
|
):
|
|
new_patch_size = None
|
|
if patch_size is not None:
|
|
new_patch_size = to_2tuple(patch_size)
|
|
if new_patch_size is not None and new_patch_size != self.patch_size:
|
|
with torch.no_grad():
|
|
new_proj = nn.Conv2d(
|
|
self.proj.in_channels,
|
|
self.proj.out_channels,
|
|
kernel_size=new_patch_size,
|
|
stride=new_patch_size,
|
|
bias=self.proj.bias is not None,
|
|
)
|
|
new_proj.weight.copy_(
|
|
resample_patch_embed(self.proj.weight, new_patch_size, verbose=True)
|
|
)
|
|
if self.proj.bias is not None:
|
|
new_proj.bias.copy_(self.proj.bias)
|
|
self.proj = new_proj
|
|
self.patch_size = new_patch_size
|
|
img_size = img_size or self.img_size
|
|
if img_size != self.img_size or new_patch_size is not None:
|
|
self.img_size, self.grid_size, self.num_patches = self._init_img_size(
|
|
img_size
|
|
)
|
|
|
|
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
|
|
if as_scalar:
|
|
return max(self.patch_size)
|
|
else:
|
|
return self.patch_size
|
|
|
|
def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
|
|
"""Get grid (feature) size for given image size taking account of dynamic padding.
|
|
NOTE: must be torchscript compatible so using fixed tuple indexing
|
|
"""
|
|
if self.dynamic_img_pad:
|
|
return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(
|
|
img_size[1] / self.patch_size[1]
|
|
)
|
|
else:
|
|
return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]
|
|
|
|
def forward(self, x):
|
|
B, C, H, W = x.shape
|
|
if self.img_size is not None:
|
|
if self.strict_img_size:
|
|
_assert(
|
|
H == self.img_size[0],
|
|
f"Input height ({H}) doesn't match model ({self.img_size[0]}).",
|
|
)
|
|
_assert(
|
|
W == self.img_size[1],
|
|
f"Input width ({W}) doesn't match model ({self.img_size[1]}).",
|
|
)
|
|
elif not self.dynamic_img_pad:
|
|
_assert(
|
|
H % self.patch_size[0] == 0,
|
|
f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]}).",
|
|
)
|
|
_assert(
|
|
W % self.patch_size[1] == 0,
|
|
f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]}).",
|
|
)
|
|
if self.dynamic_img_pad:
|
|
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
|
|
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
|
|
x = F.pad(x, (0, pad_w, 0, pad_h))
|
|
x = self.proj(x)
|
|
if self.flatten:
|
|
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
|
elif self.output_fmt != Format.NCHW:
|
|
x = nchw_to(x, self.output_fmt)
|
|
x = self.norm(x)
|
|
return x
|
|
|
|
|
|
class Mlp(nn.Module):
|
|
"""MLP as used in Vision Transformer, MLP-Mixer and related networks
|
|
|
|
NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_features,
|
|
hidden_features=None,
|
|
out_features=None,
|
|
act_layer=nn.GELU,
|
|
norm_layer=None,
|
|
bias=True,
|
|
drop=0.0,
|
|
use_conv=False,
|
|
):
|
|
super().__init__()
|
|
out_features = out_features or in_features
|
|
hidden_features = hidden_features or in_features
|
|
bias = to_2tuple(bias)
|
|
drop_probs = to_2tuple(drop)
|
|
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
|
|
|
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
|
self.act = act_layer()
|
|
self.drop1 = nn.Dropout(drop_probs[0])
|
|
self.norm = (
|
|
norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
|
)
|
|
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
|
self.drop2 = nn.Dropout(drop_probs[1])
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.act(x)
|
|
x = self.drop1(x)
|
|
x = self.norm(x)
|
|
x = self.fc2(x)
|
|
x = self.drop2(x)
|
|
return x
|
|
|
|
|
|
def drop_path(
|
|
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
|
):
|
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
|
|
|
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
|
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
|
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
|
'survival rate' as the argument.
|
|
|
|
"""
|
|
if drop_prob == 0.0 or not training:
|
|
return x
|
|
keep_prob = 1 - drop_prob
|
|
shape = (x.shape[0],) + (1,) * (
|
|
x.ndim - 1
|
|
) # work with diff dim tensors, not just 2D ConvNets
|
|
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
|
if keep_prob > 0.0 and scale_by_keep:
|
|
random_tensor.div_(keep_prob)
|
|
return x * random_tensor
|
|
|
|
|
|
class DropPath(nn.Module):
|
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
|
|
|
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
|
super(DropPath, self).__init__()
|
|
self.drop_prob = drop_prob
|
|
self.scale_by_keep = scale_by_keep
|
|
|
|
def forward(self, x):
|
|
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
|
|
|
def extra_repr(self):
|
|
return f"drop_prob={round(self.drop_prob, 3):0.3f}"
|
|
|
|
|
|
class VisionTransformerBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
num_heads: int,
|
|
mlp_ratio: float = 4.0,
|
|
qkv_bias: bool = False,
|
|
qk_norm: bool = False,
|
|
proj_drop: float = 0.0,
|
|
attn_drop: float = 0.0,
|
|
init_values: Optional[float] = None,
|
|
drop_path: float = 0.0,
|
|
act_layer: nn.Module = nn.GELU,
|
|
norm_layer: nn.Module = nn.LayerNorm,
|
|
mlp_layer: nn.Module = Mlp,
|
|
) -> None:
|
|
super().__init__()
|
|
self.norm1 = norm_layer(dim)
|
|
self.attn = VisionAttention(
|
|
embed_dim=dim,
|
|
num_heads=num_heads,
|
|
projection_size=dim,
|
|
use_qkv_parallel=True,
|
|
use_context_forward=False,
|
|
softmax_in_single_precision=False,
|
|
dropout=attn_drop,
|
|
)
|
|
|
|
self.ls1 = (
|
|
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
|
)
|
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
self.mlp = mlp_layer(
|
|
in_features=dim,
|
|
hidden_features=int(dim * mlp_ratio),
|
|
act_layer=act_layer,
|
|
drop=proj_drop,
|
|
)
|
|
self.ls2 = (
|
|
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
|
)
|
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
|
|
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
|
return x
|
|
|
|
|
|
LayerType = Union[str, Callable, Type[torch.nn.Module]]
|
|
|
|
|
|
class PatchDropout(nn.Module):
|
|
"""
|
|
https://arxiv.org/abs/2212.00794 and https://arxiv.org/pdf/2208.07220
|
|
"""
|
|
|
|
return_indices: torch.jit.Final[bool]
|
|
|
|
def __init__(
|
|
self,
|
|
prob: float = 0.5,
|
|
num_prefix_tokens: int = 1,
|
|
ordered: bool = False,
|
|
return_indices: bool = False,
|
|
):
|
|
super().__init__()
|
|
assert 0 <= prob < 1.0
|
|
self.prob = prob
|
|
self.num_prefix_tokens = (
|
|
num_prefix_tokens # exclude CLS token (or other prefix tokens)
|
|
)
|
|
self.ordered = ordered
|
|
self.return_indices = return_indices
|
|
|
|
def forward(
|
|
self, x
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
|
if not self.training or self.prob == 0.0:
|
|
if self.return_indices:
|
|
return x, None
|
|
return x
|
|
|
|
if self.num_prefix_tokens:
|
|
prefix_tokens, x = (
|
|
x[:, : self.num_prefix_tokens],
|
|
x[:, self.num_prefix_tokens :],
|
|
)
|
|
else:
|
|
prefix_tokens = None
|
|
|
|
B = x.shape[0]
|
|
L = x.shape[1]
|
|
num_keep = max(1, int(L * (1.0 - self.prob)))
|
|
keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[
|
|
:, :num_keep
|
|
]
|
|
if self.ordered:
|
|
# NOTE does not need to maintain patch order in typical transformer use,
|
|
# but possibly useful for debug / visualization
|
|
keep_indices = keep_indices.sort(dim=-1)[0]
|
|
x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:]))
|
|
|
|
if prefix_tokens is not None:
|
|
x = torch.cat((prefix_tokens, x), dim=1)
|
|
|
|
if self.return_indices:
|
|
return x, keep_indices
|
|
return x
|
|
|
|
|
|
def resample_abs_pos_embed(
|
|
posemb: torch.Tensor,
|
|
new_size: List[int],
|
|
old_size: Optional[List[int]] = None,
|
|
num_prefix_tokens: int = 1,
|
|
interpolation: str = "bicubic",
|
|
antialias: bool = True,
|
|
verbose: bool = False,
|
|
):
|
|
# sort out sizes, assume square if old size not provided
|
|
num_pos_tokens = posemb.shape[1]
|
|
num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens
|
|
if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
|
|
return posemb
|
|
|
|
if old_size is None:
|
|
hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
|
|
old_size = hw, hw
|
|
|
|
if num_prefix_tokens:
|
|
posemb_prefix, posemb = (
|
|
posemb[:, :num_prefix_tokens],
|
|
posemb[:, num_prefix_tokens:],
|
|
)
|
|
else:
|
|
posemb_prefix, posemb = None, posemb
|
|
|
|
# do the interpolation
|
|
embed_dim = posemb.shape[-1]
|
|
orig_dtype = posemb.dtype
|
|
posemb = posemb.float() # interpolate needs float32
|
|
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
|
|
posemb = F.interpolate(
|
|
posemb, size=new_size, mode=interpolation, antialias=antialias
|
|
)
|
|
posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)
|
|
posemb = posemb.to(orig_dtype)
|
|
|
|
# add back extra (class, etc) prefix tokens
|
|
if posemb_prefix is not None:
|
|
posemb = torch.cat([posemb_prefix, posemb], dim=1)
|
|
|
|
if not torch.jit.is_scripting() and verbose:
|
|
logger.info(f"Resized position embedding: {old_size} to {new_size}.")
|
|
|
|
return posemb
|
|
|
|
|
|
def init_weights(self):
|
|
if self.pos_embed is not None:
|
|
trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
|
|
trunc_normal_(self.latent, std=self.latent_dim**-0.5)
|
|
|
|
|
|
def init_weights_vit_timm(module: nn.Module, name: str = "") -> None:
|
|
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
|
if isinstance(module, nn.Linear):
|
|
trunc_normal_(module.weight, std=0.02)
|
|
if module.bias is not None:
|
|
nn.init.zeros_(module.bias)
|
|
elif hasattr(module, "init_weights"):
|
|
module.init_weights()
|
|
|
|
|
|
class VisionTransformer(nn.Module):
|
|
"""Vision Transformer
|
|
|
|
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
|
- https://arxiv.org/abs/2010.11929
|
|
"""
|
|
|
|
dynamic_img_size: Final[bool]
|
|
|
|
def __init__(
|
|
self,
|
|
img_size: Union[int, Tuple[int, int]] = 224,
|
|
patch_size: Union[int, Tuple[int, int]] = 16,
|
|
in_chans: int = 3,
|
|
num_classes: int = 1000,
|
|
global_pool: Literal["", "avg", "token", "map"] = "token",
|
|
embed_dim: int = 768,
|
|
depth: int = 12,
|
|
num_heads: int = 12,
|
|
mlp_ratio: float = 4.0,
|
|
qkv_bias: bool = True,
|
|
qk_norm: bool = False,
|
|
init_values: Optional[float] = None,
|
|
class_token: bool = True,
|
|
no_embed_class: bool = False,
|
|
reg_tokens: int = 0,
|
|
pre_norm: bool = False,
|
|
fc_norm: Optional[bool] = None,
|
|
dynamic_img_size: bool = False,
|
|
dynamic_img_pad: bool = False,
|
|
drop_rate: float = 0.0,
|
|
pos_drop_rate: float = 0.0,
|
|
patch_drop_rate: float = 0.0,
|
|
proj_drop_rate: float = 0.0,
|
|
attn_drop_rate: float = 0.0,
|
|
drop_path_rate: float = 0.0,
|
|
weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "",
|
|
embed_layer: Callable = PatchEmbed,
|
|
_norm_layer: Optional[LayerType] = None,
|
|
_act_layer: Optional[LayerType] = None,
|
|
block_fn: Type[nn.Module] = VisionTransformerBlock,
|
|
mlp_layer: Type[nn.Module] = Mlp,
|
|
ignore_head: bool = False,
|
|
) -> None:
|
|
"""
|
|
Args:
|
|
img_size: Input image size.
|
|
patch_size: Patch size.
|
|
in_chans: Number of image input channels.
|
|
num_classes: Mumber of classes for classification head.
|
|
global_pool: Type of global pooling for final sequence (default: 'token').
|
|
embed_dim: Transformer embedding dimension.
|
|
depth: Depth of transformer.
|
|
num_heads: Number of attention heads.
|
|
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
|
qkv_bias: Enable bias for qkv projections if True.
|
|
init_values: Layer-scale init values (layer-scale enabled if not None).
|
|
class_token: Use class token.
|
|
no_embed_class: Don't include position embeddings for class (or reg) tokens.
|
|
reg_tokens: Number of register tokens.
|
|
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
|
|
drop_rate: Head dropout rate.
|
|
pos_drop_rate: Position embedding dropout rate.
|
|
attn_drop_rate: Attention dropout rate.
|
|
drop_path_rate: Stochastic depth rate.
|
|
weight_init: Weight initialization scheme.
|
|
embed_layer: Patch embedding layer.
|
|
_norm_layer: Normalization layer.
|
|
_act_layer: MLP activation layer.
|
|
block_fn: Transformer block layer.
|
|
"""
|
|
super().__init__()
|
|
assert global_pool in ("", "avg", "token", "map")
|
|
assert class_token or global_pool != "token"
|
|
use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
|
|
# norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
|
|
# act_layer = get_act_layer(act_layer) or nn.GELU
|
|
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
|
act_layer = nn.GELU
|
|
|
|
self.num_classes = num_classes
|
|
self.global_pool = global_pool
|
|
self.num_features = self.embed_dim = (
|
|
embed_dim # num_features for consistency with other models
|
|
)
|
|
self.num_prefix_tokens = 1 if class_token else 0
|
|
self.num_prefix_tokens += reg_tokens
|
|
self.num_reg_tokens = reg_tokens
|
|
self.has_class_token = class_token
|
|
self.no_embed_class = (
|
|
no_embed_class # don't embed prefix positions (includes reg)
|
|
)
|
|
self.dynamic_img_size = dynamic_img_size
|
|
self.grad_checkpointing = False
|
|
self.ignore_head = ignore_head
|
|
|
|
embed_args = {}
|
|
if dynamic_img_size:
|
|
# flatten deferred until after pos embed
|
|
embed_args.update(dict(strict_img_size=False, output_fmt="NHWC"))
|
|
self.patch_embed = embed_layer(
|
|
img_size=img_size,
|
|
patch_size=patch_size,
|
|
in_chans=in_chans,
|
|
embed_dim=embed_dim,
|
|
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
|
|
dynamic_img_pad=dynamic_img_pad,
|
|
**embed_args,
|
|
)
|
|
num_patches = self.patch_embed.num_patches
|
|
|
|
self.cls_token = (
|
|
nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
|
)
|
|
self.reg_token = (
|
|
nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
|
|
)
|
|
embed_len = (
|
|
num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
|
)
|
|
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
|
|
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
|
if patch_drop_rate > 0:
|
|
self.patch_drop = PatchDropout(
|
|
patch_drop_rate,
|
|
num_prefix_tokens=self.num_prefix_tokens,
|
|
)
|
|
else:
|
|
self.patch_drop = nn.Identity()
|
|
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
|
|
|
|
dpr = [
|
|
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
|
] # stochastic depth decay rule
|
|
self.blocks = nn.Sequential(
|
|
*[
|
|
block_fn(
|
|
dim=embed_dim,
|
|
num_heads=num_heads,
|
|
mlp_ratio=mlp_ratio,
|
|
qkv_bias=qkv_bias,
|
|
qk_norm=qk_norm,
|
|
init_values=init_values,
|
|
proj_drop=proj_drop_rate,
|
|
attn_drop=attn_drop_rate,
|
|
drop_path=dpr[i],
|
|
norm_layer=norm_layer,
|
|
act_layer=act_layer,
|
|
mlp_layer=mlp_layer,
|
|
)
|
|
for i in range(depth)
|
|
]
|
|
)
|
|
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
|
|
|
# Classifier Head
|
|
if global_pool == "map":
|
|
AttentionPoolLatent.init_weights = init_weights
|
|
self.attn_pool = AttentionPoolLatent(
|
|
self.embed_dim,
|
|
num_heads=num_heads,
|
|
mlp_ratio=mlp_ratio,
|
|
norm_layer=norm_layer,
|
|
)
|
|
else:
|
|
self.attn_pool = None
|
|
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
|
|
self.head_drop = nn.Dropout(drop_rate)
|
|
self.head = (
|
|
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
|
)
|
|
|
|
if weight_init != "skip":
|
|
self.init_weights(weight_init)
|
|
|
|
def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None:
|
|
assert mode in ("jax", "jax_nlhb", "moco", "")
|
|
# head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
|
|
trunc_normal_(self.pos_embed, std=0.02)
|
|
if self.cls_token is not None:
|
|
nn.init.normal_(self.cls_token, std=1e-6)
|
|
named_apply(init_weights_vit_timm, self)
|
|
|
|
@torch.jit.ignore
|
|
def no_weight_decay(self) -> Set:
|
|
return {"pos_embed", "cls_token", "dist_token"}
|
|
|
|
@torch.jit.ignore
|
|
def group_matcher(self, coarse: bool = False) -> Dict:
|
|
return dict(
|
|
stem=r"^cls_token|pos_embed|patch_embed", # stem and embed
|
|
blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))],
|
|
)
|
|
|
|
@torch.jit.ignore
|
|
def get_classifier(self) -> nn.Module:
|
|
return self.head
|
|
|
|
def reset_classifier(self, num_classes: int, global_pool=None) -> None:
|
|
self.num_classes = num_classes
|
|
if global_pool is not None:
|
|
assert global_pool in ("", "avg", "token", "map")
|
|
if global_pool == "map" and self.attn_pool is None:
|
|
assert (
|
|
False
|
|
), "Cannot currently add attention pooling in reset_classifier()."
|
|
elif global_pool != "map " and self.attn_pool is not None:
|
|
self.attn_pool = None # remove attention pooling
|
|
self.global_pool = global_pool
|
|
self.head = (
|
|
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
|
)
|
|
|
|
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
|
if self.dynamic_img_size:
|
|
B, H, W, C = x.shape
|
|
pos_embed = resample_abs_pos_embed(
|
|
self.pos_embed,
|
|
[H, W],
|
|
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
|
|
)
|
|
x = x.view(B, -1, C)
|
|
else:
|
|
pos_embed = self.pos_embed
|
|
|
|
to_cat = []
|
|
if self.cls_token is not None:
|
|
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
|
|
if self.reg_token is not None:
|
|
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
|
|
|
|
if self.no_embed_class:
|
|
# deit-3, updated JAX (big vision)
|
|
# position embedding does not overlap with class token, add then concat
|
|
x = x + pos_embed
|
|
if to_cat:
|
|
x = torch.cat(to_cat + [x], dim=1)
|
|
else:
|
|
# original timm, JAX, and deit vit impl
|
|
# pos_embed has entry for class token, concat then add
|
|
if to_cat:
|
|
x = torch.cat(to_cat + [x], dim=1)
|
|
x = x + pos_embed
|
|
|
|
return self.pos_drop(x)
|
|
|
|
def _intermediate_layers(
|
|
self,
|
|
x: torch.Tensor,
|
|
n: Union[int, Sequence] = 1,
|
|
) -> List[torch.Tensor]:
|
|
outputs, num_blocks = [], len(self.blocks)
|
|
take_indices = set(
|
|
range(num_blocks - n, num_blocks) if isinstance(n, int) else n
|
|
)
|
|
|
|
# forward pass
|
|
x = self.patch_embed(x)
|
|
x = self._pos_embed(x)
|
|
x = self.patch_drop(x)
|
|
x = self.norm_pre(x)
|
|
for i, blk in enumerate(self.blocks):
|
|
x = blk(x)
|
|
if i in take_indices:
|
|
outputs.append(x)
|
|
|
|
return outputs
|
|
|
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.patch_embed(x)
|
|
x = self._pos_embed(x)
|
|
x = self.patch_drop(x)
|
|
x = self.norm_pre(x)
|
|
x = self.blocks(x)
|
|
x = self.norm(x)
|
|
return x
|
|
|
|
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
|
if self.attn_pool is not None:
|
|
x = self.attn_pool(x)
|
|
elif self.global_pool == "avg":
|
|
x = x[:, self.num_prefix_tokens :].mean(dim=1)
|
|
elif self.global_pool:
|
|
x = x[:, 0] # class token
|
|
x = self.fc_norm(x)
|
|
x = self.head_drop(x)
|
|
return x if pre_logits else self.head(x)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.forward_features(x)
|
|
if not self.ignore_head:
|
|
x = self.forward_head(x)
|
|
return x
|
|
|
|
|
|
def model_name_to_cls(cls_name):
|
|
if "MlpProjector" in cls_name:
|
|
cls = MlpProjector
|
|
|
|
elif "CLIPVisionTower" in cls_name:
|
|
cls = CLIPVisionTower
|
|
|
|
elif "VQ" in cls_name:
|
|
|
|
cls = VQ_models[cls_name]
|
|
elif "vision_head" in cls_name:
|
|
cls = vision_head
|
|
else:
|
|
raise ValueError(f"class_name {cls_name} is invalid.")
|
|
|
|
return cls
|
|
|
|
|
|
class vision_head(torch.nn.Module):
|
|
def __init__(self, params):
|
|
super().__init__()
|
|
self.output_mlp_projector = torch.nn.Linear(
|
|
params["n_embed"], params["image_token_embed"]
|
|
)
|
|
self.vision_activation = torch.nn.GELU()
|
|
self.vision_head = torch.nn.Linear(
|
|
params["image_token_embed"], params["image_token_size"]
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.output_mlp_projector(x)
|
|
x = self.vision_activation(x)
|
|
x = self.vision_head(x)
|
|
return x
|
|
|
|
|
|
SigLIP_MODEL_CONFIG = {
|
|
"siglip_so400m_patch14_384": {
|
|
"image_size": 336,
|
|
"patch_size": 14,
|
|
"width": 1152,
|
|
"layers": 27,
|
|
"heads": 16,
|
|
"mlp_ratio": 3.7362,
|
|
"global_pool": "map",
|
|
"use_checkpoint": False,
|
|
},
|
|
"siglip_so400m_patch14_224": {
|
|
"image_size": 224,
|
|
"patch_size": 14,
|
|
"width": 1152,
|
|
"layers": 27,
|
|
"heads": 16,
|
|
"mlp_ratio": 3.7362,
|
|
"global_pool": "map",
|
|
"use_checkpoint": False,
|
|
},
|
|
"siglip_large_patch16_384": {
|
|
"image_size": 384,
|
|
"patch_size": 16,
|
|
"width": 1024,
|
|
"layers": 24,
|
|
"heads": 16,
|
|
"mlp_ratio": 4,
|
|
"global_pool": "map",
|
|
"use_checkpoint": False,
|
|
},
|
|
}
|
|
|
|
|
|
def create_siglip_vit(
|
|
model_name: str = "siglip_so400m_patch14_384",
|
|
image_size: int = 384,
|
|
select_layer: int = -1,
|
|
ckpt_path: str = "",
|
|
**kwargs,
|
|
):
|
|
assert (
|
|
model_name in SigLIP_MODEL_CONFIG.keys()
|
|
), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}"
|
|
|
|
vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name])
|
|
|
|
if select_layer <= 0:
|
|
layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1)
|
|
else:
|
|
layers = min(vision_cfg.layers, select_layer)
|
|
|
|
model = VisionTransformer(
|
|
img_size=image_size,
|
|
patch_size=vision_cfg.patch_size,
|
|
embed_dim=vision_cfg.width,
|
|
depth=layers,
|
|
num_heads=vision_cfg.heads,
|
|
mlp_ratio=vision_cfg.mlp_ratio,
|
|
class_token=vision_cfg.class_token,
|
|
global_pool=vision_cfg.global_pool,
|
|
ignore_head=kwargs.get("ignore_head", True),
|
|
weight_init=kwargs.get("weight_init", "skip"),
|
|
num_classes=0,
|
|
)
|
|
|
|
if ckpt_path:
|
|
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
|
|
|
incompatible_keys = model.load_state_dict(state_dict, strict=False)
|
|
print(
|
|
f"SigLIP-ViT restores from {ckpt_path},\n"
|
|
f"\tincompatible_keys:', {incompatible_keys}."
|
|
)
|
|
|
|
return model
|
|
|
|
|
|
class Normalize(torch.nn.Module):
|
|
"""Normalize a tensor image with mean and standard deviation.
|
|
This transform does not support PIL Image.
|
|
Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
|
|
channels, this transform will normalize each channel of the input
|
|
``torch.*Tensor`` i.e.,
|
|
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
|
|
|
|
.. note::
|
|
This transform acts out of place, i.e., it does not mutate the input tensor.
|
|
|
|
Args:
|
|
mean (sequence): Sequence of means for each channel.
|
|
std (sequence): Sequence of standard deviations for each channel.
|
|
inplace(bool,optional): Bool to make this operation in-place.
|
|
|
|
"""
|
|
|
|
def __init__(self, mean, std, inplace=False):
|
|
super().__init__()
|
|
# _log_api_usage_once(self)
|
|
self.mean = mean
|
|
self.std = std
|
|
self.inplace = inplace
|
|
|
|
def forward(self, tensor: Tensor) -> Tensor:
|
|
"""
|
|
Args:
|
|
tensor (Tensor): Tensor image to be normalized.
|
|
|
|
Returns:
|
|
Tensor: Normalized Tensor image.
|
|
"""
|
|
return F.normalize(tensor, self.mean, self.std, self.inplace)
|
|
|
|
def __repr__(self) -> str:
|
|
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"
|
|
|
|
|
|
class CLIPVisionTower(nn.Module):
|
|
def __init__(
|
|
self,
|
|
model_name: str = "siglip_large_patch16_384",
|
|
image_size: Union[Tuple[int, int], int] = 336,
|
|
select_feature: str = "patch",
|
|
select_layer: int = -2,
|
|
select_layers: list = None,
|
|
ckpt_path: str = "",
|
|
pixel_mean: Optional[List[float]] = None,
|
|
pixel_std: Optional[List[float]] = None,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
|
|
self.model_name = model_name
|
|
self.select_feature = select_feature
|
|
self.select_layer = select_layer
|
|
self.select_layers = select_layers
|
|
|
|
vision_tower_params = {
|
|
"model_name": model_name,
|
|
"image_size": image_size,
|
|
"ckpt_path": ckpt_path,
|
|
"select_layer": select_layer,
|
|
}
|
|
vision_tower_params.update(kwargs)
|
|
self.vision_tower, self.forward_kwargs = self.build_vision_tower(
|
|
vision_tower_params
|
|
)
|
|
|
|
if pixel_mean is not None and pixel_std is not None:
|
|
image_norm = Normalize(mean=pixel_mean, std=pixel_std)
|
|
else:
|
|
image_norm = None
|
|
|
|
self.image_norm = image_norm
|
|
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return next(self.vision_tower.parameters()).device
|
|
|
|
@property
|
|
def dtype(self):
|
|
return next(self.vision_tower.parameters()).dtype
|
|
|
|
def build_vision_tower(self, vision_tower_params):
|
|
if self.model_name.startswith("siglip"):
|
|
self.select_feature = "same"
|
|
vision_tower = create_siglip_vit(**vision_tower_params)
|
|
forward_kwargs = dict()
|
|
|
|
elif self.model_name.startswith("sam"):
|
|
# vision_tower = create_sam_vit(**vision_tower_params)
|
|
forward_kwargs = dict()
|
|
|
|
else: # huggingface
|
|
from transformers import CLIPVisionModel
|
|
|
|
vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params)
|
|
forward_kwargs = dict(output_hidden_states=True)
|
|
|
|
return vision_tower, forward_kwargs
|
|
|
|
def feature_select(self, image_forward_outs):
|
|
if isinstance(image_forward_outs, torch.Tensor):
|
|
# the output has been the self.select_layer"s features
|
|
image_features = image_forward_outs
|
|
else:
|
|
image_features = image_forward_outs.hidden_states[self.select_layer]
|
|
|
|
if self.select_feature == "patch":
|
|
# if the output has cls_token
|
|
image_features = image_features[:, 1:]
|
|
elif self.select_feature == "cls_patch":
|
|
image_features = image_features
|
|
elif self.select_feature == "same":
|
|
image_features = image_features
|
|
|
|
else:
|
|
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
|
return image_features
|
|
|
|
def forward(self, images):
|
|
"""
|
|
|
|
Args:
|
|
images (torch.Tensor): [b, 3, H, W]
|
|
|
|
Returns:
|
|
image_features (torch.Tensor): [b, n_patch, d]
|
|
"""
|
|
|
|
if self.image_norm is not None:
|
|
images = self.image_norm(images)
|
|
|
|
image_forward_outs = self.vision_tower(images, **self.forward_kwargs)
|
|
image_features = self.feature_select(image_forward_outs)
|
|
return image_features
|
|
|
|
|
|
class MlpProjector(nn.Module):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
|
|
self.cfg = cfg
|
|
|
|
if cfg["projector_type"] == "identity":
|
|
modules = nn.Identity()
|
|
|
|
elif cfg["projector_type"] == "linear":
|
|
modules = nn.Linear(cfg["input_dim"], cfg["n_embed"])
|
|
|
|
elif cfg["projector_type"] == "mlp_gelu":
|
|
mlp_depth = cfg.get("depth", 1)
|
|
modules = [nn.Linear(cfg["input_dim"], cfg["n_embed"])]
|
|
for _ in range(1, mlp_depth):
|
|
modules.append(nn.GELU())
|
|
modules.append(nn.Linear(cfg["n_embed"], cfg["n_embed"]))
|
|
modules = nn.Sequential(*modules)
|
|
|
|
elif cfg["projector_type"] == "low_high_hybrid_split_mlp_gelu":
|
|
mlp_depth = cfg.get("depth", 1)
|
|
self.high_up_proj = nn.Linear(cfg["input_dim"], cfg["n_embed"] // 2)
|
|
self.low_up_proj = nn.Linear(cfg["input_dim"], cfg["n_embed"] // 2)
|
|
|
|
modules = []
|
|
for _ in range(1, mlp_depth):
|
|
modules.append(nn.GELU())
|
|
modules.append(nn.Linear(cfg["n_embed"], cfg["n_embed"]))
|
|
modules = nn.Sequential(*modules)
|
|
|
|
else:
|
|
raise ValueError(f"Unknown projector type: {cfg['projector_type']}")
|
|
|
|
self.layers = modules
|
|
|
|
def forward(
|
|
self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]
|
|
):
|
|
"""
|
|
|
|
Args:
|
|
x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor,
|
|
then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x);
|
|
otherwise it is the feature from the single vision encoder.
|
|
|
|
Returns:
|
|
x (torch.Tensor): [b, s, c]
|
|
"""
|
|
|
|
if isinstance(x_or_tuple, tuple):
|
|
# self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
|
|
high_x, low_x = x_or_tuple
|
|
high_x = self.high_up_proj(high_x)
|
|
low_x = self.low_up_proj(low_x)
|
|
x = torch.cat([high_x, low_x], dim=-1)
|
|
else:
|
|
x = x_or_tuple
|
|
|
|
return self.layers(x)
|
|
|
|
|
|
class LayerScale(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
init_values: float = 1e-5,
|
|
inplace: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
self.inplace = inplace
|
|
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
|
|
|
|
|
# use torch.scaled_dot_product_attention where possible
|
|
_HAS_FUSED_ATTN = hasattr(torch.nn.functional, "scaled_dot_product_attention")
|
|
if "TIMM_FUSED_ATTN" in os.environ:
|
|
_USE_FUSED_ATTN = int(os.environ["TIMM_FUSED_ATTN"])
|
|
else:
|
|
_USE_FUSED_ATTN = (
|
|
1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use)
|
|
)
|
|
|
|
# Set to True if exporting a model with Same padding via ONNX
|
|
_EXPORTABLE = False
|
|
|
|
|
|
def use_fused_attn(experimental: bool = False) -> bool:
|
|
# NOTE: ONNX export cannot handle F.scaled_dot_product_attention as of pytorch 2.0
|
|
if not _HAS_FUSED_ATTN or _EXPORTABLE:
|
|
return False
|
|
if experimental:
|
|
return _USE_FUSED_ATTN > 1
|
|
return _USE_FUSED_ATTN > 0
|
|
|
|
|
|
class AttentionPoolLatent(nn.Module):
|
|
"""Attention pooling w/ latent query"""
|
|
|
|
fused_attn: torch.jit.Final[bool]
|
|
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
out_features: int = None,
|
|
embed_dim: int = None,
|
|
num_heads: int = 8,
|
|
feat_size: Optional[int] = None,
|
|
mlp_ratio: float = 4.0,
|
|
qkv_bias: bool = True,
|
|
qk_norm: bool = False,
|
|
latent_len: int = 1,
|
|
latent_dim: int = None,
|
|
pos_embed: str = "",
|
|
pool_type: str = "token",
|
|
norm_layer: Optional[nn.Module] = None,
|
|
drop: float = 0.0,
|
|
):
|
|
super().__init__()
|
|
embed_dim = embed_dim or in_features
|
|
out_features = out_features or in_features
|
|
assert embed_dim % num_heads == 0
|
|
self.num_heads = num_heads
|
|
self.head_dim = embed_dim // num_heads
|
|
self.feat_size = feat_size
|
|
self.scale = self.head_dim**-0.5
|
|
self.pool = pool_type
|
|
self.fused_attn = use_fused_attn()
|
|
|
|
if pos_embed == "abs":
|
|
assert feat_size is not None
|
|
self.pos_embed = nn.Parameter(torch.zeros(feat_size, in_features))
|
|
else:
|
|
self.pos_embed = None
|
|
|
|
self.latent_dim = latent_dim or embed_dim
|
|
self.latent_len = latent_len
|
|
self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim))
|
|
|
|
self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
|
|
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
|
|
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
self.proj = nn.Linear(embed_dim, embed_dim)
|
|
self.proj_drop = nn.Dropout(drop)
|
|
|
|
self.norm = (
|
|
norm_layer(out_features) if norm_layer is not None else nn.Identity()
|
|
)
|
|
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio))
|
|
|
|
self.init_weights()
|
|
|
|
def init_weights(self):
|
|
if self.pos_embed is not None:
|
|
trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
|
|
trunc_normal_tf_(self.latent, std=self.latent_dim**-0.5)
|
|
|
|
def forward(self, x):
|
|
B, N, C = x.shape
|
|
|
|
if self.pos_embed is not None:
|
|
# FIXME interpolate
|
|
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
|
|
|
|
q_latent = self.latent.expand(B, -1, -1)
|
|
q = (
|
|
self.q(q_latent)
|
|
.reshape(B, self.latent_len, self.num_heads, self.head_dim)
|
|
.transpose(1, 2)
|
|
)
|
|
|
|
kv = (
|
|
self.kv(x)
|
|
.reshape(B, N, 2, self.num_heads, self.head_dim)
|
|
.permute(2, 0, 3, 1, 4)
|
|
)
|
|
k, v = kv.unbind(0)
|
|
|
|
q, k = self.q_norm(q), self.k_norm(k)
|
|
|
|
if self.fused_attn:
|
|
x = F.scaled_dot_product_attention(q, k, v)
|
|
else:
|
|
q = q * self.scale
|
|
attn = q @ k.transpose(-2, -1)
|
|
attn = attn.softmax(dim=-1)
|
|
x = attn @ v
|
|
x = x.transpose(1, 2).reshape(B, self.latent_len, C)
|
|
x = self.proj(x)
|
|
x = self.proj_drop(x)
|
|
|
|
x = x + self.mlp(self.norm(x))
|
|
|
|
# optional pool if latent seq_len > 1 and pooled output is desired
|
|
if self.pool == "token":
|
|
x = x[:, 0]
|
|
elif self.pool == "avg":
|
|
x = x.mean(1)
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels=3,
|
|
ch=128,
|
|
ch_mult=(1, 1, 2, 2, 4),
|
|
num_res_blocks=2,
|
|
norm_type="group",
|
|
dropout=0.0,
|
|
resamp_with_conv=True,
|
|
z_channels=256,
|
|
):
|
|
super().__init__()
|
|
self.num_resolutions = len(ch_mult)
|
|
self.num_res_blocks = num_res_blocks
|
|
self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1)
|
|
|
|
# downsampling
|
|
in_ch_mult = (1,) + tuple(ch_mult)
|
|
self.conv_blocks = nn.ModuleList()
|
|
for i_level in range(self.num_resolutions):
|
|
conv_block = nn.Module()
|
|
# res & attn
|
|
res_block = nn.ModuleList()
|
|
attn_block = nn.ModuleList()
|
|
block_in = ch * in_ch_mult[i_level]
|
|
block_out = ch * ch_mult[i_level]
|
|
for _ in range(self.num_res_blocks):
|
|
res_block.append(
|
|
ResnetBlock(
|
|
block_in, block_out, dropout=dropout, norm_type=norm_type
|
|
)
|
|
)
|
|
block_in = block_out
|
|
if i_level == self.num_resolutions - 1:
|
|
attn_block.append(AttnBlock(block_in, norm_type))
|
|
conv_block.res = res_block
|
|
conv_block.attn = attn_block
|
|
# downsample
|
|
if i_level != self.num_resolutions - 1:
|
|
conv_block.downsample = Downsample(block_in, resamp_with_conv)
|
|
self.conv_blocks.append(conv_block)
|
|
|
|
# middle
|
|
self.mid = nn.ModuleList()
|
|
self.mid.append(
|
|
ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
|
|
)
|
|
self.mid.append(AttnBlock(block_in, norm_type=norm_type))
|
|
self.mid.append(
|
|
ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
|
|
)
|
|
|
|
# end
|
|
self.norm_out = Normalize(block_in, norm_type)
|
|
self.conv_out = nn.Conv2d(
|
|
block_in, z_channels, kernel_size=3, stride=1, padding=1
|
|
)
|
|
|
|
def forward(self, x):
|
|
h = self.conv_in(x)
|
|
# downsampling
|
|
for i_level, block in enumerate(self.conv_blocks):
|
|
for i_block in range(self.num_res_blocks):
|
|
h = block.res[i_block](h)
|
|
if len(block.attn) > 0:
|
|
h = block.attn[i_block](h)
|
|
if i_level != self.num_resolutions - 1:
|
|
h = block.downsample(h)
|
|
|
|
# middle
|
|
for mid_block in self.mid:
|
|
h = mid_block(h)
|
|
|
|
# end
|
|
h = self.norm_out(h)
|
|
h = nonlinearity(h)
|
|
h = self.conv_out(h)
|
|
return h
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
z_channels=256,
|
|
ch=128,
|
|
ch_mult=(1, 1, 2, 2, 4),
|
|
num_res_blocks=2,
|
|
norm_type="group",
|
|
dropout=0.0,
|
|
resamp_with_conv=True,
|
|
out_channels=3,
|
|
):
|
|
super().__init__()
|
|
self.num_resolutions = len(ch_mult)
|
|
self.num_res_blocks = num_res_blocks
|
|
|
|
block_in = ch * ch_mult[self.num_resolutions - 1]
|
|
# z to block_in
|
|
self.conv_in = nn.Conv2d(
|
|
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
|
)
|
|
|
|
# middle
|
|
self.mid = nn.ModuleList()
|
|
self.mid.append(
|
|
ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
|
|
)
|
|
self.mid.append(AttnBlock(block_in, norm_type=norm_type))
|
|
self.mid.append(
|
|
ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
|
|
)
|
|
|
|
# upsampling
|
|
self.conv_blocks = nn.ModuleList()
|
|
for i_level in reversed(range(self.num_resolutions)):
|
|
conv_block = nn.Module()
|
|
# res & attn
|
|
res_block = nn.ModuleList()
|
|
attn_block = nn.ModuleList()
|
|
block_out = ch * ch_mult[i_level]
|
|
for _ in range(self.num_res_blocks + 1):
|
|
res_block.append(
|
|
ResnetBlock(
|
|
block_in, block_out, dropout=dropout, norm_type=norm_type
|
|
)
|
|
)
|
|
block_in = block_out
|
|
if i_level == self.num_resolutions - 1:
|
|
attn_block.append(AttnBlock(block_in, norm_type))
|
|
conv_block.res = res_block
|
|
conv_block.attn = attn_block
|
|
# downsample
|
|
if i_level != 0:
|
|
conv_block.upsample = Upsample(block_in, resamp_with_conv)
|
|
self.conv_blocks.append(conv_block)
|
|
|
|
# end
|
|
self.norm_out = Normalize(block_in, norm_type)
|
|
self.conv_out = nn.Conv2d(
|
|
block_in, out_channels, kernel_size=3, stride=1, padding=1
|
|
)
|
|
|
|
@property
|
|
def last_layer(self):
|
|
return self.conv_out.weight
|
|
|
|
def forward(self, z):
|
|
# z to block_in
|
|
h = self.conv_in(z)
|
|
|
|
# middle
|
|
for mid_block in self.mid:
|
|
h = mid_block(h)
|
|
|
|
# upsampling
|
|
for i_level, block in enumerate(self.conv_blocks):
|
|
for i_block in range(self.num_res_blocks + 1):
|
|
h = block.res[i_block](h)
|
|
if len(block.attn) > 0:
|
|
h = block.attn[i_block](h)
|
|
if i_level != self.num_resolutions - 1:
|
|
h = block.upsample(h)
|
|
|
|
# end
|
|
h = self.norm_out(h)
|
|
h = nonlinearity(h)
|
|
h = self.conv_out(h)
|
|
return h
|
|
|
|
|
|
class VectorQuantizer(nn.Module):
|
|
def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage):
|
|
super().__init__()
|
|
self.n_e = n_e
|
|
self.e_dim = e_dim
|
|
self.beta = beta
|
|
self.entropy_loss_ratio = entropy_loss_ratio
|
|
self.l2_norm = l2_norm
|
|
self.show_usage = show_usage
|
|
|
|
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
|
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
|
if self.l2_norm:
|
|
self.embedding.weight.data = F.normalize(
|
|
self.embedding.weight.data, p=2, dim=-1
|
|
)
|
|
if self.show_usage:
|
|
# self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536)))
|
|
self.codebook_used = nn.Parameter(torch.zeros(65536))
|
|
|
|
def forward(self, z):
|
|
# reshape z -> (batch, height, width, channel) and flatten
|
|
z = torch.einsum("b c h w -> b h w c", z).contiguous()
|
|
z_flattened = z.view(-1, self.e_dim)
|
|
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
|
|
|
if self.l2_norm:
|
|
z = F.normalize(z, p=2, dim=-1)
|
|
z_flattened = F.normalize(z_flattened, p=2, dim=-1)
|
|
embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
|
|
else:
|
|
embedding = self.embedding.weight
|
|
|
|
d = (
|
|
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
|
+ torch.sum(embedding**2, dim=1)
|
|
- 2
|
|
* torch.einsum(
|
|
"bd,dn->bn", z_flattened, torch.einsum("n d -> d n", embedding)
|
|
)
|
|
)
|
|
|
|
min_encoding_indices = torch.argmin(d, dim=1)
|
|
z_q = embedding[min_encoding_indices].view(z.shape)
|
|
perplexity = None
|
|
min_encodings = None
|
|
vq_loss = None
|
|
commit_loss = None
|
|
entropy_loss = None
|
|
|
|
# compute loss for embedding
|
|
if self.training:
|
|
vq_loss = torch.mean((z_q - z.detach()) ** 2)
|
|
commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2)
|
|
entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d)
|
|
|
|
# preserve gradients
|
|
z_q = z + (z_q - z).detach()
|
|
|
|
# reshape back to match original input shape
|
|
z_q = torch.einsum("b h w c -> b c h w", z_q)
|
|
|
|
return (
|
|
z_q,
|
|
(vq_loss, commit_loss, entropy_loss),
|
|
(perplexity, min_encodings, min_encoding_indices),
|
|
)
|
|
|
|
def get_codebook_entry(self, indices, shape=None, channel_first=True):
|
|
# shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel)
|
|
if self.l2_norm:
|
|
embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
|
|
else:
|
|
embedding = self.embedding.weight
|
|
z_q = embedding[indices] # (b*h*w, c)
|
|
|
|
if shape is not None:
|
|
if channel_first:
|
|
z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1])
|
|
# reshape back to match original input shape
|
|
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
|
else:
|
|
z_q = z_q.view(shape)
|
|
return z_q
|
|
|
|
|
|
class ResnetBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels=None,
|
|
conv_shortcut=False,
|
|
dropout=0.0,
|
|
norm_type="group",
|
|
):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
out_channels = in_channels if out_channels is None else out_channels
|
|
self.out_channels = out_channels
|
|
self.use_conv_shortcut = conv_shortcut
|
|
|
|
self.norm1 = Normalize(in_channels, norm_type)
|
|
self.conv1 = nn.Conv2d(
|
|
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
|
)
|
|
self.norm2 = Normalize(out_channels, norm_type)
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.conv2 = nn.Conv2d(
|
|
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
|
)
|
|
|
|
if self.in_channels != self.out_channels:
|
|
if self.use_conv_shortcut:
|
|
self.conv_shortcut = nn.Conv2d(
|
|
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
|
)
|
|
else:
|
|
self.nin_shortcut = nn.Conv2d(
|
|
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
|
)
|
|
|
|
def forward(self, x):
|
|
h = x
|
|
h = self.norm1(h)
|
|
h = nonlinearity(h)
|
|
h = self.conv1(h)
|
|
h = self.norm2(h)
|
|
h = nonlinearity(h)
|
|
h = self.dropout(h)
|
|
h = self.conv2(h)
|
|
|
|
if self.in_channels != self.out_channels:
|
|
if self.use_conv_shortcut:
|
|
x = self.conv_shortcut(x)
|
|
else:
|
|
x = self.nin_shortcut(x)
|
|
return x + h
|
|
|
|
|
|
class AttnBlock(nn.Module):
|
|
def __init__(self, in_channels, norm_type="group"):
|
|
super().__init__()
|
|
self.norm = Normalize(in_channels, norm_type)
|
|
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
|
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
|
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
|
self.proj_out = nn.Conv2d(
|
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
)
|
|
|
|
def forward(self, x):
|
|
h_ = x
|
|
h_ = self.norm(h_)
|
|
q = self.q(h_)
|
|
k = self.k(h_)
|
|
v = self.v(h_)
|
|
|
|
# compute attention
|
|
b, c, h, w = q.shape
|
|
q = q.reshape(b, c, h * w)
|
|
q = q.permute(0, 2, 1) # b,hw,c
|
|
k = k.reshape(b, c, h * w) # b,c,hw
|
|
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
|
w_ = w_ * (int(c) ** (-0.5))
|
|
w_ = F.softmax(w_, dim=2)
|
|
|
|
# attend to values
|
|
v = v.reshape(b, c, h * w)
|
|
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
|
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
|
h_ = h_.reshape(b, c, h, w)
|
|
|
|
h_ = self.proj_out(h_)
|
|
|
|
return x + h_
|
|
|
|
|
|
def nonlinearity(x):
|
|
# swish
|
|
return x * torch.sigmoid(x)
|
|
|
|
|
|
def Normalize(in_channels, norm_type="group"):
|
|
assert norm_type in ["group", "batch"]
|
|
if norm_type == "group":
|
|
return nn.GroupNorm(
|
|
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
|
)
|
|
elif norm_type == "batch":
|
|
return nn.SyncBatchNorm(in_channels)
|
|
|
|
|
|
class Upsample(nn.Module):
|
|
def __init__(self, in_channels, with_conv):
|
|
super().__init__()
|
|
self.with_conv = with_conv
|
|
if self.with_conv:
|
|
self.conv = nn.Conv2d(
|
|
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
|
)
|
|
|
|
def forward(self, x):
|
|
if x.dtype != torch.float32:
|
|
x = F.interpolate(x.to(torch.float), scale_factor=2.0, mode="nearest").to(
|
|
torch.bfloat16
|
|
)
|
|
else:
|
|
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
|
|
|
if self.with_conv:
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
|
|
class Downsample(nn.Module):
|
|
def __init__(self, in_channels, with_conv):
|
|
super().__init__()
|
|
self.with_conv = with_conv
|
|
if self.with_conv:
|
|
# no asymmetric padding in torch conv, must do it ourselves
|
|
self.conv = nn.Conv2d(
|
|
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
|
)
|
|
|
|
def forward(self, x):
|
|
if self.with_conv:
|
|
pad = (0, 1, 0, 1)
|
|
x = F.pad(x, pad, mode="constant", value=0)
|
|
x = self.conv(x)
|
|
else:
|
|
x = F.avg_pool2d(x, kernel_size=2, stride=2)
|
|
return x
|
|
|
|
|
|
def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
|
|
flat_affinity = affinity.reshape(-1, affinity.shape[-1])
|
|
flat_affinity /= temperature
|
|
probs = F.softmax(flat_affinity, dim=-1)
|
|
log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)
|
|
if loss_type == "softmax":
|
|
target_probs = probs
|
|
else:
|
|
raise ValueError("Entropy loss {} not supported".format(loss_type))
|
|
avg_probs = torch.mean(target_probs, dim=0)
|
|
avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-5))
|
|
sample_entropy = -torch.mean(torch.sum(target_probs * log_probs, dim=-1))
|
|
loss = sample_entropy - avg_entropy
|
|
return loss
|
|
|
|
|
|
class VQModel(nn.Module):
|
|
def __init__(self, config: ModelArgs):
|
|
super().__init__()
|
|
self.config = config
|
|
self.encoder = Encoder(
|
|
ch_mult=config.encoder_ch_mult,
|
|
z_channels=config.z_channels,
|
|
dropout=config.dropout_p,
|
|
)
|
|
self.decoder = Decoder(
|
|
ch_mult=config.decoder_ch_mult,
|
|
z_channels=config.z_channels,
|
|
dropout=config.dropout_p,
|
|
)
|
|
|
|
self.quantize = VectorQuantizer(
|
|
config.codebook_size,
|
|
config.codebook_embed_dim,
|
|
config.commit_loss_beta,
|
|
config.entropy_loss_ratio,
|
|
config.codebook_l2_norm,
|
|
config.codebook_show_usage,
|
|
)
|
|
self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)
|
|
self.post_quant_conv = nn.Conv2d(
|
|
config.codebook_embed_dim, config.z_channels, 1
|
|
)
|
|
|
|
def encode(self, x):
|
|
h = self.encoder(x)
|
|
h = self.quant_conv(h)
|
|
quant, emb_loss, info = self.quantize(h)
|
|
return quant, emb_loss, info
|
|
|
|
def decode(self, quant):
|
|
quant = self.post_quant_conv(quant)
|
|
dec = self.decoder(quant)
|
|
return dec
|
|
|
|
def decode_code(self, code_b, shape=None, channel_first=True):
|
|
quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first)
|
|
dec = self.decode(quant_b)
|
|
return dec
|
|
|
|
def forward(self, input):
|
|
quant, diff, _ = self.encode(input)
|
|
dec = self.decode(quant)
|
|
return dec, diff
|
|
|
|
|
|
class MultiModalityPreTrainedModel(PreTrainedModel):
|
|
config_class = MultiModalityConfig
|
|
base_model_prefix = "multi_modality"
|
|
_no_split_modules = []
|
|
_skip_keys_device_placement = "past_key_values"
|
|
|
|
|
|
# Copied and adapted from:
|
|
# https://github.com/deepseek-ai/Janus/tree/main/janus/models/modeling_vlm.py
|
|
class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|
|
|
def __init__(
|
|
self,
|
|
config: MultiModalityConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
):
|
|
super().__init__(config)
|
|
|
|
vision_config = config.vision_config
|
|
vision_cls = model_name_to_cls(vision_config.cls)
|
|
self.vision_model = vision_cls(**vision_config.params)
|
|
|
|
aligner_config = config.aligner_config
|
|
aligner_cls = model_name_to_cls(aligner_config.cls)
|
|
self.aligner = aligner_cls(aligner_config.params)
|
|
|
|
gen_vision_config = config.gen_vision_config
|
|
gen_vision_cls = model_name_to_cls(gen_vision_config.cls)
|
|
self.gen_vision_model = gen_vision_cls()
|
|
|
|
gen_aligner_config = config.gen_aligner_config
|
|
gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls)
|
|
self.gen_aligner = gen_aligner_cls(gen_aligner_config.params)
|
|
|
|
gen_head_config = config.gen_head_config
|
|
gen_head_cls = model_name_to_cls(gen_head_config.cls)
|
|
self.gen_head = gen_head_cls(gen_head_config.params)
|
|
|
|
self.gen_embed = torch.nn.Embedding(
|
|
gen_vision_config.params["image_token_size"],
|
|
gen_vision_config.params["n_embed"],
|
|
)
|
|
|
|
language_config = config.language_config
|
|
self.language_model = LlamaForCausalLM(
|
|
language_config, quant_config=quant_config
|
|
)
|
|
self.logits_processor = LogitsProcessor(config)
|
|
|
|
def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
|
|
pixel_values = image_input.pixel_values
|
|
bs, n = pixel_values.shape[0:2]
|
|
pixel_values = pixel_values.to(
|
|
device=self.vision_model.device, dtype=self.vision_model.dtype
|
|
)
|
|
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
|
|
|
|
# [b x n, T2, D]
|
|
images_embeds = self.aligner(self.vision_model(images))
|
|
|
|
# [b x n, T2, D] -> [b, n x T2, D]
|
|
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
|
|
|
|
return images_embeds
|
|
|
|
def get_input_embeddings(self) -> nn.Embedding:
|
|
return self.language_model.model.embed_tokens
|
|
|
|
@torch.no_grad()
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
positions: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
) -> torch.Tensor:
|
|
|
|
inputs_embeds = general_mm_embed_routine(
|
|
input_ids=input_ids,
|
|
forward_batch=forward_batch,
|
|
embed_tokens=self.get_input_embeddings(),
|
|
mm_data_embedding_func=self.get_image_feature,
|
|
)
|
|
|
|
return self.language_model(
|
|
input_ids=None,
|
|
positions=positions,
|
|
forward_batch=forward_batch,
|
|
input_embeds=inputs_embeds,
|
|
get_embedding=False,
|
|
)
|
|
|
|
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
|
|
return self.gen_aligner(self.gen_embed(image_ids))
|
|
|
|
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
|
im_start_id = image_inputs.im_start_id
|
|
im_end_id = image_inputs.im_end_id
|
|
media_token_pairs = [(im_start_id, im_end_id)]
|
|
|
|
helper = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
|
|
|
return helper.pad_input_tokens(input_ids, image_inputs)
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
(".qkv_proj", ".q_proj", "q"),
|
|
(".qkv_proj", ".k_proj", "k"),
|
|
(".qkv_proj", ".v_proj", "v"),
|
|
("gate_up_proj", "gate_proj", 0),
|
|
("gate_up_proj", "up_proj", 1),
|
|
]
|
|
|
|
params_dict = dict(self.named_parameters())
|
|
for name, loaded_weight in weights:
|
|
if "rotary_emb.inv_freq~" in name or "projector" in name:
|
|
continue
|
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
|
# Models trained using ColossalAI may include these tensors in
|
|
# the checkpoint. Skip them.
|
|
continue
|
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
|
continue
|
|
|
|
# skip generation sub model
|
|
if "gen" in name:
|
|
continue
|
|
|
|
# adapt to VisionAttention
|
|
name = name.replace(r"self_attn.out_proj", r"self_attn.proj")
|
|
if "vision_model.vision_tower" in name:
|
|
name = name.replace("attn.qkv", "attn.qkv_proj")
|
|
|
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
# replace the name and load with customized loader
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
|
|
# # Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader", None)
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
|
|
|
|
AutoModel.register(config_class=MultiModalityConfig, model_class=MultiModalityCausalLM)
|
|
EntryClass = [MultiModalityCausalLM]
|