import math from collections import OrderedDict from functools import partial from typing import Any, Callable, NamedTuple, Optional import torch import torch.nn as nn from ..ops.misc import Conv2dNormActivation, MLP from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import register_model, Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param, handle_legacy_interface __all__ = [ "VisionTransformer", "ViT_B_16_Weights", "ViT_B_32_Weights", "ViT_L_16_Weights", "ViT_L_32_Weights", "ViT_H_14_Weights", "vit_b_16", "vit_b_32", "vit_l_16", "vit_l_32", "vit_h_14", ] class ConvStemConfig(NamedTuple): out_channels: int kernel_size: int stride: int norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d activation_layer: Callable[..., nn.Module] = nn.ReLU class MLPBlock(MLP): """Transformer MLP block.""" _version = 2 def __init__(self, in_dim: int, mlp_dim: int, dropout: float): super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout) for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.normal_(m.bias, std=1e-6) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ): version = local_metadata.get("version", None) if version is None or version < 2: # Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053 for i in range(2): for type in ["weight", "bias"]: old_key = f"{prefix}linear_{i+1}.{type}" new_key = f"{prefix}{3*i}.{type}" if old_key in state_dict: state_dict[new_key] = state_dict.pop(old_key) super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ) class EncoderBlock(nn.Module): """Transformer encoder block.""" def __init__( self, num_heads: int, hidden_dim: int, mlp_dim: int, dropout: float, attention_dropout: float, norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), ): super().__init__() self.num_heads = num_heads # Attention block self.ln_1 = norm_layer(hidden_dim) self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True) self.dropout = nn.Dropout(dropout) # MLP block self.ln_2 = norm_layer(hidden_dim) self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout) def forward(self, input: torch.Tensor): torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") x = self.ln_1(input) x, _ = self.self_attention(x, x, x, need_weights=False) x = self.dropout(x) x = x + input y = self.ln_2(x) y = self.mlp(y) return x + y class Encoder(nn.Module): """Transformer Model Encoder for sequence to sequence translation.""" def __init__( self, seq_length: int, num_layers: int, num_heads: int, hidden_dim: int, mlp_dim: int, dropout: float, attention_dropout: float, norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), ): super().__init__() # Note that batch_size is on the first dim because # we have batch_first=True in nn.MultiAttention() by default self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT self.dropout = nn.Dropout(dropout) layers: OrderedDict[str, nn.Module] = OrderedDict() for i in range(num_layers): layers[f"encoder_layer_{i}"] = EncoderBlock( num_heads, hidden_dim, mlp_dim, dropout, attention_dropout, norm_layer, ) self.layers = nn.Sequential(layers) self.ln = norm_layer(hidden_dim) def forward(self, input: torch.Tensor): torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") input = input + self.pos_embedding return self.ln(self.layers(self.dropout(input))) class VisionTransformer(nn.Module): """Vision Transformer as per https://arxiv.org/abs/2010.11929.""" def __init__( self, image_size: int, patch_size: int, num_layers: int, num_heads: int, hidden_dim: int, mlp_dim: int, dropout: float = 0.0, attention_dropout: float = 0.0, num_classes: int = 1000, representation_size: Optional[int] = None, norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), conv_stem_configs: Optional[list[ConvStemConfig]] = None, ): super().__init__() _log_api_usage_once(self) torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!") self.image_size = image_size self.patch_size = patch_size self.hidden_dim = hidden_dim self.mlp_dim = mlp_dim self.attention_dropout = attention_dropout self.dropout = dropout self.num_classes = num_classes self.representation_size = representation_size self.norm_layer = norm_layer if conv_stem_configs is not None: # As per https://arxiv.org/abs/2106.14881 seq_proj = nn.Sequential() prev_channels = 3 for i, conv_stem_layer_config in enumerate(conv_stem_configs): seq_proj.add_module( f"conv_bn_relu_{i}", Conv2dNormActivation( in_channels=prev_channels, out_channels=conv_stem_layer_config.out_channels, kernel_size=conv_stem_layer_config.kernel_size, stride=conv_stem_layer_config.stride, norm_layer=conv_stem_layer_config.norm_layer, activation_layer=conv_stem_layer_config.activation_layer, ), ) prev_channels = conv_stem_layer_config.out_channels seq_proj.add_module( "conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1) ) self.conv_proj: nn.Module = seq_proj else: self.conv_proj = nn.Conv2d( in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size ) seq_length = (image_size // patch_size) ** 2 # Add a class token self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) seq_length += 1 self.encoder = Encoder( seq_length, num_layers, num_heads, hidden_dim, mlp_dim, dropout, attention_dropout, norm_layer, ) self.seq_length = seq_length heads_layers: OrderedDict[str, nn.Module] = OrderedDict() if representation_size is None: heads_layers["head"] = nn.Linear(hidden_dim, num_classes) else: heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size) heads_layers["act"] = nn.Tanh() heads_layers["head"] = nn.Linear(representation_size, num_classes) self.heads = nn.Sequential(heads_layers) if isinstance(self.conv_proj, nn.Conv2d): # Init the patchify stem fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1] nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in)) if self.conv_proj.bias is not None: nn.init.zeros_(self.conv_proj.bias) elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d): # Init the last 1x1 conv of the conv stem nn.init.normal_( self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels) ) if self.conv_proj.conv_last.bias is not None: nn.init.zeros_(self.conv_proj.conv_last.bias) if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear): fan_in = self.heads.pre_logits.in_features nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in)) nn.init.zeros_(self.heads.pre_logits.bias) if isinstance(self.heads.head, nn.Linear): nn.init.zeros_(self.heads.head.weight) nn.init.zeros_(self.heads.head.bias) def _process_input(self, x: torch.Tensor) -> torch.Tensor: n, c, h, w = x.shape p = self.patch_size torch._assert(h == self.image_size, f"Wrong image height! Expected {self.image_size} but got {h}!") torch._assert(w == self.image_size, f"Wrong image width! Expected {self.image_size} but got {w}!") n_h = h // p n_w = w // p # (n, c, h, w) -> (n, hidden_dim, n_h, n_w) x = self.conv_proj(x) # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w)) x = x.reshape(n, self.hidden_dim, n_h * n_w) # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim) # The self attention layer expects inputs in the format (N, S, E) # where S is the source sequence length, N is the batch size, E is the # embedding dimension x = x.permute(0, 2, 1) return x def forward(self, x: torch.Tensor): # Reshape and permute the input tensor x = self._process_input(x) n = x.shape[0] # Expand the class token to the full batch batch_class_token = self.class_token.expand(n, -1, -1) x = torch.cat([batch_class_token, x], dim=1) x = self.encoder(x) # Classifier "token" as used by standard language architectures x = x[:, 0] x = self.heads(x) return x def _vision_transformer( patch_size: int, num_layers: int, num_heads: int, hidden_dim: int, mlp_dim: int, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> VisionTransformer: if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) assert weights.meta["min_size"][0] == weights.meta["min_size"][1] _ovewrite_named_param(kwargs, "image_size", weights.meta["min_size"][0]) image_size = kwargs.pop("image_size", 224) model = VisionTransformer( image_size=image_size, patch_size=patch_size, num_layers=num_layers, num_heads=num_heads, hidden_dim=hidden_dim, mlp_dim=mlp_dim, **kwargs, ) if weights: model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) return model _COMMON_META: dict[str, Any] = { "categories": _IMAGENET_CATEGORIES, } _COMMON_SWAG_META = { **_COMMON_META, "recipe": "https://github.com/facebookresearch/SWAG", "license": "https://github.com/facebookresearch/SWAG/blob/main/LICENSE", } class ViT_B_16_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vit_b_16-c867db91.pth", transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 86567656, "min_size": (224, 224), "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_16", "_metrics": { "ImageNet-1K": { "acc@1": 81.072, "acc@5": 95.318, } }, "_ops": 17.564, "_file_size": 330.285, "_docs": """ These weights were trained from scratch by using a modified version of `DeIT `_'s training recipe. """, }, ) IMAGENET1K_SWAG_E2E_V1 = Weights( url="https://download.pytorch.org/models/vit_b_16_swag-9ac1b537.pth", transforms=partial( ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC, ), meta={ **_COMMON_SWAG_META, "num_params": 86859496, "min_size": (384, 384), "_metrics": { "ImageNet-1K": { "acc@1": 85.304, "acc@5": 97.650, } }, "_ops": 55.484, "_file_size": 331.398, "_docs": """ These weights are learnt via transfer learning by end-to-end fine-tuning the original `SWAG `_ weights on ImageNet-1K data. """, }, ) IMAGENET1K_SWAG_LINEAR_V1 = Weights( url="https://download.pytorch.org/models/vit_b_16_lc_swag-4e70ced5.pth", transforms=partial( ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC, ), meta={ **_COMMON_SWAG_META, "recipe": "https://github.com/pytorch/vision/pull/5793", "num_params": 86567656, "min_size": (224, 224), "_metrics": { "ImageNet-1K": { "acc@1": 81.886, "acc@5": 96.180, } }, "_ops": 17.564, "_file_size": 330.285, "_docs": """ These weights are composed of the original frozen `SWAG `_ trunk weights and a linear classifier learnt on top of them trained on ImageNet-1K data. """, }, ) DEFAULT = IMAGENET1K_V1 class ViT_B_32_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 88224232, "min_size": (224, 224), "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_32", "_metrics": { "ImageNet-1K": { "acc@1": 75.912, "acc@5": 92.466, } }, "_ops": 4.409, "_file_size": 336.604, "_docs": """ These weights were trained from scratch by using a modified version of `DeIT `_'s training recipe. """, }, ) DEFAULT = IMAGENET1K_V1 class ViT_L_16_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", transforms=partial(ImageClassification, crop_size=224, resize_size=242), meta={ **_COMMON_META, "num_params": 304326632, "min_size": (224, 224), "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_16", "_metrics": { "ImageNet-1K": { "acc@1": 79.662, "acc@5": 94.638, } }, "_ops": 61.555, "_file_size": 1161.023, "_docs": """ These weights were trained from scratch by using a modified version of TorchVision's `new training recipe `_. """, }, ) IMAGENET1K_SWAG_E2E_V1 = Weights( url="https://download.pytorch.org/models/vit_l_16_swag-4f3808c9.pth", transforms=partial( ImageClassification, crop_size=512, resize_size=512, interpolation=InterpolationMode.BICUBIC, ), meta={ **_COMMON_SWAG_META, "num_params": 305174504, "min_size": (512, 512), "_metrics": { "ImageNet-1K": { "acc@1": 88.064, "acc@5": 98.512, } }, "_ops": 361.986, "_file_size": 1164.258, "_docs": """ These weights are learnt via transfer learning by end-to-end fine-tuning the original `SWAG `_ weights on ImageNet-1K data. """, }, ) IMAGENET1K_SWAG_LINEAR_V1 = Weights( url="https://download.pytorch.org/models/vit_l_16_lc_swag-4d563306.pth", transforms=partial( ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC, ), meta={ **_COMMON_SWAG_META, "recipe": "https://github.com/pytorch/vision/pull/5793", "num_params": 304326632, "min_size": (224, 224), "_metrics": { "ImageNet-1K": { "acc@1": 85.146, "acc@5": 97.422, } }, "_ops": 61.555, "_file_size": 1161.023, "_docs": """ These weights are composed of the original frozen `SWAG `_ trunk weights and a linear classifier learnt on top of them trained on ImageNet-1K data. """, }, ) DEFAULT = IMAGENET1K_V1 class ViT_L_32_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vit_l_32-c7638314.pth", transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 306535400, "min_size": (224, 224), "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_32", "_metrics": { "ImageNet-1K": { "acc@1": 76.972, "acc@5": 93.07, } }, "_ops": 15.378, "_file_size": 1169.449, "_docs": """ These weights were trained from scratch by using a modified version of `DeIT `_'s training recipe. """, }, ) DEFAULT = IMAGENET1K_V1 class ViT_H_14_Weights(WeightsEnum): IMAGENET1K_SWAG_E2E_V1 = Weights( url="https://download.pytorch.org/models/vit_h_14_swag-80465313.pth", transforms=partial( ImageClassification, crop_size=518, resize_size=518, interpolation=InterpolationMode.BICUBIC, ), meta={ **_COMMON_SWAG_META, "num_params": 633470440, "min_size": (518, 518), "_metrics": { "ImageNet-1K": { "acc@1": 88.552, "acc@5": 98.694, } }, "_ops": 1016.717, "_file_size": 2416.643, "_docs": """ These weights are learnt via transfer learning by end-to-end fine-tuning the original `SWAG `_ weights on ImageNet-1K data. """, }, ) IMAGENET1K_SWAG_LINEAR_V1 = Weights( url="https://download.pytorch.org/models/vit_h_14_lc_swag-c1eb923e.pth", transforms=partial( ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC, ), meta={ **_COMMON_SWAG_META, "recipe": "https://github.com/pytorch/vision/pull/5793", "num_params": 632045800, "min_size": (224, 224), "_metrics": { "ImageNet-1K": { "acc@1": 85.708, "acc@5": 97.730, } }, "_ops": 167.295, "_file_size": 2411.209, "_docs": """ These weights are composed of the original frozen `SWAG `_ trunk weights and a linear classifier learnt on top of them trained on ImageNet-1K data. """, }, ) DEFAULT = IMAGENET1K_SWAG_E2E_V1 @register_model() @handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1)) def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_b_16 architecture from `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `_. Args: weights (:class:`~torchvision.models.ViT_B_16_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.ViT_B_16_Weights` below for more details and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer`` base class. Please refer to the `source code `_ for more details about this class. .. autoclass:: torchvision.models.ViT_B_16_Weights :members: """ weights = ViT_B_16_Weights.verify(weights) return _vision_transformer( patch_size=16, num_layers=12, num_heads=12, hidden_dim=768, mlp_dim=3072, weights=weights, progress=progress, **kwargs, ) @register_model() @handle_legacy_interface(weights=("pretrained", ViT_B_32_Weights.IMAGENET1K_V1)) def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_b_32 architecture from `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `_. Args: weights (:class:`~torchvision.models.ViT_B_32_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.ViT_B_32_Weights` below for more details and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer`` base class. Please refer to the `source code `_ for more details about this class. .. autoclass:: torchvision.models.ViT_B_32_Weights :members: """ weights = ViT_B_32_Weights.verify(weights) return _vision_transformer( patch_size=32, num_layers=12, num_heads=12, hidden_dim=768, mlp_dim=3072, weights=weights, progress=progress, **kwargs, ) @register_model() @handle_legacy_interface(weights=("pretrained", ViT_L_16_Weights.IMAGENET1K_V1)) def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_l_16 architecture from `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `_. Args: weights (:class:`~torchvision.models.ViT_L_16_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.ViT_L_16_Weights` below for more details and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer`` base class. Please refer to the `source code `_ for more details about this class. .. autoclass:: torchvision.models.ViT_L_16_Weights :members: """ weights = ViT_L_16_Weights.verify(weights) return _vision_transformer( patch_size=16, num_layers=24, num_heads=16, hidden_dim=1024, mlp_dim=4096, weights=weights, progress=progress, **kwargs, ) @register_model() @handle_legacy_interface(weights=("pretrained", ViT_L_32_Weights.IMAGENET1K_V1)) def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_l_32 architecture from `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `_. Args: weights (:class:`~torchvision.models.ViT_L_32_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.ViT_L_32_Weights` below for more details and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer`` base class. Please refer to the `source code `_ for more details about this class. .. autoclass:: torchvision.models.ViT_L_32_Weights :members: """ weights = ViT_L_32_Weights.verify(weights) return _vision_transformer( patch_size=32, num_layers=24, num_heads=16, hidden_dim=1024, mlp_dim=4096, weights=weights, progress=progress, **kwargs, ) @register_model() @handle_legacy_interface(weights=("pretrained", None)) def vit_h_14(*, weights: Optional[ViT_H_14_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_h_14 architecture from `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `_. Args: weights (:class:`~torchvision.models.ViT_H_14_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.ViT_H_14_Weights` below for more details and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer`` base class. Please refer to the `source code `_ for more details about this class. .. autoclass:: torchvision.models.ViT_H_14_Weights :members: """ weights = ViT_H_14_Weights.verify(weights) return _vision_transformer( patch_size=14, num_layers=32, num_heads=16, hidden_dim=1280, mlp_dim=5120, weights=weights, progress=progress, **kwargs, ) def interpolate_embeddings( image_size: int, patch_size: int, model_state: "OrderedDict[str, torch.Tensor]", interpolation_mode: str = "bicubic", reset_heads: bool = False, ) -> "OrderedDict[str, torch.Tensor]": """This function helps interpolate positional embeddings during checkpoint loading, especially when you want to apply a pre-trained model on images with different resolution. Args: image_size (int): Image size of the new model. patch_size (int): Patch size of the new model. model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model. interpolation_mode (str): The algorithm used for upsampling. Default: bicubic. reset_heads (bool): If true, not copying the state of heads. Default: False. Returns: OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model. """ # Shape of pos_embedding is (1, seq_length, hidden_dim) pos_embedding = model_state["encoder.pos_embedding"] n, seq_length, hidden_dim = pos_embedding.shape if n != 1: raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}") new_seq_length = (image_size // patch_size) ** 2 + 1 # Need to interpolate the weights for the position embedding. # We do this by reshaping the positions embeddings to a 2d grid, performing # an interpolation in the (h, w) space and then reshaping back to a 1d grid. if new_seq_length != seq_length: # The class token embedding shouldn't be interpolated, so we split it up. seq_length -= 1 new_seq_length -= 1 pos_embedding_token = pos_embedding[:, :1, :] pos_embedding_img = pos_embedding[:, 1:, :] # (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length) pos_embedding_img = pos_embedding_img.permute(0, 2, 1) seq_length_1d = int(math.sqrt(seq_length)) if seq_length_1d * seq_length_1d != seq_length: raise ValueError( f"seq_length is not a perfect square! Instead got seq_length_1d * seq_length_1d = {seq_length_1d * seq_length_1d } and seq_length = {seq_length}" ) # (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d) pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d) new_seq_length_1d = image_size // patch_size # Perform interpolation. # (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) new_pos_embedding_img = nn.functional.interpolate( pos_embedding_img, size=new_seq_length_1d, mode=interpolation_mode, align_corners=True, ) # (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length) new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length) # (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim) new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1) new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1) model_state["encoder.pos_embedding"] = new_pos_embedding if reset_heads: model_state_copy: "OrderedDict[str, torch.Tensor]" = OrderedDict() for k, v in model_state.items(): if not k.startswith("heads"): model_state_copy[k] = v model_state = model_state_copy return model_state