# Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== # Copied and Adapted from: # https://github.com/deepseek-ai/Janus import collections import math import os from dataclasses import field from enum import Enum from functools import partial from itertools import repeat from typing import ( Callable, Final, Iterable, Literal, Optional, Sequence, Set, Tuple, Type, Union, ) import torch import torch.nn.functional as F from einops import rearrange from torch import Tensor, _assert, nn from torch.nn.init import trunc_normal_ from transformers import AutoModel, PreTrainedModel from sglang.srt.configs.janus_pro import * from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization import QuantizationConfig from sglang.srt.managers.mm_utils import ( MultiModalityDataPaddingPatternTokenPairs, general_mm_embed_routine, ) from sglang.srt.managers.schedule_batch import MultimodalInputs, global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llama import LlamaForCausalLM from sglang.utils import logger ################################################################################# # VQ Model Configs # ################################################################################# # Copied from: # https://github.com/deepseek-ai/Janus/tree/main/janus/models/vq_model.py @dataclass class ModelArgs: codebook_size: int = 16384 codebook_embed_dim: int = 8 codebook_l2_norm: bool = True codebook_show_usage: bool = True commit_loss_beta: float = 0.25 entropy_loss_ratio: float = 0.0 encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) z_channels: int = 256 dropout_p: float = 0.0 def named_apply( fn: Callable, module: nn.Module, name="", depth_first: bool = True, include_root: bool = False, ) -> nn.Module: if not depth_first and include_root: fn(module=module, name=name) for child_name, child_module in module.named_children(): child_name = ".".join((name, child_name)) if name else child_name named_apply( fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True, ) if depth_first and include_root: fn(module=module, name=name) return module def VQ_16(**kwargs): return VQModel( ModelArgs( encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs ) ) VQ_models = {"VQ-16": VQ_16} import collections.abc # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): return tuple(x) return tuple(repeat(x, n)) return parse def _trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 if (mean < a - 2 * std) or (mean > b + 2 * std): logger.warn( "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2, ) # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal if tensor.dtype in [torch.float16, torch.bfloat16]: # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu og_dtype = tensor.dtype tensor = tensor.to(torch.float32) tensor.erfinv_() tensor = tensor.to(og_dtype) else: tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) # Clamp to ensure it's in the proper range if tensor.dtype == torch.float16: # The `clamp_` op is not (yet?) defined in float16+cpu tensor = tensor.to(torch.float32) tensor.clamp_(min=a, max=b) else: tensor.clamp_(min=a, max=b) def trunc_normal_tf_( tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0, ): """Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \\leq \text{mean} \\leq b`. NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 and the result is subsquently scaled and shifted by the mean and std args. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value """ with torch.no_grad(): _trunc_normal_(tensor, 0, 1.0, a, b) tensor.mul_(std).add_(mean) to_2tuple = _ntuple(2) class Format(str, Enum): NCHW = "NCHW" NHWC = "NHWC" NCL = "NCL" NLC = "NLC" def nchw_to(x: torch.Tensor, fmt: Format): if fmt == Format.NHWC: x = x.permute(0, 2, 3, 1) elif fmt == Format.NLC: x = x.flatten(2).transpose(1, 2) elif fmt == Format.NCL: x = x.flatten(2) return x def resample_patch_embed( patch_embed, new_size: List[int], interpolation: str = "bicubic", antialias: bool = True, verbose: bool = False, ): """Resample the weights of the patch embedding kernel to target resolution. We resample the patch embedding kernel by approximately inverting the effect of patch resizing. Code based on: https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py With this resizing, we can for example load a B/8 filter into a B/16 model and, on 2x larger input image, the result will match. Args: patch_embed: original parameter to be resized. new_size (tuple(int, int): target shape (height, width)-only. interpolation (str): interpolation for resize antialias (bool): use anti-aliasing filter in resize verbose (bool): log operation Returns: Resized patch embedding kernel. """ import numpy as np try: from torch import vmap except ImportError: from torch.func import vmap assert len(patch_embed.shape) == 4, "Four dimensions expected" assert len(new_size) == 2, "New shape should only be hw" old_size = patch_embed.shape[-2:] if tuple(old_size) == tuple(new_size): return patch_embed if verbose: logger.info( f"Resize patch embedding {patch_embed.shape} to {new_size}, w/ {interpolation} interpolation." ) def resize(x_np, _new_size): x_tf = torch.Tensor(x_np)[None, None, ...] x_upsampled = F.interpolate( x_tf, size=_new_size, mode=interpolation, antialias=antialias )[0, 0, ...].numpy() return x_upsampled def get_resize_mat(_old_size, _new_size): mat = [] for i in range(np.prod(_old_size)): basis_vec = np.zeros(_old_size) basis_vec[np.unravel_index(i, _old_size)] = 1.0 mat.append(resize(basis_vec, _new_size).reshape(-1)) return np.stack(mat).T resize_mat = get_resize_mat(old_size, new_size) resize_mat_pinv = torch.tensor( np.linalg.pinv(resize_mat.T), device=patch_embed.device ) def resample_kernel(kernel): resampled_kernel = resize_mat_pinv @ kernel.reshape(-1) return resampled_kernel.reshape(new_size) v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1) orig_dtype = patch_embed.dtype patch_embed = patch_embed.float() patch_embed = v_resample_kernel(patch_embed) patch_embed = patch_embed.to(orig_dtype) return patch_embed # Copied from: # https://github.com/deepseek-ai/Janus/tree/main/janus/models/siglip_vit.py class PatchEmbed(nn.Module): """2D Image to Patch Embedding""" output_fmt: Format dynamic_img_pad: torch.jit.Final[bool] def __init__( self, img_size: Optional[int] = 224, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, norm_layer: Optional[Callable] = None, flatten: bool = True, output_fmt: Optional[str] = None, bias: bool = True, strict_img_size: bool = True, dynamic_img_pad: bool = False, ): super().__init__() self.patch_size = tuple(to_2tuple(patch_size)) self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size) if output_fmt is not None: self.flatten = False self.output_fmt = Format(output_fmt) else: # flatten spatial dim and transpose to channels last, kept for bwd compat self.flatten = flatten self.output_fmt = Format.NCHW self.strict_img_size = strict_img_size self.dynamic_img_pad = dynamic_img_pad self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias ) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def _init_img_size(self, img_size: Union[int, Tuple[int, int]]): assert self.patch_size if img_size is None: return None, None, None img_size = to_2tuple(img_size) grid_size = tuple([s // p for s, p in zip(img_size, self.patch_size)]) num_patches = grid_size[0] * grid_size[1] return img_size, grid_size, num_patches def set_input_size( self, img_size: Optional[Union[int, Tuple[int, int]]] = None, patch_size: Optional[Union[int, Tuple[int, int]]] = None, ): new_patch_size = None if patch_size is not None: new_patch_size = to_2tuple(patch_size) if new_patch_size is not None and new_patch_size != self.patch_size: with torch.no_grad(): new_proj = nn.Conv2d( self.proj.in_channels, self.proj.out_channels, kernel_size=new_patch_size, stride=new_patch_size, bias=self.proj.bias is not None, ) new_proj.weight.copy_( resample_patch_embed(self.proj.weight, new_patch_size, verbose=True) ) if self.proj.bias is not None: new_proj.bias.copy_(self.proj.bias) self.proj = new_proj self.patch_size = new_patch_size img_size = img_size or self.img_size if img_size != self.img_size or new_patch_size is not None: self.img_size, self.grid_size, self.num_patches = self._init_img_size( img_size ) def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]: if as_scalar: return max(self.patch_size) else: return self.patch_size def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]: """Get grid (feature) size for given image size taking account of dynamic padding. NOTE: must be torchscript compatible so using fixed tuple indexing """ if self.dynamic_img_pad: return math.ceil(img_size[0] / self.patch_size[0]), math.ceil( img_size[1] / self.patch_size[1] ) else: return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1] def forward(self, x): B, C, H, W = x.shape if self.img_size is not None: if self.strict_img_size: _assert( H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).", ) _assert( W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).", ) elif not self.dynamic_img_pad: _assert( H % self.patch_size[0] == 0, f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]}).", ) _assert( W % self.patch_size[1] == 0, f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]}).", ) if self.dynamic_img_pad: pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] x = F.pad(x, (0, pad_w, 0, pad_h)) x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # NCHW -> NLC elif self.output_fmt != Format.NCHW: x = nchw_to(x, self.output_fmt) x = self.norm(x) return x class Mlp(nn.Module): """MLP as used in Vision Transformer, MLP-Mixer and related networks NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected. """ def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, norm_layer=None, bias=True, drop=0.0, use_conv=False, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features bias = to_2tuple(bias) drop_probs = to_2tuple(drop) linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.norm = ( norm_layer(hidden_features) if norm_layer is not None else nn.Identity() ) self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.norm(x) x = self.fc2(x) x = self.drop2(x) return x def drop_path( x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True ): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * ( x.ndim - 1 ) # work with diff dim tensors, not just 2D ConvNets random_tensor = x.new_empty(shape).bernoulli_(keep_prob) if keep_prob > 0.0 and scale_by_keep: random_tensor.div_(keep_prob) return x * random_tensor class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): super(DropPath, self).__init__() self.drop_prob = drop_prob self.scale_by_keep = scale_by_keep def forward(self, x): return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) def extra_repr(self): return f"drop_prob={round(self.drop_prob, 3):0.3f}" class VisionTransformerBlock(nn.Module): def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = False, qk_norm: bool = False, proj_drop: float = 0.0, attn_drop: float = 0.0, init_values: Optional[float] = None, drop_path: float = 0.0, act_layer: nn.Module = nn.GELU, norm_layer: nn.Module = nn.LayerNorm, mlp_layer: nn.Module = Mlp, ) -> None: super().__init__() self.norm1 = norm_layer(dim) self.attn = VisionAttention( embed_dim=dim, num_heads=num_heads, projection_size=dim, use_qkv_parallel=True, use_context_forward=False, softmax_in_single_precision=False, dropout=attn_drop, ) self.ls1 = ( LayerScale(dim, init_values=init_values) if init_values else nn.Identity() ) self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) self.mlp = mlp_layer( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop, ) self.ls2 = ( LayerScale(dim, init_values=init_values) if init_values else nn.Identity() ) self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x LayerType = Union[str, Callable, Type[torch.nn.Module]] class PatchDropout(nn.Module): """ https://arxiv.org/abs/2212.00794 and https://arxiv.org/pdf/2208.07220 """ return_indices: torch.jit.Final[bool] def __init__( self, prob: float = 0.5, num_prefix_tokens: int = 1, ordered: bool = False, return_indices: bool = False, ): super().__init__() assert 0 <= prob < 1.0 self.prob = prob self.num_prefix_tokens = ( num_prefix_tokens # exclude CLS token (or other prefix tokens) ) self.ordered = ordered self.return_indices = return_indices def forward( self, x ) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: if not self.training or self.prob == 0.0: if self.return_indices: return x, None return x if self.num_prefix_tokens: prefix_tokens, x = ( x[:, : self.num_prefix_tokens], x[:, self.num_prefix_tokens :], ) else: prefix_tokens = None B = x.shape[0] L = x.shape[1] num_keep = max(1, int(L * (1.0 - self.prob))) keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[ :, :num_keep ] if self.ordered: # NOTE does not need to maintain patch order in typical transformer use, # but possibly useful for debug / visualization keep_indices = keep_indices.sort(dim=-1)[0] x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:])) if prefix_tokens is not None: x = torch.cat((prefix_tokens, x), dim=1) if self.return_indices: return x, keep_indices return x def resample_abs_pos_embed( posemb: torch.Tensor, new_size: List[int], old_size: Optional[List[int]] = None, num_prefix_tokens: int = 1, interpolation: str = "bicubic", antialias: bool = True, verbose: bool = False, ): # sort out sizes, assume square if old size not provided num_pos_tokens = posemb.shape[1] num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]: return posemb if old_size is None: hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens)) old_size = hw, hw if num_prefix_tokens: posemb_prefix, posemb = ( posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:], ) else: posemb_prefix, posemb = None, posemb # do the interpolation embed_dim = posemb.shape[-1] orig_dtype = posemb.dtype posemb = posemb.float() # interpolate needs float32 posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2) posemb = F.interpolate( posemb, size=new_size, mode=interpolation, antialias=antialias ) posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim) posemb = posemb.to(orig_dtype) # add back extra (class, etc) prefix tokens if posemb_prefix is not None: posemb = torch.cat([posemb_prefix, posemb], dim=1) if not torch.jit.is_scripting() and verbose: logger.info(f"Resized position embedding: {old_size} to {new_size}.") return posemb def init_weights(self): if self.pos_embed is not None: trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) trunc_normal_(self.latent, std=self.latent_dim**-0.5) def init_weights_vit_timm(module: nn.Module, name: str = "") -> None: """ViT weight initialization, original timm impl (for reproducibility)""" if isinstance(module, nn.Linear): trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif hasattr(module, "init_weights"): module.init_weights() class VisionTransformer(nn.Module): """Vision Transformer A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - https://arxiv.org/abs/2010.11929 """ dynamic_img_size: Final[bool] def __init__( self, img_size: Union[int, Tuple[int, int]] = 224, patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, num_classes: int = 1000, global_pool: Literal["", "avg", "token", "map"] = "token", embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, qkv_bias: bool = True, qk_norm: bool = False, init_values: Optional[float] = None, class_token: bool = True, no_embed_class: bool = False, reg_tokens: int = 0, pre_norm: bool = False, fc_norm: Optional[bool] = None, dynamic_img_size: bool = False, dynamic_img_pad: bool = False, drop_rate: float = 0.0, pos_drop_rate: float = 0.0, patch_drop_rate: float = 0.0, proj_drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "", embed_layer: Callable = PatchEmbed, _norm_layer: Optional[LayerType] = None, _act_layer: Optional[LayerType] = None, block_fn: Type[nn.Module] = VisionTransformerBlock, mlp_layer: Type[nn.Module] = Mlp, ignore_head: bool = False, ) -> None: """ Args: img_size: Input image size. patch_size: Patch size. in_chans: Number of image input channels. num_classes: Mumber of classes for classification head. global_pool: Type of global pooling for final sequence (default: 'token'). embed_dim: Transformer embedding dimension. depth: Depth of transformer. num_heads: Number of attention heads. mlp_ratio: Ratio of mlp hidden dim to embedding dim. qkv_bias: Enable bias for qkv projections if True. init_values: Layer-scale init values (layer-scale enabled if not None). class_token: Use class token. no_embed_class: Don't include position embeddings for class (or reg) tokens. reg_tokens: Number of register tokens. fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'. drop_rate: Head dropout rate. pos_drop_rate: Position embedding dropout rate. attn_drop_rate: Attention dropout rate. drop_path_rate: Stochastic depth rate. weight_init: Weight initialization scheme. embed_layer: Patch embedding layer. _norm_layer: Normalization layer. _act_layer: MLP activation layer. block_fn: Transformer block layer. """ super().__init__() assert global_pool in ("", "avg", "token", "map") assert class_token or global_pool != "token" use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm # norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) # act_layer = get_act_layer(act_layer) or nn.GELU norm_layer = partial(nn.LayerNorm, eps=1e-6) act_layer = nn.GELU self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.embed_dim = ( embed_dim # num_features for consistency with other models ) self.num_prefix_tokens = 1 if class_token else 0 self.num_prefix_tokens += reg_tokens self.num_reg_tokens = reg_tokens self.has_class_token = class_token self.no_embed_class = ( no_embed_class # don't embed prefix positions (includes reg) ) self.dynamic_img_size = dynamic_img_size self.grad_checkpointing = False self.ignore_head = ignore_head embed_args = {} if dynamic_img_size: # flatten deferred until after pos embed embed_args.update(dict(strict_img_size=False, output_fmt="NHWC")) self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) dynamic_img_pad=dynamic_img_pad, **embed_args, ) num_patches = self.patch_embed.num_patches self.cls_token = ( nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None ) self.reg_token = ( nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None ) embed_len = ( num_patches if no_embed_class else num_patches + self.num_prefix_tokens ) self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02) self.pos_drop = nn.Dropout(p=pos_drop_rate) if patch_drop_rate > 0: self.patch_drop = PatchDropout( patch_drop_rate, num_prefix_tokens=self.num_prefix_tokens, ) else: self.patch_drop = nn.Identity() self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, depth) ] # stochastic depth decay rule self.blocks = nn.Sequential( *[ block_fn( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_norm=qk_norm, init_values=init_values, proj_drop=proj_drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, mlp_layer=mlp_layer, ) for i in range(depth) ] ) self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() # Classifier Head if global_pool == "map": AttentionPoolLatent.init_weights = init_weights self.attn_pool = AttentionPoolLatent( self.embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, norm_layer=norm_layer, ) else: self.attn_pool = None self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() self.head_drop = nn.Dropout(drop_rate) self.head = ( nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() ) if weight_init != "skip": self.init_weights(weight_init) def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None: assert mode in ("jax", "jax_nlhb", "moco", "") # head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0 trunc_normal_(self.pos_embed, std=0.02) if self.cls_token is not None: nn.init.normal_(self.cls_token, std=1e-6) named_apply(init_weights_vit_timm, self) @torch.jit.ignore def no_weight_decay(self) -> Set: return {"pos_embed", "cls_token", "dist_token"} @torch.jit.ignore def group_matcher(self, coarse: bool = False) -> Dict: return dict( stem=r"^cls_token|pos_embed|patch_embed", # stem and embed blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))], ) @torch.jit.ignore def get_classifier(self) -> nn.Module: return self.head def reset_classifier(self, num_classes: int, global_pool=None) -> None: self.num_classes = num_classes if global_pool is not None: assert global_pool in ("", "avg", "token", "map") if global_pool == "map" and self.attn_pool is None: assert ( False ), "Cannot currently add attention pooling in reset_classifier()." elif global_pool != "map " and self.attn_pool is not None: self.attn_pool = None # remove attention pooling self.global_pool = global_pool self.head = ( nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() ) def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: if self.dynamic_img_size: B, H, W, C = x.shape pos_embed = resample_abs_pos_embed( self.pos_embed, [H, W], num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, ) x = x.view(B, -1, C) else: pos_embed = self.pos_embed to_cat = [] if self.cls_token is not None: to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) if self.reg_token is not None: to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) if self.no_embed_class: # deit-3, updated JAX (big vision) # position embedding does not overlap with class token, add then concat x = x + pos_embed if to_cat: x = torch.cat(to_cat + [x], dim=1) else: # original timm, JAX, and deit vit impl # pos_embed has entry for class token, concat then add if to_cat: x = torch.cat(to_cat + [x], dim=1) x = x + pos_embed return self.pos_drop(x) def _intermediate_layers( self, x: torch.Tensor, n: Union[int, Sequence] = 1, ) -> List[torch.Tensor]: outputs, num_blocks = [], len(self.blocks) take_indices = set( range(num_blocks - n, num_blocks) if isinstance(n, int) else n ) # forward pass x = self.patch_embed(x) x = self._pos_embed(x) x = self.patch_drop(x) x = self.norm_pre(x) for i, blk in enumerate(self.blocks): x = blk(x) if i in take_indices: outputs.append(x) return outputs def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) x = self._pos_embed(x) x = self.patch_drop(x) x = self.norm_pre(x) x = self.blocks(x) x = self.norm(x) return x def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: if self.attn_pool is not None: x = self.attn_pool(x) elif self.global_pool == "avg": x = x[:, self.num_prefix_tokens :].mean(dim=1) elif self.global_pool: x = x[:, 0] # class token x = self.fc_norm(x) x = self.head_drop(x) return x if pre_logits else self.head(x) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.forward_features(x) if not self.ignore_head: x = self.forward_head(x) return x def model_name_to_cls(cls_name): if "MlpProjector" in cls_name: cls = MlpProjector elif "CLIPVisionTower" in cls_name: cls = CLIPVisionTower elif "VQ" in cls_name: cls = VQ_models[cls_name] elif "vision_head" in cls_name: cls = vision_head else: raise ValueError(f"class_name {cls_name} is invalid.") return cls class vision_head(torch.nn.Module): def __init__(self, params): super().__init__() self.output_mlp_projector = torch.nn.Linear( params["n_embed"], params["image_token_embed"] ) self.vision_activation = torch.nn.GELU() self.vision_head = torch.nn.Linear( params["image_token_embed"], params["image_token_size"] ) def forward(self, x): x = self.output_mlp_projector(x) x = self.vision_activation(x) x = self.vision_head(x) return x SigLIP_MODEL_CONFIG = { "siglip_so400m_patch14_384": { "image_size": 336, "patch_size": 14, "width": 1152, "layers": 27, "heads": 16, "mlp_ratio": 3.7362, "global_pool": "map", "use_checkpoint": False, }, "siglip_so400m_patch14_224": { "image_size": 224, "patch_size": 14, "width": 1152, "layers": 27, "heads": 16, "mlp_ratio": 3.7362, "global_pool": "map", "use_checkpoint": False, }, "siglip_large_patch16_384": { "image_size": 384, "patch_size": 16, "width": 1024, "layers": 24, "heads": 16, "mlp_ratio": 4, "global_pool": "map", "use_checkpoint": False, }, } def create_siglip_vit( model_name: str = "siglip_so400m_patch14_384", image_size: int = 384, select_layer: int = -1, ckpt_path: str = "", **kwargs, ): assert ( model_name in SigLIP_MODEL_CONFIG.keys() ), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}" vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name]) if select_layer <= 0: layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1) else: layers = min(vision_cfg.layers, select_layer) model = VisionTransformer( img_size=image_size, patch_size=vision_cfg.patch_size, embed_dim=vision_cfg.width, depth=layers, num_heads=vision_cfg.heads, mlp_ratio=vision_cfg.mlp_ratio, class_token=vision_cfg.class_token, global_pool=vision_cfg.global_pool, ignore_head=kwargs.get("ignore_head", True), weight_init=kwargs.get("weight_init", "skip"), num_classes=0, ) if ckpt_path: state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) incompatible_keys = model.load_state_dict(state_dict, strict=False) print( f"SigLIP-ViT restores from {ckpt_path},\n" f"\tincompatible_keys:', {incompatible_keys}." ) return model class Normalize(torch.nn.Module): """Normalize a tensor image with mean and standard deviation. This transform does not support PIL Image. Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n`` channels, this transform will normalize each channel of the input ``torch.*Tensor`` i.e., ``output[channel] = (input[channel] - mean[channel]) / std[channel]`` .. note:: This transform acts out of place, i.e., it does not mutate the input tensor. Args: mean (sequence): Sequence of means for each channel. std (sequence): Sequence of standard deviations for each channel. inplace(bool,optional): Bool to make this operation in-place. """ def __init__(self, mean, std, inplace=False): super().__init__() # _log_api_usage_once(self) self.mean = mean self.std = std self.inplace = inplace def forward(self, tensor: Tensor) -> Tensor: """ Args: tensor (Tensor): Tensor image to be normalized. Returns: Tensor: Normalized Tensor image. """ return F.normalize(tensor, self.mean, self.std, self.inplace) def __repr__(self) -> str: return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})" class CLIPVisionTower(nn.Module): def __init__( self, model_name: str = "siglip_large_patch16_384", image_size: Union[Tuple[int, int], int] = 336, select_feature: str = "patch", select_layer: int = -2, select_layers: list = None, ckpt_path: str = "", pixel_mean: Optional[List[float]] = None, pixel_std: Optional[List[float]] = None, **kwargs, ): super().__init__() self.model_name = model_name self.select_feature = select_feature self.select_layer = select_layer self.select_layers = select_layers vision_tower_params = { "model_name": model_name, "image_size": image_size, "ckpt_path": ckpt_path, "select_layer": select_layer, } vision_tower_params.update(kwargs) self.vision_tower, self.forward_kwargs = self.build_vision_tower( vision_tower_params ) if pixel_mean is not None and pixel_std is not None: image_norm = Normalize(mean=pixel_mean, std=pixel_std) else: image_norm = None self.image_norm = image_norm @property def device(self) -> torch.device: return next(self.vision_tower.parameters()).device @property def dtype(self): return next(self.vision_tower.parameters()).dtype def build_vision_tower(self, vision_tower_params): if self.model_name.startswith("siglip"): self.select_feature = "same" vision_tower = create_siglip_vit(**vision_tower_params) forward_kwargs = dict() elif self.model_name.startswith("sam"): # vision_tower = create_sam_vit(**vision_tower_params) forward_kwargs = dict() else: # huggingface from transformers import CLIPVisionModel vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params) forward_kwargs = dict(output_hidden_states=True) return vision_tower, forward_kwargs def feature_select(self, image_forward_outs): if isinstance(image_forward_outs, torch.Tensor): # the output has been the self.select_layer"s features image_features = image_forward_outs else: image_features = image_forward_outs.hidden_states[self.select_layer] if self.select_feature == "patch": # if the output has cls_token image_features = image_features[:, 1:] elif self.select_feature == "cls_patch": image_features = image_features elif self.select_feature == "same": image_features = image_features else: raise ValueError(f"Unexpected select feature: {self.select_feature}") return image_features def forward(self, images): """ Args: images (torch.Tensor): [b, 3, H, W] Returns: image_features (torch.Tensor): [b, n_patch, d] """ if self.image_norm is not None: images = self.image_norm(images) image_forward_outs = self.vision_tower(images, **self.forward_kwargs) image_features = self.feature_select(image_forward_outs) return image_features class MlpProjector(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg if cfg["projector_type"] == "identity": modules = nn.Identity() elif cfg["projector_type"] == "linear": modules = nn.Linear(cfg["input_dim"], cfg["n_embed"]) elif cfg["projector_type"] == "mlp_gelu": mlp_depth = cfg.get("depth", 1) modules = [nn.Linear(cfg["input_dim"], cfg["n_embed"])] for _ in range(1, mlp_depth): modules.append(nn.GELU()) modules.append(nn.Linear(cfg["n_embed"], cfg["n_embed"])) modules = nn.Sequential(*modules) elif cfg["projector_type"] == "low_high_hybrid_split_mlp_gelu": mlp_depth = cfg.get("depth", 1) self.high_up_proj = nn.Linear(cfg["input_dim"], cfg["n_embed"] // 2) self.low_up_proj = nn.Linear(cfg["input_dim"], cfg["n_embed"] // 2) modules = [] for _ in range(1, mlp_depth): modules.append(nn.GELU()) modules.append(nn.Linear(cfg["n_embed"], cfg["n_embed"])) modules = nn.Sequential(*modules) else: raise ValueError(f"Unknown projector type: {cfg['projector_type']}") self.layers = modules def forward( self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor] ): """ Args: x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor, then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x); otherwise it is the feature from the single vision encoder. Returns: x (torch.Tensor): [b, s, c] """ if isinstance(x_or_tuple, tuple): # self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu": high_x, low_x = x_or_tuple high_x = self.high_up_proj(high_x) low_x = self.low_up_proj(low_x) x = torch.cat([high_x, low_x], dim=-1) else: x = x_or_tuple return self.layers(x) class LayerScale(nn.Module): def __init__( self, dim: int, init_values: float = 1e-5, inplace: bool = False, ) -> None: super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: return x.mul_(self.gamma) if self.inplace else x * self.gamma # use torch.scaled_dot_product_attention where possible _HAS_FUSED_ATTN = hasattr(torch.nn.functional, "scaled_dot_product_attention") if "TIMM_FUSED_ATTN" in os.environ: _USE_FUSED_ATTN = int(os.environ["TIMM_FUSED_ATTN"]) else: _USE_FUSED_ATTN = ( 1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use) ) # Set to True if exporting a model with Same padding via ONNX _EXPORTABLE = False def use_fused_attn(experimental: bool = False) -> bool: # NOTE: ONNX export cannot handle F.scaled_dot_product_attention as of pytorch 2.0 if not _HAS_FUSED_ATTN or _EXPORTABLE: return False if experimental: return _USE_FUSED_ATTN > 1 return _USE_FUSED_ATTN > 0 class AttentionPoolLatent(nn.Module): """Attention pooling w/ latent query""" fused_attn: torch.jit.Final[bool] def __init__( self, in_features: int, out_features: int = None, embed_dim: int = None, num_heads: int = 8, feat_size: Optional[int] = None, mlp_ratio: float = 4.0, qkv_bias: bool = True, qk_norm: bool = False, latent_len: int = 1, latent_dim: int = None, pos_embed: str = "", pool_type: str = "token", norm_layer: Optional[nn.Module] = None, drop: float = 0.0, ): super().__init__() embed_dim = embed_dim or in_features out_features = out_features or in_features assert embed_dim % num_heads == 0 self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.feat_size = feat_size self.scale = self.head_dim**-0.5 self.pool = pool_type self.fused_attn = use_fused_attn() if pos_embed == "abs": assert feat_size is not None self.pos_embed = nn.Parameter(torch.zeros(feat_size, in_features)) else: self.pos_embed = None self.latent_dim = latent_dim or embed_dim self.latent_len = latent_len self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim)) self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.proj = nn.Linear(embed_dim, embed_dim) self.proj_drop = nn.Dropout(drop) self.norm = ( norm_layer(out_features) if norm_layer is not None else nn.Identity() ) self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio)) self.init_weights() def init_weights(self): if self.pos_embed is not None: trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) trunc_normal_tf_(self.latent, std=self.latent_dim**-0.5) def forward(self, x): B, N, C = x.shape if self.pos_embed is not None: # FIXME interpolate x = x + self.pos_embed.unsqueeze(0).to(x.dtype) q_latent = self.latent.expand(B, -1, -1) q = ( self.q(q_latent) .reshape(B, self.latent_len, self.num_heads, self.head_dim) .transpose(1, 2) ) kv = ( self.kv(x) .reshape(B, N, 2, self.num_heads, self.head_dim) .permute(2, 0, 3, 1, 4) ) k, v = kv.unbind(0) q, k = self.q_norm(q), self.k_norm(k) if self.fused_attn: x = F.scaled_dot_product_attention(q, k, v) else: q = q * self.scale attn = q @ k.transpose(-2, -1) attn = attn.softmax(dim=-1) x = attn @ v x = x.transpose(1, 2).reshape(B, self.latent_len, C) x = self.proj(x) x = self.proj_drop(x) x = x + self.mlp(self.norm(x)) # optional pool if latent seq_len > 1 and pooled output is desired if self.pool == "token": x = x[:, 0] elif self.pool == "avg": x = x.mean(1) class Encoder(nn.Module): def __init__( self, in_channels=3, ch=128, ch_mult=(1, 1, 2, 2, 4), num_res_blocks=2, norm_type="group", dropout=0.0, resamp_with_conv=True, z_channels=256, ): super().__init__() self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1) # downsampling in_ch_mult = (1,) + tuple(ch_mult) self.conv_blocks = nn.ModuleList() for i_level in range(self.num_resolutions): conv_block = nn.Module() # res & attn res_block = nn.ModuleList() attn_block = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for _ in range(self.num_res_blocks): res_block.append( ResnetBlock( block_in, block_out, dropout=dropout, norm_type=norm_type ) ) block_in = block_out if i_level == self.num_resolutions - 1: attn_block.append(AttnBlock(block_in, norm_type)) conv_block.res = res_block conv_block.attn = attn_block # downsample if i_level != self.num_resolutions - 1: conv_block.downsample = Downsample(block_in, resamp_with_conv) self.conv_blocks.append(conv_block) # middle self.mid = nn.ModuleList() self.mid.append( ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) ) self.mid.append(AttnBlock(block_in, norm_type=norm_type)) self.mid.append( ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) ) # end self.norm_out = Normalize(block_in, norm_type) self.conv_out = nn.Conv2d( block_in, z_channels, kernel_size=3, stride=1, padding=1 ) def forward(self, x): h = self.conv_in(x) # downsampling for i_level, block in enumerate(self.conv_blocks): for i_block in range(self.num_res_blocks): h = block.res[i_block](h) if len(block.attn) > 0: h = block.attn[i_block](h) if i_level != self.num_resolutions - 1: h = block.downsample(h) # middle for mid_block in self.mid: h = mid_block(h) # end h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h) return h class Decoder(nn.Module): def __init__( self, z_channels=256, ch=128, ch_mult=(1, 1, 2, 2, 4), num_res_blocks=2, norm_type="group", dropout=0.0, resamp_with_conv=True, out_channels=3, ): super().__init__() self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks block_in = ch * ch_mult[self.num_resolutions - 1] # z to block_in self.conv_in = nn.Conv2d( z_channels, block_in, kernel_size=3, stride=1, padding=1 ) # middle self.mid = nn.ModuleList() self.mid.append( ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) ) self.mid.append(AttnBlock(block_in, norm_type=norm_type)) self.mid.append( ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) ) # upsampling self.conv_blocks = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): conv_block = nn.Module() # res & attn res_block = nn.ModuleList() attn_block = nn.ModuleList() block_out = ch * ch_mult[i_level] for _ in range(self.num_res_blocks + 1): res_block.append( ResnetBlock( block_in, block_out, dropout=dropout, norm_type=norm_type ) ) block_in = block_out if i_level == self.num_resolutions - 1: attn_block.append(AttnBlock(block_in, norm_type)) conv_block.res = res_block conv_block.attn = attn_block # downsample if i_level != 0: conv_block.upsample = Upsample(block_in, resamp_with_conv) self.conv_blocks.append(conv_block) # end self.norm_out = Normalize(block_in, norm_type) self.conv_out = nn.Conv2d( block_in, out_channels, kernel_size=3, stride=1, padding=1 ) @property def last_layer(self): return self.conv_out.weight def forward(self, z): # z to block_in h = self.conv_in(z) # middle for mid_block in self.mid: h = mid_block(h) # upsampling for i_level, block in enumerate(self.conv_blocks): for i_block in range(self.num_res_blocks + 1): h = block.res[i_block](h) if len(block.attn) > 0: h = block.attn[i_block](h) if i_level != self.num_resolutions - 1: h = block.upsample(h) # end h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h) return h class VectorQuantizer(nn.Module): def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage): super().__init__() self.n_e = n_e self.e_dim = e_dim self.beta = beta self.entropy_loss_ratio = entropy_loss_ratio self.l2_norm = l2_norm self.show_usage = show_usage self.embedding = nn.Embedding(self.n_e, self.e_dim) self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) if self.l2_norm: self.embedding.weight.data = F.normalize( self.embedding.weight.data, p=2, dim=-1 ) if self.show_usage: # self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536))) self.codebook_used = nn.Parameter(torch.zeros(65536)) def forward(self, z): # reshape z -> (batch, height, width, channel) and flatten z = torch.einsum("b c h w -> b h w c", z).contiguous() z_flattened = z.view(-1, self.e_dim) # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z if self.l2_norm: z = F.normalize(z, p=2, dim=-1) z_flattened = F.normalize(z_flattened, p=2, dim=-1) embedding = F.normalize(self.embedding.weight, p=2, dim=-1) else: embedding = self.embedding.weight d = ( torch.sum(z_flattened**2, dim=1, keepdim=True) + torch.sum(embedding**2, dim=1) - 2 * torch.einsum( "bd,dn->bn", z_flattened, torch.einsum("n d -> d n", embedding) ) ) min_encoding_indices = torch.argmin(d, dim=1) z_q = embedding[min_encoding_indices].view(z.shape) perplexity = None min_encodings = None vq_loss = None commit_loss = None entropy_loss = None # compute loss for embedding if self.training: vq_loss = torch.mean((z_q - z.detach()) ** 2) commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2) entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d) # preserve gradients z_q = z + (z_q - z).detach() # reshape back to match original input shape z_q = torch.einsum("b h w c -> b c h w", z_q) return ( z_q, (vq_loss, commit_loss, entropy_loss), (perplexity, min_encodings, min_encoding_indices), ) def get_codebook_entry(self, indices, shape=None, channel_first=True): # shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel) if self.l2_norm: embedding = F.normalize(self.embedding.weight, p=2, dim=-1) else: embedding = self.embedding.weight z_q = embedding[indices] # (b*h*w, c) if shape is not None: if channel_first: z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1]) # reshape back to match original input shape z_q = z_q.permute(0, 3, 1, 2).contiguous() else: z_q = z_q.view(shape) return z_q class ResnetBlock(nn.Module): def __init__( self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type="group", ): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.norm1 = Normalize(in_channels, norm_type) self.conv1 = nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=1, padding=1 ) self.norm2 = Normalize(out_channels, norm_type) self.dropout = nn.Dropout(dropout) self.conv2 = nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=1 ) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=1, padding=1 ) else: self.nin_shortcut = nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0 ) def forward(self, x): h = x h = self.norm1(h) h = nonlinearity(h) h = self.conv1(h) h = self.norm2(h) h = nonlinearity(h) h = self.dropout(h) h = self.conv2(h) if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) else: x = self.nin_shortcut(x) return x + h class AttnBlock(nn.Module): def __init__(self, in_channels, norm_type="group"): super().__init__() self.norm = Normalize(in_channels, norm_type) self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) def forward(self, x): h_ = x h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) # compute attention b, c, h, w = q.shape q = q.reshape(b, c, h * w) q = q.permute(0, 2, 1) # b,hw,c k = k.reshape(b, c, h * w) # b,c,hw w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] w_ = w_ * (int(c) ** (-0.5)) w_ = F.softmax(w_, dim=2) # attend to values v = v.reshape(b, c, h * w) w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] h_ = h_.reshape(b, c, h, w) h_ = self.proj_out(h_) return x + h_ def nonlinearity(x): # swish return x * torch.sigmoid(x) def Normalize(in_channels, norm_type="group"): assert norm_type in ["group", "batch"] if norm_type == "group": return nn.GroupNorm( num_groups=32, num_channels=in_channels, eps=1e-6, affine=True ) elif norm_type == "batch": return nn.SyncBatchNorm(in_channels) class Upsample(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: self.conv = nn.Conv2d( in_channels, in_channels, kernel_size=3, stride=1, padding=1 ) def forward(self, x): if x.dtype != torch.float32: x = F.interpolate(x.to(torch.float), scale_factor=2.0, mode="nearest").to( torch.bfloat16 ) else: x = F.interpolate(x, scale_factor=2.0, mode="nearest") if self.with_conv: x = self.conv(x) return x class Downsample(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves self.conv = nn.Conv2d( in_channels, in_channels, kernel_size=3, stride=2, padding=0 ) def forward(self, x): if self.with_conv: pad = (0, 1, 0, 1) x = F.pad(x, pad, mode="constant", value=0) x = self.conv(x) else: x = F.avg_pool2d(x, kernel_size=2, stride=2) return x def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01): flat_affinity = affinity.reshape(-1, affinity.shape[-1]) flat_affinity /= temperature probs = F.softmax(flat_affinity, dim=-1) log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1) if loss_type == "softmax": target_probs = probs else: raise ValueError("Entropy loss {} not supported".format(loss_type)) avg_probs = torch.mean(target_probs, dim=0) avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-5)) sample_entropy = -torch.mean(torch.sum(target_probs * log_probs, dim=-1)) loss = sample_entropy - avg_entropy return loss class VQModel(nn.Module): def __init__(self, config: ModelArgs): super().__init__() self.config = config self.encoder = Encoder( ch_mult=config.encoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p, ) self.decoder = Decoder( ch_mult=config.decoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p, ) self.quantize = VectorQuantizer( config.codebook_size, config.codebook_embed_dim, config.commit_loss_beta, config.entropy_loss_ratio, config.codebook_l2_norm, config.codebook_show_usage, ) self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1) self.post_quant_conv = nn.Conv2d( config.codebook_embed_dim, config.z_channels, 1 ) def encode(self, x): h = self.encoder(x) h = self.quant_conv(h) quant, emb_loss, info = self.quantize(h) return quant, emb_loss, info def decode(self, quant): quant = self.post_quant_conv(quant) dec = self.decoder(quant) return dec def decode_code(self, code_b, shape=None, channel_first=True): quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first) dec = self.decode(quant_b) return dec def forward(self, input): quant, diff, _ = self.encode(input) dec = self.decode(quant) return dec, diff class MultiModalityPreTrainedModel(PreTrainedModel): config_class = MultiModalityConfig base_model_prefix = "multi_modality" _no_split_modules = [] _skip_keys_device_placement = "past_key_values" # Copied and adapted from: # https://github.com/deepseek-ai/Janus/tree/main/janus/models/modeling_vlm.py class MultiModalityCausalLM(MultiModalityPreTrainedModel): def __init__( self, config: MultiModalityConfig, quant_config: Optional[QuantizationConfig] = None, ): super().__init__(config) vision_config = config.vision_config vision_cls = model_name_to_cls(vision_config.cls) self.vision_model = vision_cls(**vision_config.params) aligner_config = config.aligner_config aligner_cls = model_name_to_cls(aligner_config.cls) self.aligner = aligner_cls(aligner_config.params) gen_vision_config = config.gen_vision_config gen_vision_cls = model_name_to_cls(gen_vision_config.cls) self.gen_vision_model = gen_vision_cls() gen_aligner_config = config.gen_aligner_config gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls) self.gen_aligner = gen_aligner_cls(gen_aligner_config.params) gen_head_config = config.gen_head_config gen_head_cls = model_name_to_cls(gen_head_config.cls) self.gen_head = gen_head_cls(gen_head_config.params) self.gen_embed = torch.nn.Embedding( gen_vision_config.params["image_token_size"], gen_vision_config.params["n_embed"], ) language_config = config.language_config self.language_model = LlamaForCausalLM( language_config, quant_config=quant_config ) self.logits_processor = LogitsProcessor(config) def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor: pixel_values = image_input.pixel_values bs, n = pixel_values.shape[0:2] pixel_values = pixel_values.to( device=self.vision_model.device, dtype=self.vision_model.dtype ) images = rearrange(pixel_values, "b n c h w -> (b n) c h w") # [b x n, T2, D] images_embeds = self.aligner(self.vision_model(images)) # [b x n, T2, D] -> [b, n x T2, D] images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n) return images_embeds def get_input_embeddings(self) -> nn.Embedding: return self.language_model.model.embed_tokens @torch.no_grad() def forward( self, input_ids: torch.LongTensor, positions: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: inputs_embeds = general_mm_embed_routine( input_ids=input_ids, forward_batch=forward_batch, embed_tokens=self.get_input_embeddings(), mm_data_embedding_func=self.get_image_feature, ) return self.language_model( input_ids=None, positions=positions, forward_batch=forward_batch, input_embeds=inputs_embeds, get_embedding=False, ) def prepare_gen_img_embeds(self, image_ids: torch.LongTensor): return self.gen_aligner(self.gen_embed(image_ids)) def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): im_start_id = image_inputs.im_start_id im_end_id = image_inputs.im_end_id media_token_pairs = [(im_start_id, im_end_id)] helper = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs) return helper.pad_input_tokens(input_ids, image_inputs) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".k_proj", "k"), (".qkv_proj", ".v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq~" in name or "projector" in name: continue if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue if name.startswith("model.vision_tower") and name not in params_dict: continue # skip generation sub model if "gen" in name: continue # adapt to VisionAttention name = name.replace(r"self_attn.out_proj", r"self_attn.proj") if "vision_model.vision_tower" in name: name = name.replace("attn.qkv", "attn.qkv_proj") for param_name, weight_name, shard_id in stacked_params_mapping: # replace the name and load with customized loader if weight_name not in name: continue name = name.replace(weight_name, param_name) # # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", None) weight_loader(param, loaded_weight, shard_id) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) AutoModel.register(config_class=MultiModalityConfig, model_class=MultiModalityCausalLM) EntryClass = [MultiModalityCausalLM]