359 lines
13 KiB
Python
359 lines
13 KiB
Python
from typing import Iterable, List, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from einops import rearrange, repeat
|
|
from torch import nn
|
|
|
|
from sglang.srt.configs.deepseekvl2 import (
|
|
DeepseekVL2Config,
|
|
DeepseekVL2MlpProjectorConfig,
|
|
)
|
|
from sglang.srt.layers.linear import ReplicatedLinear
|
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
|
|
|
|
|
|
class DeepseekVL2MlpProjector(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: DeepseekVL2MlpProjectorConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
self.config = config
|
|
|
|
if config.projector_type == "identity":
|
|
modules = nn.Identity()
|
|
|
|
elif config.projector_type == "linear":
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
ReplicatedLinear(
|
|
config.input_dim,
|
|
config.n_embed,
|
|
quant_config=quant_config,
|
|
)
|
|
]
|
|
)
|
|
|
|
elif config.projector_type == "mlp_gelu":
|
|
mlp_depth = config.depth
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
ReplicatedLinear(
|
|
config.input_dim,
|
|
config.n_embed,
|
|
quant_config=quant_config,
|
|
)
|
|
]
|
|
)
|
|
for _ in range(1, mlp_depth):
|
|
self.layers.append(nn.GELU())
|
|
self.layers.append(
|
|
ReplicatedLinear(
|
|
config.n_embed,
|
|
config.n_embed,
|
|
quant_config=quant_config,
|
|
)
|
|
)
|
|
|
|
elif config.projector_type == "downsample_mlp_gelu":
|
|
mlp_depth = config.depth
|
|
mlp_ratio = config.mlp_ratio
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
ReplicatedLinear(
|
|
config.input_dim
|
|
* config.downsample_ratio
|
|
* config.downsample_ratio,
|
|
config.n_embed * mlp_ratio,
|
|
quant_config=quant_config,
|
|
)
|
|
]
|
|
)
|
|
for _ in range(1, mlp_depth - 1):
|
|
self.layers.append(nn.GELU())
|
|
self.layers.append(
|
|
ReplicatedLinear(
|
|
config.n_embed * mlp_ratio,
|
|
config.n_embed * mlp_ratio,
|
|
quant_config=quant_config,
|
|
)
|
|
)
|
|
self.layers.append(nn.GELU())
|
|
self.layers.append(
|
|
ReplicatedLinear(
|
|
config.n_embed * mlp_ratio,
|
|
config.n_embed,
|
|
quant_config=quant_config,
|
|
)
|
|
)
|
|
|
|
else:
|
|
raise ValueError(f"Unknown projector type: {config.projector_type}")
|
|
|
|
if config.token_pooling:
|
|
self.token_pooling_layer = ReplicatedLinear(
|
|
config.input_dim * 4, config.input_dim, quant_config=quant_config
|
|
)
|
|
|
|
def forward(self, x):
|
|
if self.config.token_pooling:
|
|
batch_size, wxh, channels = x.shape
|
|
w = h = int(wxh**0.5)
|
|
x = x.view(batch_size, w, h, channels)
|
|
x = x.permute(0, 3, 1, 2)
|
|
|
|
patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
|
|
batch_size, channels, h_patches, w_patches, _, _ = patches.size()
|
|
patches = patches.contiguous().view(
|
|
batch_size, channels, h_patches * w_patches, -1
|
|
)
|
|
patches = patches.permute(0, 2, 1, 3).contiguous()
|
|
patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
|
|
|
|
x = self.token_pooling_layer(patches)[0]
|
|
|
|
elif self.config.projector_type == "downsample_mlp_gelu":
|
|
bs, hw, input_dim = x.shape
|
|
h = w = int((hw) ** 0.5)
|
|
|
|
"""compute padding"""
|
|
if h % self.config.downsample_ratio:
|
|
pad = self.config.downsample_ratio - h % self.config.downsample_ratio
|
|
else:
|
|
pad = 0
|
|
x = x.reshape(bs, h, w, input_dim)
|
|
if pad > 0:
|
|
x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
|
|
|
|
"""4 to 1 concat"""
|
|
x = x.permute(0, 3, 1, 2) # B, C, H, W
|
|
x = F.unfold(
|
|
x,
|
|
kernel_size=self.config.downsample_ratio,
|
|
stride=self.config.downsample_ratio,
|
|
padding=0,
|
|
) # B, C*4, HW // 4
|
|
x = x.permute(0, 2, 1)
|
|
|
|
for layer in self.layers:
|
|
x = layer(x)
|
|
if isinstance(x, tuple):
|
|
x = x[0]
|
|
return x
|
|
|
|
|
|
# todo
|
|
class DeepseekVL2ForCausalLM(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: DeepseekVL2Config,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
):
|
|
super().__init__()
|
|
|
|
# ----------- vision encoder ------------
|
|
vision_config = config.vision_config
|
|
self.vision = self._init_vision_module(vision_config, quant_config)
|
|
|
|
# ----------- vl projector ------------
|
|
projector_config = config.projector_config
|
|
self.projector = DeepseekVL2MlpProjector(projector_config, quant_config)
|
|
|
|
self.tile_tag = config.tile_tag
|
|
self.global_view_pos = config.global_view_pos
|
|
|
|
embed_std = 1 / torch.sqrt(
|
|
torch.tensor(projector_config.n_embed, dtype=torch.float32)
|
|
)
|
|
if self.tile_tag == "2D":
|
|
self.image_newline = nn.Parameter(
|
|
torch.randn(projector_config.n_embed) * embed_std
|
|
)
|
|
self.view_seperator = nn.Parameter(
|
|
torch.randn(projector_config.n_embed) * embed_std
|
|
)
|
|
else:
|
|
raise ValueError(f"tile tag should be 2D, but got {self.tile_tag}")
|
|
|
|
# ----------- language model ------------
|
|
language_config = config.language_config
|
|
self.language_model = DeepseekV2ForCausalLM(language_config)
|
|
|
|
def _init_vision_module(
|
|
self, vision_config, quant_config: Optional[QuantizationConfig]
|
|
) -> nn.Module:
|
|
# TODO: refactor vision model through timm wrapper from transformers
|
|
try:
|
|
import timm
|
|
except ImportError:
|
|
raise ImportError("Please install timm") from ImportError
|
|
|
|
model = timm.create_model(
|
|
"vit_so400m_patch14_siglip_384.webli",
|
|
pretrained=False,
|
|
num_classes=0,
|
|
dynamic_img_size=True,
|
|
dynamic_img_pad=True,
|
|
)
|
|
|
|
model = model.to(dtype=torch.get_default_dtype())
|
|
return model
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
**kwargs: object,
|
|
):
|
|
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
|
if (
|
|
forward_batch.forward_mode.is_extend()
|
|
and forward_batch.contains_image_inputs()
|
|
):
|
|
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
|
extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
|
|
for idx, image in enumerate(forward_batch.mm_inputs):
|
|
if image is None:
|
|
continue
|
|
start_idx = extend_start_loc_cpu[idx]
|
|
end_idx = start_idx + extend_seq_lens_cpu[idx]
|
|
images_emb_mask = image.images_emb_mask.to(device="cuda")
|
|
image_features = self.get_image_feature(image)
|
|
input_embeds[start_idx:end_idx] = input_embeds[
|
|
start_idx:end_idx
|
|
].masked_scatter(images_emb_mask.unsqueeze(-1), image_features)
|
|
|
|
outputs = self.language_model.forward(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
forward_batch=forward_batch,
|
|
input_embeds=input_embeds,
|
|
)
|
|
|
|
return outputs
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("qkv_proj", "q_proj", "q"),
|
|
("qkv_proj", "k_proj", "k"),
|
|
("qkv_proj", "v_proj", "v"),
|
|
("gate_up_proj", "up_proj", 1),
|
|
("gate_up_proj", "gate_proj", 0),
|
|
]
|
|
params_dict = dict(self.named_parameters())
|
|
weights = list(weights)
|
|
for name, loaded_weight in weights:
|
|
if "language" in name:
|
|
name = name.replace("language.", "")
|
|
self.language_model.load_weights([(name, loaded_weight)])
|
|
else:
|
|
param = params_dict[name]
|
|
weights_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weights_loader(param, loaded_weight)
|
|
|
|
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
|
return input_ids
|
|
|
|
def get_image_feature(self, image_input: MultimodalInputs):
|
|
pixel_values = image_input.pixel_values.type(
|
|
next(self.vision.parameters()).dtype
|
|
).to(device=next(self.vision.parameters()).device)
|
|
image_feature = self.vision.forward_features(pixel_values)
|
|
images_embeds = self.projector(image_feature)
|
|
_, hw, n_dim = images_embeds.shape
|
|
h = w = int(hw**0.5)
|
|
tile_index = 0
|
|
images_in_this_batch = []
|
|
images_spatial_crop = image_input.image_spatial_crop
|
|
for jdx in range(images_spatial_crop.shape[1]):
|
|
num_width_tiles, num_height_tiles = images_spatial_crop[0, jdx]
|
|
if num_width_tiles == 0 or num_height_tiles == 0:
|
|
break
|
|
num_tiles_in_image = num_width_tiles * num_height_tiles
|
|
|
|
# [hw, D]
|
|
global_features = images_embeds[tile_index]
|
|
|
|
# [num_height_tiles * num_width_tiles, hw, D]
|
|
local_features = images_embeds[
|
|
tile_index + 1 : tile_index + 1 + num_tiles_in_image
|
|
]
|
|
tile_index += num_tiles_in_image + 1
|
|
|
|
# format global and local features
|
|
# ----------------- global view add newline -----------------
|
|
# [hw, D] -> [h, w, D]
|
|
global_features = global_features.view(h, w, n_dim)
|
|
|
|
# [D] -> [h, 1, D]
|
|
new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
|
|
|
|
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
|
|
global_features = torch.cat([global_features, new_lines_in_global], dim=1)
|
|
|
|
# [h, w + 1, D] -> [h * (w + 1), D]
|
|
global_features = global_features.view(-1, n_dim)
|
|
|
|
# ----------------- local view add newline -----------------
|
|
# [num_height_tiles * num_width_tiles, h * w, D] ->
|
|
# [num_height_tiles * h, num_width_tiles * w, D]
|
|
local_features = rearrange(
|
|
local_features,
|
|
"(th tw) (h w) d -> (th h) (tw w) d",
|
|
th=num_height_tiles,
|
|
tw=num_width_tiles,
|
|
h=h,
|
|
w=w,
|
|
)
|
|
|
|
# [D] -> [num_height_tiles * h, 1, D]
|
|
new_lines_in_local = repeat(
|
|
self.image_newline,
|
|
"d -> (th h) 1 d",
|
|
th=num_height_tiles,
|
|
h=h,
|
|
)
|
|
|
|
# [num_height_tiles * h, num_width_tiles * w + 1, D]
|
|
local_features = torch.cat([local_features, new_lines_in_local], dim=1)
|
|
|
|
# [num_height_tiles * h, num_width_tiles * w + 1, D]
|
|
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
|
|
local_features = local_features.view(-1, n_dim)
|
|
|
|
# merge global and local tiles
|
|
if self.global_view_pos == "head":
|
|
global_local_features = torch.cat(
|
|
[
|
|
global_features,
|
|
self.view_seperator[None, :],
|
|
local_features,
|
|
]
|
|
)
|
|
else:
|
|
global_local_features = torch.cat(
|
|
[
|
|
local_features,
|
|
self.view_seperator[None, :],
|
|
global_features,
|
|
]
|
|
)
|
|
|
|
images_in_this_batch.append(global_local_features)
|
|
|
|
return torch.cat(images_in_this_batch, dim=0)
|
|
|
|
|
|
EntryClass = DeepseekVL2ForCausalLM
|