# Adapted from: # https://github.com/vllm-project/vllm/blob/7193774b1ff8603ad5bf4598e5efba0d9a39b436/vllm/model_executor/models/mllama.py """PyTorch Mllama model.""" import math from typing import Iterable, List, Optional, Tuple, Union import torch import torch.nn.functional as F import torch.utils.checkpoint import transformers.models.mllama.configuration_mllama as config_mllama from torch import nn from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast from transformers.models.mllama.modeling_mllama import ( _prepare_aspect_ratio_attention_mask, ) import sglang.srt.distributed.parallel_state as ps from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) from sglang.srt.managers.schedule_batch import MultimodalInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP from sglang.srt.utils import add_prefix class ColumnParallelConv2dPatch(torch.nn.Module): """Conv2D Patching layer with model parallelism. Column parallel over unfolded input. Arguments: in_channels: Input channels. out_channels: Output channels. kernel_size: Size of convolution kernel. stride (default 1): Stride for convolution. bias (default False): Use bias in Conv2d. Input: (bsz, in_channels, width, height) Output: (bsz, num_tokens, out_channels) """ def __init__( self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], bias: bool = False, ) -> None: super().__init__() if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride) self._linear = ColumnParallelLinear( in_channels * kernel_size[0] * kernel_size[1], out_channels, bias=bias, ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self._unfold(x) x = x.permute(0, 2, 1) x, _ = self._linear(x) return x class MllamaPrecomputedAspectRatioEmbedding(nn.Module): def __init__(self, config: config_mllama.MllamaVisionConfig, is_gated: bool = True): super().__init__() self.max_num_tiles = config.max_num_tiles self.hidden_size = config.hidden_size self.max_aspect_ratio_id = config.max_aspect_ratio_id self.is_gated = is_gated self.embedding = nn.Embedding( self.max_aspect_ratio_id + 1, self.max_num_tiles * self.hidden_size ) if is_gated: self.gate = nn.Parameter(torch.zeros(1)) def forward( self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor ) -> torch.Tensor: embeddings = self.embedding(aspect_ratio_ids) embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size) if self.is_gated: embeddings = embeddings * self.gate.tanh() hidden_state = hidden_state + embeddings return hidden_state class MllamaPrecomputedPositionEmbedding(nn.Module): def __init__(self, config: config_mllama.MllamaVisionConfig): super().__init__() self.max_num_tiles = config.max_num_tiles self.max_aspect_ratio_id = config.max_aspect_ratio_id self.num_patches = (config.image_size // config.patch_size) ** 2 + 1 self.hidden_size = config.hidden_size self.scale = config.hidden_size**-0.5 self.gate = nn.Parameter(torch.zeros(1)) # position embedding position_embedding = torch.randn(self.num_patches, self.hidden_size) self.embedding = nn.Parameter(self.scale * position_embedding) # tile position embedding self.tile_embedding = nn.Embedding( self.max_aspect_ratio_id + 1, self.max_num_tiles * self.num_patches * self.hidden_size, ) def forward( self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor ) -> torch.Tensor: # position embeddings gated_position_embedding = (1 - self.gate.tanh()) * self.embedding hidden_state = hidden_state + gated_position_embedding.view( 1, 1, self.num_patches, self.hidden_size ) # precomputed tile position embeddings tile_position_embedding = self.tile_embedding(aspect_ratio_ids) batch_size = hidden_state.shape[0] tile_position_embedding = tile_position_embedding.reshape( batch_size, self.max_num_tiles, self.num_patches, self.hidden_size ) gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding hidden_state = hidden_state + gated_tile_position_embedding return hidden_state class MllamaVisionMLP(nn.Module): def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.config = config self.activation_fn = get_act_fn(config.hidden_act) self.fc1 = ColumnParallelLinear( config.hidden_size, config.intermediate_size, bias=True, quant_config=quant_config, prefix=add_prefix("fc1", prefix), ) self.fc2 = RowParallelLinear( config.intermediate_size, config.hidden_size, bias=True, quant_config=quant_config, prefix=add_prefix("fc2", prefix), ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states, _ = self.fc2(hidden_states) return hidden_states class MllamaVisionEncoderLayer(nn.Module): def __init__( self, config: config_mllama.MllamaVisionConfig, is_gated: bool = False, prefix: str = "", ): super().__init__() self.hidden_size = config.hidden_size self.num_attention_heads = config.attention_heads self.is_gated = is_gated self.intermediate_size = config.intermediate_size self.self_attn = VisionAttention( self.hidden_size, self.num_attention_heads, self.hidden_size, use_qkv_parallel=True, quant_config=None, dropout=0.0, use_context_forward=False, softmax_in_single_precision=False, flatten_batch=False, prefix=add_prefix("self_attn", prefix), ) self.mlp = MllamaVisionMLP(config, prefix=add_prefix("mlp", prefix)) self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps) self.post_attention_layernorm = nn.LayerNorm( self.hidden_size, eps=config.norm_eps ) # there used to be an if else here, no code path if is_gated: self.gate_attn = nn.Parameter(torch.ones(1) * math.pi / 4) self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4) def forward( self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ): # Self Attention residual = hidden_state hidden_state = self.input_layernorm(hidden_state) hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask) gate_attn = 1 if not self.is_gated else self.gate_attn.tanh() hidden_state = residual + gate_attn * hidden_state # Feed forward residual = hidden_state hidden_state = self.post_attention_layernorm(hidden_state) hidden_state = self.mlp(hidden_state) gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh() hidden_state = residual + gate_ffn * hidden_state return hidden_state class MllamaVisionEncoder(nn.Module): def __init__( self, config: config_mllama.MllamaVisionConfig, num_layers=32, is_gated=False, output_hidden_states=None, prefix: str = "", ): super().__init__() self.config = config self.layers = nn.ModuleList( [ MllamaVisionEncoderLayer( config, is_gated, prefix=add_prefix(f"layers.{i}", prefix) ) for i in range(num_layers) ] ) self.output_hidden_states = output_hidden_states or [] def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutput]: encoder_states = () for i, encoder_layer in enumerate(self.layers): if i in self.output_hidden_states: encoder_states = encoder_states + (hidden_states,) hidden_states = encoder_layer( hidden_states, attention_mask, ) if len(self.layers) - 1 in self.output_hidden_states: encoder_states = encoder_states + (hidden_states,) return hidden_states, encoder_states class MllamaVisionModel(nn.Module): def __init__(self, config: config_mllama.MllamaVisionConfig, prefix: str = ""): super().__init__() self.image_size = config.image_size self.patch_size = config.patch_size self.max_num_tiles = config.max_num_tiles self.hidden_size = config.hidden_size self.in_channels = config.num_channels self.intermediate_layers_indices = config.intermediate_layers_indices self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 self.scale = config.hidden_size**-0.5 self.patch_embedding = ColumnParallelConv2dPatch( in_channels=config.num_channels, out_channels=self.hidden_size, kernel_size=self.patch_size, stride=self.patch_size, bias=False, ) self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size)) self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(config) self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( config, is_gated=True ) self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( config, is_gated=True ) # layer norms self.layernorm_pre = nn.LayerNorm(self.hidden_size) self.layernorm_post = nn.LayerNorm(self.hidden_size) # encoders self.transformer = MllamaVisionEncoder( config, config.num_hidden_layers, is_gated=False, output_hidden_states=config.intermediate_layers_indices, prefix=add_prefix("transformer", prefix), ) self.global_transformer = MllamaVisionEncoder( config, config.num_global_layers, is_gated=True, prefix=add_prefix("global_transformer", prefix), ) def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: batch_size, _, hidden_size = hidden_state.shape class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size) hidden_state = torch.cat([class_embedding, hidden_state], dim=1) return hidden_state def forward( self, pixel_values: torch.Tensor, aspect_ratio_ids: torch.Tensor, aspect_ratio_mask: torch.Tensor, ) -> torch.Tensor: batch_size, num_concurrent_media, num_tiles, num_channels, height, width = ( pixel_values.shape ) pixel_values = pixel_values.reshape( batch_size * num_concurrent_media * num_tiles, num_channels, height, width ) aspect_ratio_ids = aspect_ratio_ids.reshape( batch_size * num_concurrent_media, -1 ) # patch embedding patch_embeds = self.patch_embedding( pixel_values.to(self.layernorm_pre.weight.dtype) ) hidden_state = patch_embeds hidden_state = ps.get_tp_group().all_gather(hidden_state) # tile embeddings _, num_patches, dim = hidden_state.shape hidden_state = hidden_state.reshape( batch_size * num_concurrent_media, num_tiles, -1, dim ) hidden_state = self.pre_tile_positional_embedding( hidden_state, aspect_ratio_ids ) # apply cls token hidden_state = hidden_state.reshape( batch_size * num_concurrent_media * num_tiles, num_patches, dim ) hidden_state = self.apply_class_embedding(hidden_state) num_patches += 1 # apply position embeddings hidden_state = hidden_state.reshape( batch_size * num_concurrent_media, num_tiles, num_patches, dim ) hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids) # apply encoder hidden_state = self.layernorm_pre(hidden_state) # Compute the number of tokens to pad num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 # Compute padding tuple for pad function padding = ( 0, 0, 0, num_padding_patches, ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) # Pad the tensor hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) slice_index = -num_padding_patches if num_padding_patches > 0 else None attention_mask = aspect_ratio_mask.reshape( batch_size * num_concurrent_media, -1 ) attention_mask = _prepare_aspect_ratio_attention_mask( aspect_ratio_mask=attention_mask, num_patches=self.num_patches, target_length=hidden_state.shape[2], dtype=self.layernorm_pre.weight.dtype, ) hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim) output = self.transformer( hidden_state, attention_mask=attention_mask, ) hidden_state, intermediate_hidden_states = output[0], output[1] intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1) # apply global encoder hidden_state = self.layernorm_post(hidden_state) hidden_state = hidden_state.reshape( batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim, ) hidden_state = self.post_tile_positional_embedding( hidden_state, aspect_ratio_ids ) hidden_state = hidden_state.reshape( batch_size * num_concurrent_media, num_tiles * (num_patches + num_padding_patches), dim, ) hidden_state = self.global_transformer( hidden_state, attention_mask=attention_mask )[0] hidden_state = hidden_state.reshape( batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim, ) hidden_state = hidden_state[:, :, :slice_index] # adding intermediate layer outputs hidden_state = hidden_state.reshape( batch_size, num_concurrent_media, num_tiles, num_patches, dim ) intermediate_hidden_states = intermediate_hidden_states.reshape( batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, -1, ) intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index] intermediate_hidden_states = intermediate_hidden_states.reshape( batch_size, num_concurrent_media, num_tiles, num_patches, -1 ) hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1) return hidden_state class MllamaTextRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class MllamaTextCrossAttention(nn.Module): def __init__( self, config: Optional[config_mllama.MllamaTextConfig] = None, layer_id: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.config = config self.model_parallel_size = get_tensor_model_parallel_world_size() self.num_heads = self.config.num_attention_heads self.num_local_heads = self.num_heads // self.model_parallel_size self.num_key_value_heads = self.config.num_key_value_heads self.num_local_key_value_heads = ( self.num_key_value_heads // self.model_parallel_size ) self.dropout = config.dropout self.hidden_size = config.hidden_size self.head_dim = config.hidden_size // self.num_heads self.layer_id = layer_id self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.q_local_size = self.num_local_heads * self.head_dim self.kv_local_size = self.num_local_key_value_heads * self.head_dim self.qkv_proj = QKVParallelLinear( self.hidden_size, self.head_dim, self.num_heads, self.num_key_value_heads, bias=False, quant_config=quant_config, prefix=add_prefix("qkv_proj", prefix), ) self.o_proj = RowParallelLinear( self.num_heads * self.head_dim, self.hidden_size, bias=False, input_is_parallel=True, quant_config=quant_config, prefix=add_prefix("o_proj", prefix), ) # vllm.model_executor.layers.layernorm.RMSNorm has precision issue, # use huggingface's instead self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.scaling = self.head_dim**-0.5 self.attn = RadixAttention( self.num_local_heads, self.head_dim, self.scaling, self.num_local_key_value_heads, layer_id=layer_id, is_cross_attention=True, prefix=add_prefix("attn", prefix), ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], cross_attention_states: Optional[torch.Tensor], forward_batch: ForwardBatch, ) -> torch.Tensor: qkv_dec, _ = self.qkv_proj(hidden_states) q, _, _ = qkv_dec.split( [self.q_local_size, self.kv_local_size, self.kv_local_size], dim=-1 ) if cross_attention_states is None: k = None v = None else: qkv_enc, _ = self.qkv_proj(cross_attention_states) _, k, v = qkv_enc.split( [self.q_local_size, self.kv_local_size, self.kv_local_size], dim=-1 ) k = k.view(-1, self.num_local_key_value_heads, self.head_dim) v = v.view(-1, self.num_local_key_value_heads, self.head_dim) k = self.k_norm(k) q = q.view(-1, self.num_local_heads, self.head_dim) q = self.q_norm(q) output = self.attn(q, k, v, forward_batch) out, _ = self.o_proj(output) return out class MllamaCrossAttentionDecoderLayer(torch.nn.Module): """Cross-attention transformer block with tanh-gated attention and feedforward.""" def __init__( self, config: config_mllama.MllamaTextConfig, layer_id: int, quant_config: Optional[QuantizationConfig], prefix: str = "", ) -> None: super().__init__() self.layer_id = layer_id self.cross_attn = MllamaTextCrossAttention( config=config, layer_id=layer_id, quant_config=quant_config, prefix=add_prefix("cross_attn", prefix), ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1)) self.mlp = LlamaMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, prefix=add_prefix("mlp", prefix), ) self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1)) def forward( self, hidden_states: torch.Tensor, cross_attention_states: torch.Tensor, cross_attention_mask: torch.Tensor, full_text_row_masked_out_mask: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = self.cross_attn( hidden_states=hidden_states, attention_mask=cross_attention_mask, cross_attention_states=cross_attention_states, forward_batch=forward_batch, ) hidden_states = full_text_row_masked_out_mask * hidden_states hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = full_text_row_masked_out_mask * hidden_states hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states return hidden_states class MllamaTextModel(nn.Module): config_class = config_mllama.MllamaTextConfig base_model_prefix = "model" def __init__( self, config: config_mllama.MllamaTextConfig, quant_config: Optional[QuantizationConfig], prefix: str = "", ): super().__init__() self.padding_id = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( config.vocab_size + 8, config.hidden_size, prefix=add_prefix("embed_tokens", prefix), ) self.cross_attention_layers = config.cross_attention_layers layers = [] for layer_id in range(config.num_hidden_layers): if layer_id in self.cross_attention_layers: layers.append( MllamaCrossAttentionDecoderLayer( config, layer_id, quant_config=quant_config, prefix=add_prefix(f"layers.{layer_id}", prefix), ) ) else: # TODO: force LlamaDecoderLayer to config.attention_bias=False layers.append( LlamaDecoderLayer( config, quant_config=quant_config, layer_id=layer_id, prefix=add_prefix(f"layers.{layer_id}", prefix), ) ) self.layers = nn.ModuleList(layers) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, input_ids: torch.LongTensor, positions: Optional[torch.LongTensor], cross_attention_states: Optional[torch.LongTensor], cross_attention_mask: Optional[torch.LongTensor], full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]], forward_batch: ForwardBatch, skip_cross_attention: bool, ) -> torch.Tensor: inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds for _, decoder_layer in enumerate(self.layers): if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer): if not skip_cross_attention: hidden_states = decoder_layer( hidden_states=hidden_states, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, forward_batch=forward_batch, ) elif isinstance(decoder_layer, LlamaDecoderLayer): hidden_states, residual = decoder_layer( positions=positions, hidden_states=hidden_states, forward_batch=forward_batch, residual=None, ) hidden_states = hidden_states + residual else: raise ValueError(f"Unknown decoder layer type {type(decoder_layer)}") hidden_states = self.norm(hidden_states) return hidden_states class MllamaForCausalLM(nn.Module): config_class = config_mllama.MllamaTextConfig base_model_prefix = "language_model" _no_split_modules = [ "MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer", ] def __init__( self, config: config_mllama.MllamaTextConfig, quant_config: Optional[QuantizationConfig], prefix: str = "", ): super().__init__() self.vocab_size = config.vocab_size self.model = MllamaTextModel( config, quant_config, prefix=add_prefix("model", prefix) ) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), ) def forward( self, input_ids: torch.LongTensor, positions: Optional[torch.LongTensor], cross_attention_states: Optional[torch.LongTensor], cross_attention_mask: Optional[torch.LongTensor], full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]], forward_batch: ForwardBatch, skip_cross_attention: bool, ) -> torch.Tensor: hidden_states = self.model( input_ids=input_ids, positions=positions, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, forward_batch=forward_batch, skip_cross_attention=skip_cross_attention, ) return hidden_states class MllamaForConditionalGeneration(nn.Module): def __init__( self, config: config_mllama.MllamaConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.vocab_size = config.text_config.vocab_size self.hidden_size = config.text_config.hidden_size self.max_num_tiles = config.vision_config.max_num_tiles self.vision_output_dim = config.vision_config.vision_output_dim self.pad_token_id = ( config.pad_token_id if config.pad_token_id is not None else -1 ) self.image_size = config.vision_config.image_size self.vision_model = MllamaVisionModel( config.vision_config, prefix=add_prefix("vision_model", prefix) ) self.language_model = MllamaForCausalLM( config.text_config, quant_config=quant_config, prefix=add_prefix("language_model", prefix), ) self.multi_modal_projector = nn.Linear( config.vision_config.vision_output_dim, config.text_config.hidden_size, bias=True, ) self.logits_processor = LogitsProcessor(config.text_config) self.capture_mode = False def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): pixel_values = image_inputs.pixel_values pad_values = image_inputs.pad_values num_concurrent_media, num_tiles = pixel_values.shape[1:3] num_patches = self.vision_model.num_patches image_len = num_concurrent_media * num_tiles * num_patches image_inputs.num_image_tokens = image_len pad_ids = pad_values * ((image_len + len(pad_values)) // len(pad_values)) return pad_ids[:image_len] + input_ids def _batch_image_inputs(self, forward_batch: ForwardBatch): if forward_batch.forward_mode.is_decode() or all(forward_batch.encoder_cached): return None, None, None, None # pixel_values: shape (bs, num_image, num_tiles, 3, image_res, image_res) max_num_images = max_num_tiles = bs = 0 for i, im in enumerate(forward_batch.mm_inputs): if not forward_batch.encoder_cached[i] and im is not None: max_num_images = max(max_num_images, im.pixel_values.shape[1]) max_num_tiles = max(max_num_tiles, im.pixel_values.shape[2]) bs += 1 if max_num_images * max_num_tiles * bs == 0: return None, None, None, None with forward_batch.out_cache_loc.device: batched_images = torch.zeros( bs, max_num_images, max_num_tiles, 3, self.image_size, self.image_size, dtype=torch.float32, ) batched_ar_ids = torch.ones( bs, max_num_images, dtype=torch.int64, device="cuda" ) batched_ar_mask = torch.zeros( bs, max_num_images, max_num_tiles, dtype=torch.int64 ) i = 0 encoder_lens_need = [] for k, im in enumerate(forward_batch.mm_inputs): if forward_batch.encoder_cached[k] or im is None: continue encoder_lens_need.append(forward_batch.encoder_lens[k]) for j in range(im.pixel_values.shape[1]): img = im.pixel_values[0, j] num_tiles = img.shape[0] batched_images[i, j, :num_tiles] = img batched_ar_ids[i, j] = im.aspect_ratio_ids[0, j] batched_ar_mask[i, j, :num_tiles] = im.aspect_ratio_mask[0, j] i += 1 return batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need def flat_encoder_result( self, cross_attention_states: torch.Tensor, encoder_lens_need: List[int] ): # NOTE: not all encoders need computation, some are cached head_dim = cross_attention_states.shape[-1] total_encoder_len = sum(encoder_lens_need) cross_attention_states_flat = torch.zeros( total_encoder_len, head_dim, device=cross_attention_states.device, dtype=cross_attention_states.dtype, ) i = start_pos = 0 for encoder_len in encoder_lens_need: if encoder_len == 0: continue end_pos = start_pos + encoder_len cross_attention_states_flat[start_pos:end_pos] = cross_attention_states[i][ :encoder_len ] i += 1 start_pos += encoder_len return cross_attention_states_flat def get_full_text_row_masked_out_mask(self, forward_batch: ForwardBatch): if forward_batch.forward_mode.is_decode(): full_text_row_masked_out_mask = forward_batch.encoder_lens != 0 else: full_text_row_masked_out_mask = torch.ones( forward_batch.extend_seq_lens.sum(), dtype=torch.bool ) start_pos = 0 for seq_len, encoder_len in zip( forward_batch.seq_lens.tolist(), forward_batch.encoder_lens_cpu ): if encoder_len == 0: full_text_row_masked_out_mask[start_pos : start_pos + seq_len] = ( False ) start_pos += encoder_len full_text_row_masked_out_mask = full_text_row_masked_out_mask.to( forward_batch.seq_lens.device ) return full_text_row_masked_out_mask.reshape(-1, 1) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, ) -> Union[Tuple, CausalLMOutputWithPast]: batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need = ( self._batch_image_inputs(forward_batch) ) # TODO: support multi-image by this mask cross_attention_mask = None cross_attention_states = None if self.capture_mode: # NOTE: when doing cuda graph capture, we do not want to skip cross attention # Make is a constant value to avoid cuda graph capture issue skip_cross_attention = False else: # NOTE: we do not need image_inputs when prefill assert len(forward_batch.encoder_lens) == len(forward_batch.seq_lens) assert len(forward_batch.encoder_lens_cpu) == len(forward_batch.seq_lens) skip_cross_attention = forward_batch.encoder_lens.max() == 0 if not skip_cross_attention: full_text_row_masked_out_mask = self.get_full_text_row_masked_out_mask( forward_batch ) else: full_text_row_masked_out_mask = None if batched_images is not None: # NOTE: llama's reference implementation runs vision model on CPU cross_attention_states = self.vision_model( batched_images, batched_ar_ids, batched_ar_mask ) cross_attention_states = self.multi_modal_projector(cross_attention_states) bs, _, _, _, image_token_dim = cross_attention_states.shape cross_attention_states = cross_attention_states.view( bs, -1, image_token_dim ) cross_attention_states = self.flat_encoder_result( cross_attention_states, encoder_lens_need ) hidden_states = self.language_model( input_ids=input_ids, positions=positions, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, forward_batch=forward_batch, skip_cross_attention=skip_cross_attention, ) return self.logits_processor( input_ids, hidden_states, self.language_model.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".k_proj", "k"), (".qkv_proj", ".v_proj", "v"), (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) updated_params = set() for name, loaded_weight in weights: if "patch_embedding.weight" in name: name = name.replace( "patch_embedding.weight", "patch_embedding._linear.weight" ) loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1) for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) param = params_dict[name] updated_params.add(name) weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: if "vision_model" in name: # adapt to VisionAttention name = name.replace("self_attn.o_proj", "self_attn.proj") param = params_dict.pop(name) weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) EntryClass = MllamaForConditionalGeneration