sglang0.4.5.post1/python/sglang/srt/models/deepseek_janus_pro.py

2071 lines
69 KiB
Python

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Copied and Adapted from:
# https://github.com/deepseek-ai/Janus
import collections
import math
import os
from dataclasses import field
from enum import Enum
from functools import partial
from itertools import repeat
from typing import (
Callable,
Final,
Iterable,
Literal,
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
)
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor, _assert, nn
from torch.nn.init import trunc_normal_
from transformers import AutoModel, PreTrainedModel
from sglang.srt.configs.janus_pro import *
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalInputs, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM
from sglang.utils import logger
#################################################################################
# VQ Model Configs #
#################################################################################
# Copied from:
# https://github.com/deepseek-ai/Janus/tree/main/janus/models/vq_model.py
@dataclass
class ModelArgs:
codebook_size: int = 16384
codebook_embed_dim: int = 8
codebook_l2_norm: bool = True
codebook_show_usage: bool = True
commit_loss_beta: float = 0.25
entropy_loss_ratio: float = 0.0
encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
z_channels: int = 256
dropout_p: float = 0.0
def named_apply(
fn: Callable,
module: nn.Module,
name="",
depth_first: bool = True,
include_root: bool = False,
) -> nn.Module:
if not depth_first and include_root:
fn(module=module, name=name)
for child_name, child_module in module.named_children():
child_name = ".".join((name, child_name)) if name else child_name
named_apply(
fn=fn,
module=child_module,
name=child_name,
depth_first=depth_first,
include_root=True,
)
if depth_first and include_root:
fn(module=module, name=name)
return module
def VQ_16(**kwargs):
return VQModel(
ModelArgs(
encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs
)
)
VQ_models = {"VQ-16": VQ_16}
import collections.abc
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
logger.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
if tensor.dtype in [torch.float16, torch.bfloat16]:
# The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
og_dtype = tensor.dtype
tensor = tensor.to(torch.float32)
tensor.erfinv_()
tensor = tensor.to(og_dtype)
else:
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
if tensor.dtype == torch.float16:
# The `clamp_` op is not (yet?) defined in float16+cpu
tensor = tensor.to(torch.float32)
tensor.clamp_(min=a, max=b)
else:
tensor.clamp_(min=a, max=b)
def trunc_normal_tf_(
tensor: torch.Tensor,
mean: float = 0.0,
std: float = 1.0,
a: float = -2.0,
b: float = 2.0,
):
"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \\leq \text{mean} \\leq b`.
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
and the result is subsquently scaled and shifted by the mean and std args.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
"""
with torch.no_grad():
_trunc_normal_(tensor, 0, 1.0, a, b)
tensor.mul_(std).add_(mean)
to_2tuple = _ntuple(2)
class Format(str, Enum):
NCHW = "NCHW"
NHWC = "NHWC"
NCL = "NCL"
NLC = "NLC"
def nchw_to(x: torch.Tensor, fmt: Format):
if fmt == Format.NHWC:
x = x.permute(0, 2, 3, 1)
elif fmt == Format.NLC:
x = x.flatten(2).transpose(1, 2)
elif fmt == Format.NCL:
x = x.flatten(2)
return x
def resample_patch_embed(
patch_embed,
new_size: List[int],
interpolation: str = "bicubic",
antialias: bool = True,
verbose: bool = False,
):
"""Resample the weights of the patch embedding kernel to target resolution.
We resample the patch embedding kernel by approximately inverting the effect
of patch resizing.
Code based on:
https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py
With this resizing, we can for example load a B/8 filter into a B/16 model
and, on 2x larger input image, the result will match.
Args:
patch_embed: original parameter to be resized.
new_size (tuple(int, int): target shape (height, width)-only.
interpolation (str): interpolation for resize
antialias (bool): use anti-aliasing filter in resize
verbose (bool): log operation
Returns:
Resized patch embedding kernel.
"""
import numpy as np
try:
from torch import vmap
except ImportError:
from torch.func import vmap
assert len(patch_embed.shape) == 4, "Four dimensions expected"
assert len(new_size) == 2, "New shape should only be hw"
old_size = patch_embed.shape[-2:]
if tuple(old_size) == tuple(new_size):
return patch_embed
if verbose:
logger.info(
f"Resize patch embedding {patch_embed.shape} to {new_size}, w/ {interpolation} interpolation."
)
def resize(x_np, _new_size):
x_tf = torch.Tensor(x_np)[None, None, ...]
x_upsampled = F.interpolate(
x_tf, size=_new_size, mode=interpolation, antialias=antialias
)[0, 0, ...].numpy()
return x_upsampled
def get_resize_mat(_old_size, _new_size):
mat = []
for i in range(np.prod(_old_size)):
basis_vec = np.zeros(_old_size)
basis_vec[np.unravel_index(i, _old_size)] = 1.0
mat.append(resize(basis_vec, _new_size).reshape(-1))
return np.stack(mat).T
resize_mat = get_resize_mat(old_size, new_size)
resize_mat_pinv = torch.tensor(
np.linalg.pinv(resize_mat.T), device=patch_embed.device
)
def resample_kernel(kernel):
resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)
return resampled_kernel.reshape(new_size)
v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1)
orig_dtype = patch_embed.dtype
patch_embed = patch_embed.float()
patch_embed = v_resample_kernel(patch_embed)
patch_embed = patch_embed.to(orig_dtype)
return patch_embed
# Copied from:
# https://github.com/deepseek-ai/Janus/tree/main/janus/models/siglip_vit.py
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding"""
output_fmt: Format
dynamic_img_pad: torch.jit.Final[bool]
def __init__(
self,
img_size: Optional[int] = 224,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten: bool = True,
output_fmt: Optional[str] = None,
bias: bool = True,
strict_img_size: bool = True,
dynamic_img_pad: bool = False,
):
super().__init__()
self.patch_size = tuple(to_2tuple(patch_size))
self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size)
if output_fmt is not None:
self.flatten = False
self.output_fmt = Format(output_fmt)
else:
# flatten spatial dim and transpose to channels last, kept for bwd compat
self.flatten = flatten
self.output_fmt = Format.NCHW
self.strict_img_size = strict_img_size
self.dynamic_img_pad = dynamic_img_pad
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias
)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def _init_img_size(self, img_size: Union[int, Tuple[int, int]]):
assert self.patch_size
if img_size is None:
return None, None, None
img_size = to_2tuple(img_size)
grid_size = tuple([s // p for s, p in zip(img_size, self.patch_size)])
num_patches = grid_size[0] * grid_size[1]
return img_size, grid_size, num_patches
def set_input_size(
self,
img_size: Optional[Union[int, Tuple[int, int]]] = None,
patch_size: Optional[Union[int, Tuple[int, int]]] = None,
):
new_patch_size = None
if patch_size is not None:
new_patch_size = to_2tuple(patch_size)
if new_patch_size is not None and new_patch_size != self.patch_size:
with torch.no_grad():
new_proj = nn.Conv2d(
self.proj.in_channels,
self.proj.out_channels,
kernel_size=new_patch_size,
stride=new_patch_size,
bias=self.proj.bias is not None,
)
new_proj.weight.copy_(
resample_patch_embed(self.proj.weight, new_patch_size, verbose=True)
)
if self.proj.bias is not None:
new_proj.bias.copy_(self.proj.bias)
self.proj = new_proj
self.patch_size = new_patch_size
img_size = img_size or self.img_size
if img_size != self.img_size or new_patch_size is not None:
self.img_size, self.grid_size, self.num_patches = self._init_img_size(
img_size
)
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
if as_scalar:
return max(self.patch_size)
else:
return self.patch_size
def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
"""Get grid (feature) size for given image size taking account of dynamic padding.
NOTE: must be torchscript compatible so using fixed tuple indexing
"""
if self.dynamic_img_pad:
return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(
img_size[1] / self.patch_size[1]
)
else:
return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]
def forward(self, x):
B, C, H, W = x.shape
if self.img_size is not None:
if self.strict_img_size:
_assert(
H == self.img_size[0],
f"Input height ({H}) doesn't match model ({self.img_size[0]}).",
)
_assert(
W == self.img_size[1],
f"Input width ({W}) doesn't match model ({self.img_size[1]}).",
)
elif not self.dynamic_img_pad:
_assert(
H % self.patch_size[0] == 0,
f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]}).",
)
_assert(
W % self.patch_size[1] == 0,
f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]}).",
)
if self.dynamic_img_pad:
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
x = F.pad(x, (0, pad_w, 0, pad_h))
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
elif self.output_fmt != Format.NCHW:
x = nchw_to(x, self.output_fmt)
x = self.norm(x)
return x
class Mlp(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks
NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected.
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.0,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = (
norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
)
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x)
x = self.drop2(x)
return x
def drop_path(
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f"drop_prob={round(self.drop_prob, 3):0.3f}"
class VisionTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
qk_norm: bool = False,
proj_drop: float = 0.0,
attn_drop: float = 0.0,
init_values: Optional[float] = None,
drop_path: float = 0.0,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.LayerNorm,
mlp_layer: nn.Module = Mlp,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = VisionAttention(
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
use_qkv_parallel=True,
use_context_forward=False,
softmax_in_single_precision=False,
dropout=attn_drop,
)
self.ls1 = (
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
)
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = mlp_layer(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=proj_drop,
)
self.ls2 = (
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
LayerType = Union[str, Callable, Type[torch.nn.Module]]
class PatchDropout(nn.Module):
"""
https://arxiv.org/abs/2212.00794 and https://arxiv.org/pdf/2208.07220
"""
return_indices: torch.jit.Final[bool]
def __init__(
self,
prob: float = 0.5,
num_prefix_tokens: int = 1,
ordered: bool = False,
return_indices: bool = False,
):
super().__init__()
assert 0 <= prob < 1.0
self.prob = prob
self.num_prefix_tokens = (
num_prefix_tokens # exclude CLS token (or other prefix tokens)
)
self.ordered = ordered
self.return_indices = return_indices
def forward(
self, x
) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
if not self.training or self.prob == 0.0:
if self.return_indices:
return x, None
return x
if self.num_prefix_tokens:
prefix_tokens, x = (
x[:, : self.num_prefix_tokens],
x[:, self.num_prefix_tokens :],
)
else:
prefix_tokens = None
B = x.shape[0]
L = x.shape[1]
num_keep = max(1, int(L * (1.0 - self.prob)))
keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[
:, :num_keep
]
if self.ordered:
# NOTE does not need to maintain patch order in typical transformer use,
# but possibly useful for debug / visualization
keep_indices = keep_indices.sort(dim=-1)[0]
x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:]))
if prefix_tokens is not None:
x = torch.cat((prefix_tokens, x), dim=1)
if self.return_indices:
return x, keep_indices
return x
def resample_abs_pos_embed(
posemb: torch.Tensor,
new_size: List[int],
old_size: Optional[List[int]] = None,
num_prefix_tokens: int = 1,
interpolation: str = "bicubic",
antialias: bool = True,
verbose: bool = False,
):
# sort out sizes, assume square if old size not provided
num_pos_tokens = posemb.shape[1]
num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens
if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
return posemb
if old_size is None:
hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
old_size = hw, hw
if num_prefix_tokens:
posemb_prefix, posemb = (
posemb[:, :num_prefix_tokens],
posemb[:, num_prefix_tokens:],
)
else:
posemb_prefix, posemb = None, posemb
# do the interpolation
embed_dim = posemb.shape[-1]
orig_dtype = posemb.dtype
posemb = posemb.float() # interpolate needs float32
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
posemb = F.interpolate(
posemb, size=new_size, mode=interpolation, antialias=antialias
)
posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)
posemb = posemb.to(orig_dtype)
# add back extra (class, etc) prefix tokens
if posemb_prefix is not None:
posemb = torch.cat([posemb_prefix, posemb], dim=1)
if not torch.jit.is_scripting() and verbose:
logger.info(f"Resized position embedding: {old_size} to {new_size}.")
return posemb
def init_weights(self):
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
trunc_normal_(self.latent, std=self.latent_dim**-0.5)
def init_weights_vit_timm(module: nn.Module, name: str = "") -> None:
"""ViT weight initialization, original timm impl (for reproducibility)"""
if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif hasattr(module, "init_weights"):
module.init_weights()
class VisionTransformer(nn.Module):
"""Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
"""
dynamic_img_size: Final[bool]
def __init__(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: Literal["", "avg", "token", "map"] = "token",
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_norm: bool = False,
init_values: Optional[float] = None,
class_token: bool = True,
no_embed_class: bool = False,
reg_tokens: int = 0,
pre_norm: bool = False,
fc_norm: Optional[bool] = None,
dynamic_img_size: bool = False,
dynamic_img_pad: bool = False,
drop_rate: float = 0.0,
pos_drop_rate: float = 0.0,
patch_drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "",
embed_layer: Callable = PatchEmbed,
_norm_layer: Optional[LayerType] = None,
_act_layer: Optional[LayerType] = None,
block_fn: Type[nn.Module] = VisionTransformerBlock,
mlp_layer: Type[nn.Module] = Mlp,
ignore_head: bool = False,
) -> None:
"""
Args:
img_size: Input image size.
patch_size: Patch size.
in_chans: Number of image input channels.
num_classes: Mumber of classes for classification head.
global_pool: Type of global pooling for final sequence (default: 'token').
embed_dim: Transformer embedding dimension.
depth: Depth of transformer.
num_heads: Number of attention heads.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: Enable bias for qkv projections if True.
init_values: Layer-scale init values (layer-scale enabled if not None).
class_token: Use class token.
no_embed_class: Don't include position embeddings for class (or reg) tokens.
reg_tokens: Number of register tokens.
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
drop_rate: Head dropout rate.
pos_drop_rate: Position embedding dropout rate.
attn_drop_rate: Attention dropout rate.
drop_path_rate: Stochastic depth rate.
weight_init: Weight initialization scheme.
embed_layer: Patch embedding layer.
_norm_layer: Normalization layer.
_act_layer: MLP activation layer.
block_fn: Transformer block layer.
"""
super().__init__()
assert global_pool in ("", "avg", "token", "map")
assert class_token or global_pool != "token"
use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
# norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
# act_layer = get_act_layer(act_layer) or nn.GELU
norm_layer = partial(nn.LayerNorm, eps=1e-6)
act_layer = nn.GELU
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = (
embed_dim # num_features for consistency with other models
)
self.num_prefix_tokens = 1 if class_token else 0
self.num_prefix_tokens += reg_tokens
self.num_reg_tokens = reg_tokens
self.has_class_token = class_token
self.no_embed_class = (
no_embed_class # don't embed prefix positions (includes reg)
)
self.dynamic_img_size = dynamic_img_size
self.grad_checkpointing = False
self.ignore_head = ignore_head
embed_args = {}
if dynamic_img_size:
# flatten deferred until after pos embed
embed_args.update(dict(strict_img_size=False, output_fmt="NHWC"))
self.patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
dynamic_img_pad=dynamic_img_pad,
**embed_args,
)
num_patches = self.patch_embed.num_patches
self.cls_token = (
nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
)
self.reg_token = (
nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
)
embed_len = (
num_patches if no_embed_class else num_patches + self.num_prefix_tokens
)
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
self.pos_drop = nn.Dropout(p=pos_drop_rate)
if patch_drop_rate > 0:
self.patch_drop = PatchDropout(
patch_drop_rate,
num_prefix_tokens=self.num_prefix_tokens,
)
else:
self.patch_drop = nn.Identity()
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
self.blocks = nn.Sequential(
*[
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
init_values=init_values,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
mlp_layer=mlp_layer,
)
for i in range(depth)
]
)
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
# Classifier Head
if global_pool == "map":
AttentionPoolLatent.init_weights = init_weights
self.attn_pool = AttentionPoolLatent(
self.embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
)
else:
self.attn_pool = None
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
self.head_drop = nn.Dropout(drop_rate)
self.head = (
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
)
if weight_init != "skip":
self.init_weights(weight_init)
def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None:
assert mode in ("jax", "jax_nlhb", "moco", "")
# head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
trunc_normal_(self.pos_embed, std=0.02)
if self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6)
named_apply(init_weights_vit_timm, self)
@torch.jit.ignore
def no_weight_decay(self) -> Set:
return {"pos_embed", "cls_token", "dist_token"}
@torch.jit.ignore
def group_matcher(self, coarse: bool = False) -> Dict:
return dict(
stem=r"^cls_token|pos_embed|patch_embed", # stem and embed
blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))],
)
@torch.jit.ignore
def get_classifier(self) -> nn.Module:
return self.head
def reset_classifier(self, num_classes: int, global_pool=None) -> None:
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ("", "avg", "token", "map")
if global_pool == "map" and self.attn_pool is None:
assert (
False
), "Cannot currently add attention pooling in reset_classifier()."
elif global_pool != "map " and self.attn_pool is not None:
self.attn_pool = None # remove attention pooling
self.global_pool = global_pool
self.head = (
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
)
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
if self.dynamic_img_size:
B, H, W, C = x.shape
pos_embed = resample_abs_pos_embed(
self.pos_embed,
[H, W],
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
)
x = x.view(B, -1, C)
else:
pos_embed = self.pos_embed
to_cat = []
if self.cls_token is not None:
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
if self.reg_token is not None:
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
if self.no_embed_class:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x = x + pos_embed
if to_cat:
x = torch.cat(to_cat + [x], dim=1)
else:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if to_cat:
x = torch.cat(to_cat + [x], dim=1)
x = x + pos_embed
return self.pos_drop(x)
def _intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1,
) -> List[torch.Tensor]:
outputs, num_blocks = [], len(self.blocks)
take_indices = set(
range(num_blocks - n, num_blocks) if isinstance(n, int) else n
)
# forward pass
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in take_indices:
outputs.append(x)
return outputs
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
x = self.blocks(x)
x = self.norm(x)
return x
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
if self.attn_pool is not None:
x = self.attn_pool(x)
elif self.global_pool == "avg":
x = x[:, self.num_prefix_tokens :].mean(dim=1)
elif self.global_pool:
x = x[:, 0] # class token
x = self.fc_norm(x)
x = self.head_drop(x)
return x if pre_logits else self.head(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_features(x)
if not self.ignore_head:
x = self.forward_head(x)
return x
def model_name_to_cls(cls_name):
if "MlpProjector" in cls_name:
cls = MlpProjector
elif "CLIPVisionTower" in cls_name:
cls = CLIPVisionTower
elif "VQ" in cls_name:
cls = VQ_models[cls_name]
elif "vision_head" in cls_name:
cls = vision_head
else:
raise ValueError(f"class_name {cls_name} is invalid.")
return cls
class vision_head(torch.nn.Module):
def __init__(self, params):
super().__init__()
self.output_mlp_projector = torch.nn.Linear(
params["n_embed"], params["image_token_embed"]
)
self.vision_activation = torch.nn.GELU()
self.vision_head = torch.nn.Linear(
params["image_token_embed"], params["image_token_size"]
)
def forward(self, x):
x = self.output_mlp_projector(x)
x = self.vision_activation(x)
x = self.vision_head(x)
return x
SigLIP_MODEL_CONFIG = {
"siglip_so400m_patch14_384": {
"image_size": 336,
"patch_size": 14,
"width": 1152,
"layers": 27,
"heads": 16,
"mlp_ratio": 3.7362,
"global_pool": "map",
"use_checkpoint": False,
},
"siglip_so400m_patch14_224": {
"image_size": 224,
"patch_size": 14,
"width": 1152,
"layers": 27,
"heads": 16,
"mlp_ratio": 3.7362,
"global_pool": "map",
"use_checkpoint": False,
},
"siglip_large_patch16_384": {
"image_size": 384,
"patch_size": 16,
"width": 1024,
"layers": 24,
"heads": 16,
"mlp_ratio": 4,
"global_pool": "map",
"use_checkpoint": False,
},
}
def create_siglip_vit(
model_name: str = "siglip_so400m_patch14_384",
image_size: int = 384,
select_layer: int = -1,
ckpt_path: str = "",
**kwargs,
):
assert (
model_name in SigLIP_MODEL_CONFIG.keys()
), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}"
vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name])
if select_layer <= 0:
layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1)
else:
layers = min(vision_cfg.layers, select_layer)
model = VisionTransformer(
img_size=image_size,
patch_size=vision_cfg.patch_size,
embed_dim=vision_cfg.width,
depth=layers,
num_heads=vision_cfg.heads,
mlp_ratio=vision_cfg.mlp_ratio,
class_token=vision_cfg.class_token,
global_pool=vision_cfg.global_pool,
ignore_head=kwargs.get("ignore_head", True),
weight_init=kwargs.get("weight_init", "skip"),
num_classes=0,
)
if ckpt_path:
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
incompatible_keys = model.load_state_dict(state_dict, strict=False)
print(
f"SigLIP-ViT restores from {ckpt_path},\n"
f"\tincompatible_keys:', {incompatible_keys}."
)
return model
class Normalize(torch.nn.Module):
"""Normalize a tensor image with mean and standard deviation.
This transform does not support PIL Image.
Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
channels, this transform will normalize each channel of the input
``torch.*Tensor`` i.e.,
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
.. note::
This transform acts out of place, i.e., it does not mutate the input tensor.
Args:
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
inplace(bool,optional): Bool to make this operation in-place.
"""
def __init__(self, mean, std, inplace=False):
super().__init__()
# _log_api_usage_once(self)
self.mean = mean
self.std = std
self.inplace = inplace
def forward(self, tensor: Tensor) -> Tensor:
"""
Args:
tensor (Tensor): Tensor image to be normalized.
Returns:
Tensor: Normalized Tensor image.
"""
return F.normalize(tensor, self.mean, self.std, self.inplace)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"
class CLIPVisionTower(nn.Module):
def __init__(
self,
model_name: str = "siglip_large_patch16_384",
image_size: Union[Tuple[int, int], int] = 336,
select_feature: str = "patch",
select_layer: int = -2,
select_layers: list = None,
ckpt_path: str = "",
pixel_mean: Optional[List[float]] = None,
pixel_std: Optional[List[float]] = None,
**kwargs,
):
super().__init__()
self.model_name = model_name
self.select_feature = select_feature
self.select_layer = select_layer
self.select_layers = select_layers
vision_tower_params = {
"model_name": model_name,
"image_size": image_size,
"ckpt_path": ckpt_path,
"select_layer": select_layer,
}
vision_tower_params.update(kwargs)
self.vision_tower, self.forward_kwargs = self.build_vision_tower(
vision_tower_params
)
if pixel_mean is not None and pixel_std is not None:
image_norm = Normalize(mean=pixel_mean, std=pixel_std)
else:
image_norm = None
self.image_norm = image_norm
@property
def device(self) -> torch.device:
return next(self.vision_tower.parameters()).device
@property
def dtype(self):
return next(self.vision_tower.parameters()).dtype
def build_vision_tower(self, vision_tower_params):
if self.model_name.startswith("siglip"):
self.select_feature = "same"
vision_tower = create_siglip_vit(**vision_tower_params)
forward_kwargs = dict()
elif self.model_name.startswith("sam"):
# vision_tower = create_sam_vit(**vision_tower_params)
forward_kwargs = dict()
else: # huggingface
from transformers import CLIPVisionModel
vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params)
forward_kwargs = dict(output_hidden_states=True)
return vision_tower, forward_kwargs
def feature_select(self, image_forward_outs):
if isinstance(image_forward_outs, torch.Tensor):
# the output has been the self.select_layer"s features
image_features = image_forward_outs
else:
image_features = image_forward_outs.hidden_states[self.select_layer]
if self.select_feature == "patch":
# if the output has cls_token
image_features = image_features[:, 1:]
elif self.select_feature == "cls_patch":
image_features = image_features
elif self.select_feature == "same":
image_features = image_features
else:
raise ValueError(f"Unexpected select feature: {self.select_feature}")
return image_features
def forward(self, images):
"""
Args:
images (torch.Tensor): [b, 3, H, W]
Returns:
image_features (torch.Tensor): [b, n_patch, d]
"""
if self.image_norm is not None:
images = self.image_norm(images)
image_forward_outs = self.vision_tower(images, **self.forward_kwargs)
image_features = self.feature_select(image_forward_outs)
return image_features
class MlpProjector(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
if cfg["projector_type"] == "identity":
modules = nn.Identity()
elif cfg["projector_type"] == "linear":
modules = nn.Linear(cfg["input_dim"], cfg["n_embed"])
elif cfg["projector_type"] == "mlp_gelu":
mlp_depth = cfg.get("depth", 1)
modules = [nn.Linear(cfg["input_dim"], cfg["n_embed"])]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg["n_embed"], cfg["n_embed"]))
modules = nn.Sequential(*modules)
elif cfg["projector_type"] == "low_high_hybrid_split_mlp_gelu":
mlp_depth = cfg.get("depth", 1)
self.high_up_proj = nn.Linear(cfg["input_dim"], cfg["n_embed"] // 2)
self.low_up_proj = nn.Linear(cfg["input_dim"], cfg["n_embed"] // 2)
modules = []
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg["n_embed"], cfg["n_embed"]))
modules = nn.Sequential(*modules)
else:
raise ValueError(f"Unknown projector type: {cfg['projector_type']}")
self.layers = modules
def forward(
self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]
):
"""
Args:
x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor,
then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x);
otherwise it is the feature from the single vision encoder.
Returns:
x (torch.Tensor): [b, s, c]
"""
if isinstance(x_or_tuple, tuple):
# self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
high_x, low_x = x_or_tuple
high_x = self.high_up_proj(high_x)
low_x = self.low_up_proj(low_x)
x = torch.cat([high_x, low_x], dim=-1)
else:
x = x_or_tuple
return self.layers(x)
class LayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: float = 1e-5,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma
# use torch.scaled_dot_product_attention where possible
_HAS_FUSED_ATTN = hasattr(torch.nn.functional, "scaled_dot_product_attention")
if "TIMM_FUSED_ATTN" in os.environ:
_USE_FUSED_ATTN = int(os.environ["TIMM_FUSED_ATTN"])
else:
_USE_FUSED_ATTN = (
1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use)
)
# Set to True if exporting a model with Same padding via ONNX
_EXPORTABLE = False
def use_fused_attn(experimental: bool = False) -> bool:
# NOTE: ONNX export cannot handle F.scaled_dot_product_attention as of pytorch 2.0
if not _HAS_FUSED_ATTN or _EXPORTABLE:
return False
if experimental:
return _USE_FUSED_ATTN > 1
return _USE_FUSED_ATTN > 0
class AttentionPoolLatent(nn.Module):
"""Attention pooling w/ latent query"""
fused_attn: torch.jit.Final[bool]
def __init__(
self,
in_features: int,
out_features: int = None,
embed_dim: int = None,
num_heads: int = 8,
feat_size: Optional[int] = None,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_norm: bool = False,
latent_len: int = 1,
latent_dim: int = None,
pos_embed: str = "",
pool_type: str = "token",
norm_layer: Optional[nn.Module] = None,
drop: float = 0.0,
):
super().__init__()
embed_dim = embed_dim or in_features
out_features = out_features or in_features
assert embed_dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.feat_size = feat_size
self.scale = self.head_dim**-0.5
self.pool = pool_type
self.fused_attn = use_fused_attn()
if pos_embed == "abs":
assert feat_size is not None
self.pos_embed = nn.Parameter(torch.zeros(feat_size, in_features))
else:
self.pos_embed = None
self.latent_dim = latent_dim or embed_dim
self.latent_len = latent_len
self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim))
self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.proj = nn.Linear(embed_dim, embed_dim)
self.proj_drop = nn.Dropout(drop)
self.norm = (
norm_layer(out_features) if norm_layer is not None else nn.Identity()
)
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio))
self.init_weights()
def init_weights(self):
if self.pos_embed is not None:
trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
trunc_normal_tf_(self.latent, std=self.latent_dim**-0.5)
def forward(self, x):
B, N, C = x.shape
if self.pos_embed is not None:
# FIXME interpolate
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
q_latent = self.latent.expand(B, -1, -1)
q = (
self.q(q_latent)
.reshape(B, self.latent_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
kv = (
self.kv(x)
.reshape(B, N, 2, self.num_heads, self.head_dim)
.permute(2, 0, 3, 1, 4)
)
k, v = kv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = F.scaled_dot_product_attention(q, k, v)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
x = attn @ v
x = x.transpose(1, 2).reshape(B, self.latent_len, C)
x = self.proj(x)
x = self.proj_drop(x)
x = x + self.mlp(self.norm(x))
# optional pool if latent seq_len > 1 and pooled output is desired
if self.pool == "token":
x = x[:, 0]
elif self.pool == "avg":
x = x.mean(1)
class Encoder(nn.Module):
def __init__(
self,
in_channels=3,
ch=128,
ch_mult=(1, 1, 2, 2, 4),
num_res_blocks=2,
norm_type="group",
dropout=0.0,
resamp_with_conv=True,
z_channels=256,
):
super().__init__()
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1)
# downsampling
in_ch_mult = (1,) + tuple(ch_mult)
self.conv_blocks = nn.ModuleList()
for i_level in range(self.num_resolutions):
conv_block = nn.Module()
# res & attn
res_block = nn.ModuleList()
attn_block = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
res_block.append(
ResnetBlock(
block_in, block_out, dropout=dropout, norm_type=norm_type
)
)
block_in = block_out
if i_level == self.num_resolutions - 1:
attn_block.append(AttnBlock(block_in, norm_type))
conv_block.res = res_block
conv_block.attn = attn_block
# downsample
if i_level != self.num_resolutions - 1:
conv_block.downsample = Downsample(block_in, resamp_with_conv)
self.conv_blocks.append(conv_block)
# middle
self.mid = nn.ModuleList()
self.mid.append(
ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
)
self.mid.append(AttnBlock(block_in, norm_type=norm_type))
self.mid.append(
ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
)
# end
self.norm_out = Normalize(block_in, norm_type)
self.conv_out = nn.Conv2d(
block_in, z_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x):
h = self.conv_in(x)
# downsampling
for i_level, block in enumerate(self.conv_blocks):
for i_block in range(self.num_res_blocks):
h = block.res[i_block](h)
if len(block.attn) > 0:
h = block.attn[i_block](h)
if i_level != self.num_resolutions - 1:
h = block.downsample(h)
# middle
for mid_block in self.mid:
h = mid_block(h)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(
self,
z_channels=256,
ch=128,
ch_mult=(1, 1, 2, 2, 4),
num_res_blocks=2,
norm_type="group",
dropout=0.0,
resamp_with_conv=True,
out_channels=3,
):
super().__init__()
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
block_in = ch * ch_mult[self.num_resolutions - 1]
# z to block_in
self.conv_in = nn.Conv2d(
z_channels, block_in, kernel_size=3, stride=1, padding=1
)
# middle
self.mid = nn.ModuleList()
self.mid.append(
ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
)
self.mid.append(AttnBlock(block_in, norm_type=norm_type))
self.mid.append(
ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
)
# upsampling
self.conv_blocks = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
conv_block = nn.Module()
# res & attn
res_block = nn.ModuleList()
attn_block = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
res_block.append(
ResnetBlock(
block_in, block_out, dropout=dropout, norm_type=norm_type
)
)
block_in = block_out
if i_level == self.num_resolutions - 1:
attn_block.append(AttnBlock(block_in, norm_type))
conv_block.res = res_block
conv_block.attn = attn_block
# downsample
if i_level != 0:
conv_block.upsample = Upsample(block_in, resamp_with_conv)
self.conv_blocks.append(conv_block)
# end
self.norm_out = Normalize(block_in, norm_type)
self.conv_out = nn.Conv2d(
block_in, out_channels, kernel_size=3, stride=1, padding=1
)
@property
def last_layer(self):
return self.conv_out.weight
def forward(self, z):
# z to block_in
h = self.conv_in(z)
# middle
for mid_block in self.mid:
h = mid_block(h)
# upsampling
for i_level, block in enumerate(self.conv_blocks):
for i_block in range(self.num_res_blocks + 1):
h = block.res[i_block](h)
if len(block.attn) > 0:
h = block.attn[i_block](h)
if i_level != self.num_resolutions - 1:
h = block.upsample(h)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class VectorQuantizer(nn.Module):
def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage):
super().__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta
self.entropy_loss_ratio = entropy_loss_ratio
self.l2_norm = l2_norm
self.show_usage = show_usage
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
if self.l2_norm:
self.embedding.weight.data = F.normalize(
self.embedding.weight.data, p=2, dim=-1
)
if self.show_usage:
# self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536)))
self.codebook_used = nn.Parameter(torch.zeros(65536))
def forward(self, z):
# reshape z -> (batch, height, width, channel) and flatten
z = torch.einsum("b c h w -> b h w c", z).contiguous()
z_flattened = z.view(-1, self.e_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
if self.l2_norm:
z = F.normalize(z, p=2, dim=-1)
z_flattened = F.normalize(z_flattened, p=2, dim=-1)
embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
else:
embedding = self.embedding.weight
d = (
torch.sum(z_flattened**2, dim=1, keepdim=True)
+ torch.sum(embedding**2, dim=1)
- 2
* torch.einsum(
"bd,dn->bn", z_flattened, torch.einsum("n d -> d n", embedding)
)
)
min_encoding_indices = torch.argmin(d, dim=1)
z_q = embedding[min_encoding_indices].view(z.shape)
perplexity = None
min_encodings = None
vq_loss = None
commit_loss = None
entropy_loss = None
# compute loss for embedding
if self.training:
vq_loss = torch.mean((z_q - z.detach()) ** 2)
commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2)
entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
z_q = torch.einsum("b h w c -> b c h w", z_q)
return (
z_q,
(vq_loss, commit_loss, entropy_loss),
(perplexity, min_encodings, min_encoding_indices),
)
def get_codebook_entry(self, indices, shape=None, channel_first=True):
# shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel)
if self.l2_norm:
embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
else:
embedding = self.embedding.weight
z_q = embedding[indices] # (b*h*w, c)
if shape is not None:
if channel_first:
z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1])
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
else:
z_q = z_q.view(shape)
return z_q
class ResnetBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
norm_type="group",
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels, norm_type)
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
self.norm2 = Normalize(out_channels, norm_type)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
else:
self.nin_shortcut = nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttnBlock(nn.Module):
def __init__(self, in_channels, norm_type="group"):
super().__init__()
self.norm = Normalize(in_channels, norm_type)
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c) ** (-0.5))
w_ = F.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x + h_
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def Normalize(in_channels, norm_type="group"):
assert norm_type in ["group", "batch"]
if norm_type == "group":
return nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
elif norm_type == "batch":
return nn.SyncBatchNorm(in_channels)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x):
if x.dtype != torch.float32:
x = F.interpolate(x.to(torch.float), scale_factor=2.0, mode="nearest").to(
torch.bfloat16
)
else:
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = F.avg_pool2d(x, kernel_size=2, stride=2)
return x
def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
flat_affinity = affinity.reshape(-1, affinity.shape[-1])
flat_affinity /= temperature
probs = F.softmax(flat_affinity, dim=-1)
log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)
if loss_type == "softmax":
target_probs = probs
else:
raise ValueError("Entropy loss {} not supported".format(loss_type))
avg_probs = torch.mean(target_probs, dim=0)
avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-5))
sample_entropy = -torch.mean(torch.sum(target_probs * log_probs, dim=-1))
loss = sample_entropy - avg_entropy
return loss
class VQModel(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.encoder = Encoder(
ch_mult=config.encoder_ch_mult,
z_channels=config.z_channels,
dropout=config.dropout_p,
)
self.decoder = Decoder(
ch_mult=config.decoder_ch_mult,
z_channels=config.z_channels,
dropout=config.dropout_p,
)
self.quantize = VectorQuantizer(
config.codebook_size,
config.codebook_embed_dim,
config.commit_loss_beta,
config.entropy_loss_ratio,
config.codebook_l2_norm,
config.codebook_show_usage,
)
self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)
self.post_quant_conv = nn.Conv2d(
config.codebook_embed_dim, config.z_channels, 1
)
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info
def decode(self, quant):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
def decode_code(self, code_b, shape=None, channel_first=True):
quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first)
dec = self.decode(quant_b)
return dec
def forward(self, input):
quant, diff, _ = self.encode(input)
dec = self.decode(quant)
return dec, diff
class MultiModalityPreTrainedModel(PreTrainedModel):
config_class = MultiModalityConfig
base_model_prefix = "multi_modality"
_no_split_modules = []
_skip_keys_device_placement = "past_key_values"
# Copied and adapted from:
# https://github.com/deepseek-ai/Janus/tree/main/janus/models/modeling_vlm.py
class MultiModalityCausalLM(MultiModalityPreTrainedModel):
def __init__(
self,
config: MultiModalityConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__(config)
vision_config = config.vision_config
vision_cls = model_name_to_cls(vision_config.cls)
self.vision_model = vision_cls(**vision_config.params)
aligner_config = config.aligner_config
aligner_cls = model_name_to_cls(aligner_config.cls)
self.aligner = aligner_cls(aligner_config.params)
gen_vision_config = config.gen_vision_config
gen_vision_cls = model_name_to_cls(gen_vision_config.cls)
self.gen_vision_model = gen_vision_cls()
gen_aligner_config = config.gen_aligner_config
gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls)
self.gen_aligner = gen_aligner_cls(gen_aligner_config.params)
gen_head_config = config.gen_head_config
gen_head_cls = model_name_to_cls(gen_head_config.cls)
self.gen_head = gen_head_cls(gen_head_config.params)
self.gen_embed = torch.nn.Embedding(
gen_vision_config.params["image_token_size"],
gen_vision_config.params["n_embed"],
)
language_config = config.language_config
self.language_model = LlamaForCausalLM(
language_config, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config)
def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
pixel_values = image_input.pixel_values
bs, n = pixel_values.shape[0:2]
pixel_values = pixel_values.to(
device=self.vision_model.device, dtype=self.vision_model.dtype
)
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
# [b x n, T2, D]
images_embeds = self.aligner(self.vision_model(images))
# [b x n, T2, D] -> [b, n x T2, D]
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
return images_embeds
def get_input_embeddings(self) -> nn.Embedding:
return self.language_model.model.embed_tokens
@torch.no_grad()
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
inputs_embeds = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(),
mm_data_embedding_func=self.get_image_feature,
)
return self.language_model(
input_ids=None,
positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
get_embedding=False,
)
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
return self.gen_aligner(self.gen_embed(image_ids))
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
im_start_id = image_inputs.im_start_id
im_end_id = image_inputs.im_end_id
media_token_pairs = [(im_start_id, im_end_id)]
helper = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
return helper.pad_input_tokens(input_ids, image_inputs)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq~" in name or "projector" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
# skip generation sub model
if "gen" in name:
continue
# adapt to VisionAttention
name = name.replace(r"self_attn.out_proj", r"self_attn.proj")
if "vision_model.vision_tower" in name:
name = name.replace("attn.qkv", "attn.qkv_proj")
for param_name, weight_name, shard_id in stacked_params_mapping:
# replace the name and load with customized loader
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", None)
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
AutoModel.register(config_class=MultiModalityConfig, model_class=MultiModalityCausalLM)
EntryClass = [MultiModalityCausalLM]