1007 lines
37 KiB
Python
1007 lines
37 KiB
Python
# Adapted from:
|
|
# https://github.com/vllm-project/vllm/blob/7193774b1ff8603ad5bf4598e5efba0d9a39b436/vllm/model_executor/models/mllama.py
|
|
"""PyTorch Mllama model."""
|
|
import math
|
|
from typing import Iterable, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.utils.checkpoint
|
|
import transformers.models.mllama.configuration_mllama as config_mllama
|
|
from torch import nn
|
|
from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast
|
|
from transformers.models.mllama.modeling_mllama import (
|
|
_prepare_aspect_ratio_attention_mask,
|
|
)
|
|
|
|
import sglang.srt.distributed.parallel_state as ps
|
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
|
from sglang.srt.layers.activation import get_act_fn
|
|
from sglang.srt.layers.attention.vision import VisionAttention
|
|
from sglang.srt.layers.layernorm import RMSNorm
|
|
from sglang.srt.layers.linear import (
|
|
ColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
from sglang.srt.layers.quantization import QuantizationConfig
|
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
|
DEFAULT_VOCAB_PADDING_SIZE,
|
|
ParallelLMHead,
|
|
VocabParallelEmbedding,
|
|
)
|
|
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP
|
|
from sglang.srt.utils import add_prefix
|
|
|
|
|
|
class ColumnParallelConv2dPatch(torch.nn.Module):
|
|
"""Conv2D Patching layer with model parallelism.
|
|
Column parallel over unfolded input.
|
|
Arguments:
|
|
in_channels: Input channels.
|
|
out_channels: Output channels.
|
|
kernel_size: Size of convolution kernel.
|
|
stride (default 1): Stride for convolution.
|
|
bias (default False): Use bias in Conv2d.
|
|
Input: (bsz, in_channels, width, height)
|
|
Output: (bsz, num_tokens, out_channels)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
kernel_size: Union[int, Tuple[int, int]],
|
|
stride: Union[int, Tuple[int, int]],
|
|
bias: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
if isinstance(kernel_size, int):
|
|
kernel_size = (kernel_size, kernel_size)
|
|
self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride)
|
|
self._linear = ColumnParallelLinear(
|
|
in_channels * kernel_size[0] * kernel_size[1],
|
|
out_channels,
|
|
bias=bias,
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self._unfold(x)
|
|
x = x.permute(0, 2, 1)
|
|
x, _ = self._linear(x)
|
|
return x
|
|
|
|
|
|
class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
|
|
|
|
def __init__(self, config: config_mllama.MllamaVisionConfig, is_gated: bool = True):
|
|
super().__init__()
|
|
self.max_num_tiles = config.max_num_tiles
|
|
self.hidden_size = config.hidden_size
|
|
self.max_aspect_ratio_id = config.max_aspect_ratio_id
|
|
self.is_gated = is_gated
|
|
|
|
self.embedding = nn.Embedding(
|
|
self.max_aspect_ratio_id + 1, self.max_num_tiles * self.hidden_size
|
|
)
|
|
if is_gated:
|
|
self.gate = nn.Parameter(torch.zeros(1))
|
|
|
|
def forward(
|
|
self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
|
|
) -> torch.Tensor:
|
|
embeddings = self.embedding(aspect_ratio_ids)
|
|
embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size)
|
|
|
|
if self.is_gated:
|
|
embeddings = embeddings * self.gate.tanh()
|
|
|
|
hidden_state = hidden_state + embeddings
|
|
return hidden_state
|
|
|
|
|
|
class MllamaPrecomputedPositionEmbedding(nn.Module):
|
|
def __init__(self, config: config_mllama.MllamaVisionConfig):
|
|
super().__init__()
|
|
self.max_num_tiles = config.max_num_tiles
|
|
self.max_aspect_ratio_id = config.max_aspect_ratio_id
|
|
self.num_patches = (config.image_size // config.patch_size) ** 2 + 1
|
|
self.hidden_size = config.hidden_size
|
|
self.scale = config.hidden_size**-0.5
|
|
|
|
self.gate = nn.Parameter(torch.zeros(1))
|
|
|
|
# position embedding
|
|
position_embedding = torch.randn(self.num_patches, self.hidden_size)
|
|
self.embedding = nn.Parameter(self.scale * position_embedding)
|
|
|
|
# tile position embedding
|
|
self.tile_embedding = nn.Embedding(
|
|
self.max_aspect_ratio_id + 1,
|
|
self.max_num_tiles * self.num_patches * self.hidden_size,
|
|
)
|
|
|
|
def forward(
|
|
self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
|
|
) -> torch.Tensor:
|
|
# position embeddings
|
|
gated_position_embedding = (1 - self.gate.tanh()) * self.embedding
|
|
hidden_state = hidden_state + gated_position_embedding.view(
|
|
1, 1, self.num_patches, self.hidden_size
|
|
)
|
|
|
|
# precomputed tile position embeddings
|
|
tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
|
|
batch_size = hidden_state.shape[0]
|
|
tile_position_embedding = tile_position_embedding.reshape(
|
|
batch_size, self.max_num_tiles, self.num_patches, self.hidden_size
|
|
)
|
|
gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding
|
|
hidden_state = hidden_state + gated_tile_position_embedding
|
|
|
|
return hidden_state
|
|
|
|
|
|
class MllamaVisionMLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
self.activation_fn = get_act_fn(config.hidden_act)
|
|
self.fc1 = ColumnParallelLinear(
|
|
config.hidden_size,
|
|
config.intermediate_size,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("fc1", prefix),
|
|
)
|
|
self.fc2 = RowParallelLinear(
|
|
config.intermediate_size,
|
|
config.hidden_size,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("fc2", prefix),
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states, _ = self.fc1(hidden_states)
|
|
hidden_states = self.activation_fn(hidden_states)
|
|
hidden_states, _ = self.fc2(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class MllamaVisionEncoderLayer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: config_mllama.MllamaVisionConfig,
|
|
is_gated: bool = False,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
|
|
self.hidden_size = config.hidden_size
|
|
self.num_attention_heads = config.attention_heads
|
|
self.is_gated = is_gated
|
|
self.intermediate_size = config.intermediate_size
|
|
|
|
self.self_attn = VisionAttention(
|
|
self.hidden_size,
|
|
self.num_attention_heads,
|
|
self.hidden_size,
|
|
use_qkv_parallel=True,
|
|
quant_config=None,
|
|
dropout=0.0,
|
|
use_context_forward=False,
|
|
softmax_in_single_precision=False,
|
|
flatten_batch=False,
|
|
prefix=add_prefix("self_attn", prefix),
|
|
)
|
|
self.mlp = MllamaVisionMLP(config, prefix=add_prefix("mlp", prefix))
|
|
|
|
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
|
|
self.post_attention_layernorm = nn.LayerNorm(
|
|
self.hidden_size, eps=config.norm_eps
|
|
)
|
|
|
|
# there used to be an if else here, no code path
|
|
if is_gated:
|
|
self.gate_attn = nn.Parameter(torch.ones(1) * math.pi / 4)
|
|
self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_state: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
):
|
|
# Self Attention
|
|
residual = hidden_state
|
|
hidden_state = self.input_layernorm(hidden_state)
|
|
hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask)
|
|
gate_attn = 1 if not self.is_gated else self.gate_attn.tanh()
|
|
hidden_state = residual + gate_attn * hidden_state
|
|
|
|
# Feed forward
|
|
residual = hidden_state
|
|
hidden_state = self.post_attention_layernorm(hidden_state)
|
|
hidden_state = self.mlp(hidden_state)
|
|
gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh()
|
|
hidden_state = residual + gate_ffn * hidden_state
|
|
|
|
return hidden_state
|
|
|
|
|
|
class MllamaVisionEncoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: config_mllama.MllamaVisionConfig,
|
|
num_layers=32,
|
|
is_gated=False,
|
|
output_hidden_states=None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
MllamaVisionEncoderLayer(
|
|
config, is_gated, prefix=add_prefix(f"layers.{i}", prefix)
|
|
)
|
|
for i in range(num_layers)
|
|
]
|
|
)
|
|
self.output_hidden_states = output_hidden_states or []
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
) -> Union[Tuple, BaseModelOutput]:
|
|
encoder_states = ()
|
|
|
|
for i, encoder_layer in enumerate(self.layers):
|
|
if i in self.output_hidden_states:
|
|
encoder_states = encoder_states + (hidden_states,)
|
|
hidden_states = encoder_layer(
|
|
hidden_states,
|
|
attention_mask,
|
|
)
|
|
|
|
if len(self.layers) - 1 in self.output_hidden_states:
|
|
encoder_states = encoder_states + (hidden_states,)
|
|
|
|
return hidden_states, encoder_states
|
|
|
|
|
|
class MllamaVisionModel(nn.Module):
|
|
def __init__(self, config: config_mllama.MllamaVisionConfig, prefix: str = ""):
|
|
super().__init__()
|
|
self.image_size = config.image_size
|
|
self.patch_size = config.patch_size
|
|
self.max_num_tiles = config.max_num_tiles
|
|
self.hidden_size = config.hidden_size
|
|
self.in_channels = config.num_channels
|
|
self.intermediate_layers_indices = config.intermediate_layers_indices
|
|
|
|
self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
|
|
self.scale = config.hidden_size**-0.5
|
|
|
|
self.patch_embedding = ColumnParallelConv2dPatch(
|
|
in_channels=config.num_channels,
|
|
out_channels=self.hidden_size,
|
|
kernel_size=self.patch_size,
|
|
stride=self.patch_size,
|
|
bias=False,
|
|
)
|
|
|
|
self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size))
|
|
self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(config)
|
|
|
|
self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
|
|
config, is_gated=True
|
|
)
|
|
self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
|
|
config, is_gated=True
|
|
)
|
|
|
|
# layer norms
|
|
self.layernorm_pre = nn.LayerNorm(self.hidden_size)
|
|
self.layernorm_post = nn.LayerNorm(self.hidden_size)
|
|
|
|
# encoders
|
|
self.transformer = MllamaVisionEncoder(
|
|
config,
|
|
config.num_hidden_layers,
|
|
is_gated=False,
|
|
output_hidden_states=config.intermediate_layers_indices,
|
|
prefix=add_prefix("transformer", prefix),
|
|
)
|
|
self.global_transformer = MllamaVisionEncoder(
|
|
config,
|
|
config.num_global_layers,
|
|
is_gated=True,
|
|
prefix=add_prefix("global_transformer", prefix),
|
|
)
|
|
|
|
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
|
batch_size, _, hidden_size = hidden_state.shape
|
|
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
|
|
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
|
|
return hidden_state
|
|
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.Tensor,
|
|
aspect_ratio_ids: torch.Tensor,
|
|
aspect_ratio_mask: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
batch_size, num_concurrent_media, num_tiles, num_channels, height, width = (
|
|
pixel_values.shape
|
|
)
|
|
|
|
pixel_values = pixel_values.reshape(
|
|
batch_size * num_concurrent_media * num_tiles, num_channels, height, width
|
|
)
|
|
aspect_ratio_ids = aspect_ratio_ids.reshape(
|
|
batch_size * num_concurrent_media, -1
|
|
)
|
|
|
|
# patch embedding
|
|
patch_embeds = self.patch_embedding(
|
|
pixel_values.to(self.layernorm_pre.weight.dtype)
|
|
)
|
|
hidden_state = patch_embeds
|
|
hidden_state = ps.get_tp_group().all_gather(hidden_state)
|
|
|
|
# tile embeddings
|
|
_, num_patches, dim = hidden_state.shape
|
|
hidden_state = hidden_state.reshape(
|
|
batch_size * num_concurrent_media, num_tiles, -1, dim
|
|
)
|
|
hidden_state = self.pre_tile_positional_embedding(
|
|
hidden_state, aspect_ratio_ids
|
|
)
|
|
|
|
# apply cls token
|
|
hidden_state = hidden_state.reshape(
|
|
batch_size * num_concurrent_media * num_tiles, num_patches, dim
|
|
)
|
|
hidden_state = self.apply_class_embedding(hidden_state)
|
|
num_patches += 1
|
|
|
|
# apply position embeddings
|
|
hidden_state = hidden_state.reshape(
|
|
batch_size * num_concurrent_media, num_tiles, num_patches, dim
|
|
)
|
|
hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids)
|
|
|
|
# apply encoder
|
|
hidden_state = self.layernorm_pre(hidden_state)
|
|
|
|
# Compute the number of tokens to pad
|
|
num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
|
|
# Compute padding tuple for pad function
|
|
padding = (
|
|
0,
|
|
0,
|
|
0,
|
|
num_padding_patches,
|
|
) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
|
|
# Pad the tensor
|
|
hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
|
|
slice_index = -num_padding_patches if num_padding_patches > 0 else None
|
|
|
|
attention_mask = aspect_ratio_mask.reshape(
|
|
batch_size * num_concurrent_media, -1
|
|
)
|
|
attention_mask = _prepare_aspect_ratio_attention_mask(
|
|
aspect_ratio_mask=attention_mask,
|
|
num_patches=self.num_patches,
|
|
target_length=hidden_state.shape[2],
|
|
dtype=self.layernorm_pre.weight.dtype,
|
|
)
|
|
|
|
hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim)
|
|
output = self.transformer(
|
|
hidden_state,
|
|
attention_mask=attention_mask,
|
|
)
|
|
hidden_state, intermediate_hidden_states = output[0], output[1]
|
|
intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1)
|
|
|
|
# apply global encoder
|
|
hidden_state = self.layernorm_post(hidden_state)
|
|
hidden_state = hidden_state.reshape(
|
|
batch_size * num_concurrent_media,
|
|
num_tiles,
|
|
num_patches + num_padding_patches,
|
|
dim,
|
|
)
|
|
hidden_state = self.post_tile_positional_embedding(
|
|
hidden_state, aspect_ratio_ids
|
|
)
|
|
hidden_state = hidden_state.reshape(
|
|
batch_size * num_concurrent_media,
|
|
num_tiles * (num_patches + num_padding_patches),
|
|
dim,
|
|
)
|
|
hidden_state = self.global_transformer(
|
|
hidden_state, attention_mask=attention_mask
|
|
)[0]
|
|
hidden_state = hidden_state.reshape(
|
|
batch_size * num_concurrent_media,
|
|
num_tiles,
|
|
num_patches + num_padding_patches,
|
|
dim,
|
|
)
|
|
hidden_state = hidden_state[:, :, :slice_index]
|
|
|
|
# adding intermediate layer outputs
|
|
hidden_state = hidden_state.reshape(
|
|
batch_size, num_concurrent_media, num_tiles, num_patches, dim
|
|
)
|
|
intermediate_hidden_states = intermediate_hidden_states.reshape(
|
|
batch_size * num_concurrent_media,
|
|
num_tiles,
|
|
num_patches + num_padding_patches,
|
|
-1,
|
|
)
|
|
intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index]
|
|
intermediate_hidden_states = intermediate_hidden_states.reshape(
|
|
batch_size, num_concurrent_media, num_tiles, num_patches, -1
|
|
)
|
|
hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1)
|
|
return hidden_state
|
|
|
|
|
|
class MllamaTextRMSNorm(nn.Module):
|
|
def __init__(self, hidden_size, eps=1e-6):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
self.variance_epsilon = eps
|
|
|
|
def forward(self, hidden_states):
|
|
input_dtype = hidden_states.dtype
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
return self.weight * hidden_states.to(input_dtype)
|
|
|
|
def extra_repr(self):
|
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
|
|
|
|
|
class MllamaTextCrossAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: Optional[config_mllama.MllamaTextConfig] = None,
|
|
layer_id: Optional[int] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
self.model_parallel_size = get_tensor_model_parallel_world_size()
|
|
self.num_heads = self.config.num_attention_heads
|
|
self.num_local_heads = self.num_heads // self.model_parallel_size
|
|
self.num_key_value_heads = self.config.num_key_value_heads
|
|
self.num_local_key_value_heads = (
|
|
self.num_key_value_heads // self.model_parallel_size
|
|
)
|
|
self.dropout = config.dropout
|
|
self.hidden_size = config.hidden_size
|
|
self.head_dim = config.hidden_size // self.num_heads
|
|
self.layer_id = layer_id
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
self.q_local_size = self.num_local_heads * self.head_dim
|
|
self.kv_local_size = self.num_local_key_value_heads * self.head_dim
|
|
|
|
self.qkv_proj = QKVParallelLinear(
|
|
self.hidden_size,
|
|
self.head_dim,
|
|
self.num_heads,
|
|
self.num_key_value_heads,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("qkv_proj", prefix),
|
|
)
|
|
self.o_proj = RowParallelLinear(
|
|
self.num_heads * self.head_dim,
|
|
self.hidden_size,
|
|
bias=False,
|
|
input_is_parallel=True,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("o_proj", prefix),
|
|
)
|
|
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
|
|
# use huggingface's instead
|
|
self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
self.scaling = self.head_dim**-0.5
|
|
|
|
self.attn = RadixAttention(
|
|
self.num_local_heads,
|
|
self.head_dim,
|
|
self.scaling,
|
|
self.num_local_key_value_heads,
|
|
layer_id=layer_id,
|
|
is_cross_attention=True,
|
|
prefix=add_prefix("attn", prefix),
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor],
|
|
cross_attention_states: Optional[torch.Tensor],
|
|
forward_batch: ForwardBatch,
|
|
) -> torch.Tensor:
|
|
qkv_dec, _ = self.qkv_proj(hidden_states)
|
|
q, _, _ = qkv_dec.split(
|
|
[self.q_local_size, self.kv_local_size, self.kv_local_size], dim=-1
|
|
)
|
|
if cross_attention_states is None:
|
|
k = None
|
|
v = None
|
|
else:
|
|
qkv_enc, _ = self.qkv_proj(cross_attention_states)
|
|
_, k, v = qkv_enc.split(
|
|
[self.q_local_size, self.kv_local_size, self.kv_local_size], dim=-1
|
|
)
|
|
k = k.view(-1, self.num_local_key_value_heads, self.head_dim)
|
|
v = v.view(-1, self.num_local_key_value_heads, self.head_dim)
|
|
k = self.k_norm(k)
|
|
q = q.view(-1, self.num_local_heads, self.head_dim)
|
|
q = self.q_norm(q)
|
|
|
|
output = self.attn(q, k, v, forward_batch)
|
|
out, _ = self.o_proj(output)
|
|
return out
|
|
|
|
|
|
class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
|
"""Cross-attention transformer block with tanh-gated attention
|
|
and feedforward."""
|
|
|
|
def __init__(
|
|
self,
|
|
config: config_mllama.MllamaTextConfig,
|
|
layer_id: int,
|
|
quant_config: Optional[QuantizationConfig],
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.layer_id = layer_id
|
|
self.cross_attn = MllamaTextCrossAttention(
|
|
config=config,
|
|
layer_id=layer_id,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("cross_attn", prefix),
|
|
)
|
|
|
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1))
|
|
|
|
self.mlp = LlamaMLP(
|
|
hidden_size=config.hidden_size,
|
|
intermediate_size=config.intermediate_size,
|
|
hidden_act=config.hidden_act,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("mlp", prefix),
|
|
)
|
|
self.post_attention_layernorm = RMSNorm(
|
|
config.hidden_size, eps=config.rms_norm_eps
|
|
)
|
|
self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1))
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
cross_attention_states: torch.Tensor,
|
|
cross_attention_mask: torch.Tensor,
|
|
full_text_row_masked_out_mask: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
) -> torch.Tensor:
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
|
|
hidden_states = self.cross_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=cross_attention_mask,
|
|
cross_attention_states=cross_attention_states,
|
|
forward_batch=forward_batch,
|
|
)
|
|
hidden_states = full_text_row_masked_out_mask * hidden_states
|
|
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = full_text_row_masked_out_mask * hidden_states
|
|
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
|
|
return hidden_states
|
|
|
|
|
|
class MllamaTextModel(nn.Module):
|
|
config_class = config_mllama.MllamaTextConfig
|
|
base_model_prefix = "model"
|
|
|
|
def __init__(
|
|
self,
|
|
config: config_mllama.MllamaTextConfig,
|
|
quant_config: Optional[QuantizationConfig],
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.padding_id = config.pad_token_id
|
|
self.vocab_size = config.vocab_size
|
|
self.embed_tokens = VocabParallelEmbedding(
|
|
config.vocab_size + 8,
|
|
config.hidden_size,
|
|
prefix=add_prefix("embed_tokens", prefix),
|
|
)
|
|
self.cross_attention_layers = config.cross_attention_layers
|
|
|
|
layers = []
|
|
for layer_id in range(config.num_hidden_layers):
|
|
if layer_id in self.cross_attention_layers:
|
|
layers.append(
|
|
MllamaCrossAttentionDecoderLayer(
|
|
config,
|
|
layer_id,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix(f"layers.{layer_id}", prefix),
|
|
)
|
|
)
|
|
else:
|
|
# TODO: force LlamaDecoderLayer to config.attention_bias=False
|
|
layers.append(
|
|
LlamaDecoderLayer(
|
|
config,
|
|
quant_config=quant_config,
|
|
layer_id=layer_id,
|
|
prefix=add_prefix(f"layers.{layer_id}", prefix),
|
|
)
|
|
)
|
|
|
|
self.layers = nn.ModuleList(layers)
|
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
positions: Optional[torch.LongTensor],
|
|
cross_attention_states: Optional[torch.LongTensor],
|
|
cross_attention_mask: Optional[torch.LongTensor],
|
|
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
|
forward_batch: ForwardBatch,
|
|
skip_cross_attention: bool,
|
|
) -> torch.Tensor:
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
hidden_states = inputs_embeds
|
|
|
|
for _, decoder_layer in enumerate(self.layers):
|
|
if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer):
|
|
if not skip_cross_attention:
|
|
hidden_states = decoder_layer(
|
|
hidden_states=hidden_states,
|
|
cross_attention_states=cross_attention_states,
|
|
cross_attention_mask=cross_attention_mask,
|
|
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
|
forward_batch=forward_batch,
|
|
)
|
|
elif isinstance(decoder_layer, LlamaDecoderLayer):
|
|
hidden_states, residual = decoder_layer(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
forward_batch=forward_batch,
|
|
residual=None,
|
|
)
|
|
hidden_states = hidden_states + residual
|
|
else:
|
|
raise ValueError(f"Unknown decoder layer type {type(decoder_layer)}")
|
|
hidden_states = self.norm(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class MllamaForCausalLM(nn.Module):
|
|
config_class = config_mllama.MllamaTextConfig
|
|
base_model_prefix = "language_model"
|
|
_no_split_modules = [
|
|
"MllamaCrossAttentionDecoderLayer",
|
|
"MllamaSelfAttentionDecoderLayer",
|
|
]
|
|
|
|
def __init__(
|
|
self,
|
|
config: config_mllama.MllamaTextConfig,
|
|
quant_config: Optional[QuantizationConfig],
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.vocab_size = config.vocab_size
|
|
self.model = MllamaTextModel(
|
|
config, quant_config, prefix=add_prefix("model", prefix)
|
|
)
|
|
self.lm_head = ParallelLMHead(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
org_num_embeddings=config.vocab_size,
|
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("lm_head", prefix),
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
positions: Optional[torch.LongTensor],
|
|
cross_attention_states: Optional[torch.LongTensor],
|
|
cross_attention_mask: Optional[torch.LongTensor],
|
|
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
|
forward_batch: ForwardBatch,
|
|
skip_cross_attention: bool,
|
|
) -> torch.Tensor:
|
|
hidden_states = self.model(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
cross_attention_states=cross_attention_states,
|
|
cross_attention_mask=cross_attention_mask,
|
|
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
|
forward_batch=forward_batch,
|
|
skip_cross_attention=skip_cross_attention,
|
|
)
|
|
return hidden_states
|
|
|
|
|
|
class MllamaForConditionalGeneration(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: config_mllama.MllamaConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.vocab_size = config.text_config.vocab_size
|
|
self.hidden_size = config.text_config.hidden_size
|
|
self.max_num_tiles = config.vision_config.max_num_tiles
|
|
self.vision_output_dim = config.vision_config.vision_output_dim
|
|
self.pad_token_id = (
|
|
config.pad_token_id if config.pad_token_id is not None else -1
|
|
)
|
|
self.image_size = config.vision_config.image_size
|
|
|
|
self.vision_model = MllamaVisionModel(
|
|
config.vision_config, prefix=add_prefix("vision_model", prefix)
|
|
)
|
|
self.language_model = MllamaForCausalLM(
|
|
config.text_config,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("language_model", prefix),
|
|
)
|
|
self.multi_modal_projector = nn.Linear(
|
|
config.vision_config.vision_output_dim,
|
|
config.text_config.hidden_size,
|
|
bias=True,
|
|
)
|
|
self.logits_processor = LogitsProcessor(config.text_config)
|
|
self.capture_mode = False
|
|
|
|
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
|
pixel_values = image_inputs.pixel_values
|
|
pad_values = image_inputs.pad_values
|
|
|
|
num_concurrent_media, num_tiles = pixel_values.shape[1:3]
|
|
num_patches = self.vision_model.num_patches
|
|
image_len = num_concurrent_media * num_tiles * num_patches
|
|
image_inputs.num_image_tokens = image_len
|
|
|
|
pad_ids = pad_values * ((image_len + len(pad_values)) // len(pad_values))
|
|
|
|
return pad_ids[:image_len] + input_ids
|
|
|
|
def _batch_image_inputs(self, forward_batch: ForwardBatch):
|
|
if forward_batch.forward_mode.is_decode() or all(forward_batch.encoder_cached):
|
|
return None, None, None, None
|
|
|
|
# pixel_values: shape (bs, num_image, num_tiles, 3, image_res, image_res)
|
|
max_num_images = max_num_tiles = bs = 0
|
|
for i, im in enumerate(forward_batch.mm_inputs):
|
|
if not forward_batch.encoder_cached[i] and im is not None:
|
|
max_num_images = max(max_num_images, im.pixel_values.shape[1])
|
|
max_num_tiles = max(max_num_tiles, im.pixel_values.shape[2])
|
|
bs += 1
|
|
|
|
if max_num_images * max_num_tiles * bs == 0:
|
|
return None, None, None, None
|
|
|
|
with forward_batch.out_cache_loc.device:
|
|
batched_images = torch.zeros(
|
|
bs,
|
|
max_num_images,
|
|
max_num_tiles,
|
|
3,
|
|
self.image_size,
|
|
self.image_size,
|
|
dtype=torch.float32,
|
|
)
|
|
batched_ar_ids = torch.ones(
|
|
bs, max_num_images, dtype=torch.int64, device="cuda"
|
|
)
|
|
batched_ar_mask = torch.zeros(
|
|
bs, max_num_images, max_num_tiles, dtype=torch.int64
|
|
)
|
|
i = 0
|
|
encoder_lens_need = []
|
|
for k, im in enumerate(forward_batch.mm_inputs):
|
|
if forward_batch.encoder_cached[k] or im is None:
|
|
continue
|
|
|
|
encoder_lens_need.append(forward_batch.encoder_lens[k])
|
|
for j in range(im.pixel_values.shape[1]):
|
|
img = im.pixel_values[0, j]
|
|
num_tiles = img.shape[0]
|
|
batched_images[i, j, :num_tiles] = img
|
|
batched_ar_ids[i, j] = im.aspect_ratio_ids[0, j]
|
|
batched_ar_mask[i, j, :num_tiles] = im.aspect_ratio_mask[0, j]
|
|
i += 1
|
|
|
|
return batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need
|
|
|
|
def flat_encoder_result(
|
|
self, cross_attention_states: torch.Tensor, encoder_lens_need: List[int]
|
|
):
|
|
# NOTE: not all encoders need computation, some are cached
|
|
head_dim = cross_attention_states.shape[-1]
|
|
total_encoder_len = sum(encoder_lens_need)
|
|
cross_attention_states_flat = torch.zeros(
|
|
total_encoder_len,
|
|
head_dim,
|
|
device=cross_attention_states.device,
|
|
dtype=cross_attention_states.dtype,
|
|
)
|
|
|
|
i = start_pos = 0
|
|
for encoder_len in encoder_lens_need:
|
|
if encoder_len == 0:
|
|
continue
|
|
end_pos = start_pos + encoder_len
|
|
cross_attention_states_flat[start_pos:end_pos] = cross_attention_states[i][
|
|
:encoder_len
|
|
]
|
|
i += 1
|
|
start_pos += encoder_len
|
|
|
|
return cross_attention_states_flat
|
|
|
|
def get_full_text_row_masked_out_mask(self, forward_batch: ForwardBatch):
|
|
if forward_batch.forward_mode.is_decode():
|
|
full_text_row_masked_out_mask = forward_batch.encoder_lens != 0
|
|
else:
|
|
full_text_row_masked_out_mask = torch.ones(
|
|
forward_batch.extend_seq_lens.sum(), dtype=torch.bool
|
|
)
|
|
start_pos = 0
|
|
|
|
for seq_len, encoder_len in zip(
|
|
forward_batch.seq_lens.tolist(), forward_batch.encoder_lens_cpu
|
|
):
|
|
if encoder_len == 0:
|
|
full_text_row_masked_out_mask[start_pos : start_pos + seq_len] = (
|
|
False
|
|
)
|
|
start_pos += encoder_len
|
|
|
|
full_text_row_masked_out_mask = full_text_row_masked_out_mask.to(
|
|
forward_batch.seq_lens.device
|
|
)
|
|
|
|
return full_text_row_masked_out_mask.reshape(-1, 1)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need = (
|
|
self._batch_image_inputs(forward_batch)
|
|
)
|
|
|
|
# TODO: support multi-image by this mask
|
|
cross_attention_mask = None
|
|
cross_attention_states = None
|
|
|
|
if self.capture_mode:
|
|
# NOTE: when doing cuda graph capture, we do not want to skip cross attention
|
|
# Make is a constant value to avoid cuda graph capture issue
|
|
skip_cross_attention = False
|
|
else:
|
|
# NOTE: we do not need image_inputs when prefill
|
|
assert len(forward_batch.encoder_lens) == len(forward_batch.seq_lens)
|
|
assert len(forward_batch.encoder_lens_cpu) == len(forward_batch.seq_lens)
|
|
skip_cross_attention = forward_batch.encoder_lens.max() == 0
|
|
|
|
if not skip_cross_attention:
|
|
full_text_row_masked_out_mask = self.get_full_text_row_masked_out_mask(
|
|
forward_batch
|
|
)
|
|
else:
|
|
full_text_row_masked_out_mask = None
|
|
|
|
if batched_images is not None:
|
|
# NOTE: llama's reference implementation runs vision model on CPU
|
|
cross_attention_states = self.vision_model(
|
|
batched_images, batched_ar_ids, batched_ar_mask
|
|
)
|
|
cross_attention_states = self.multi_modal_projector(cross_attention_states)
|
|
|
|
bs, _, _, _, image_token_dim = cross_attention_states.shape
|
|
cross_attention_states = cross_attention_states.view(
|
|
bs, -1, image_token_dim
|
|
)
|
|
|
|
cross_attention_states = self.flat_encoder_result(
|
|
cross_attention_states, encoder_lens_need
|
|
)
|
|
|
|
hidden_states = self.language_model(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
cross_attention_states=cross_attention_states,
|
|
cross_attention_mask=cross_attention_mask,
|
|
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
|
forward_batch=forward_batch,
|
|
skip_cross_attention=skip_cross_attention,
|
|
)
|
|
return self.logits_processor(
|
|
input_ids, hidden_states, self.language_model.lm_head, forward_batch
|
|
)
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
(".qkv_proj", ".q_proj", "q"),
|
|
(".qkv_proj", ".k_proj", "k"),
|
|
(".qkv_proj", ".v_proj", "v"),
|
|
(".gate_up_proj", ".gate_proj", 0),
|
|
(".gate_up_proj", ".up_proj", 1),
|
|
]
|
|
params_dict = dict(self.named_parameters())
|
|
updated_params = set()
|
|
for name, loaded_weight in weights:
|
|
if "patch_embedding.weight" in name:
|
|
name = name.replace(
|
|
"patch_embedding.weight", "patch_embedding._linear.weight"
|
|
)
|
|
loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1)
|
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
param = params_dict[name]
|
|
updated_params.add(name)
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
if "vision_model" in name:
|
|
# adapt to VisionAttention
|
|
name = name.replace("self_attn.o_proj", "self_attn.proj")
|
|
|
|
param = params_dict.pop(name)
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
|
|
|
|
EntryClass = MllamaForConditionalGeneration
|