sglang0.4.5.post1/python/sglang/srt/models/deepseek_vl2.py

359 lines
13 KiB
Python

from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import nn
from sglang.srt.configs.deepseekvl2 import (
DeepseekVL2Config,
DeepseekVL2MlpProjectorConfig,
)
from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
class DeepseekVL2MlpProjector(nn.Module):
def __init__(
self,
config: DeepseekVL2MlpProjectorConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
if config.projector_type == "identity":
modules = nn.Identity()
elif config.projector_type == "linear":
self.layers = nn.ModuleList(
[
ReplicatedLinear(
config.input_dim,
config.n_embed,
quant_config=quant_config,
)
]
)
elif config.projector_type == "mlp_gelu":
mlp_depth = config.depth
self.layers = nn.ModuleList(
[
ReplicatedLinear(
config.input_dim,
config.n_embed,
quant_config=quant_config,
)
]
)
for _ in range(1, mlp_depth):
self.layers.append(nn.GELU())
self.layers.append(
ReplicatedLinear(
config.n_embed,
config.n_embed,
quant_config=quant_config,
)
)
elif config.projector_type == "downsample_mlp_gelu":
mlp_depth = config.depth
mlp_ratio = config.mlp_ratio
self.layers = nn.ModuleList(
[
ReplicatedLinear(
config.input_dim
* config.downsample_ratio
* config.downsample_ratio,
config.n_embed * mlp_ratio,
quant_config=quant_config,
)
]
)
for _ in range(1, mlp_depth - 1):
self.layers.append(nn.GELU())
self.layers.append(
ReplicatedLinear(
config.n_embed * mlp_ratio,
config.n_embed * mlp_ratio,
quant_config=quant_config,
)
)
self.layers.append(nn.GELU())
self.layers.append(
ReplicatedLinear(
config.n_embed * mlp_ratio,
config.n_embed,
quant_config=quant_config,
)
)
else:
raise ValueError(f"Unknown projector type: {config.projector_type}")
if config.token_pooling:
self.token_pooling_layer = ReplicatedLinear(
config.input_dim * 4, config.input_dim, quant_config=quant_config
)
def forward(self, x):
if self.config.token_pooling:
batch_size, wxh, channels = x.shape
w = h = int(wxh**0.5)
x = x.view(batch_size, w, h, channels)
x = x.permute(0, 3, 1, 2)
patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
batch_size, channels, h_patches, w_patches, _, _ = patches.size()
patches = patches.contiguous().view(
batch_size, channels, h_patches * w_patches, -1
)
patches = patches.permute(0, 2, 1, 3).contiguous()
patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
x = self.token_pooling_layer(patches)[0]
elif self.config.projector_type == "downsample_mlp_gelu":
bs, hw, input_dim = x.shape
h = w = int((hw) ** 0.5)
"""compute padding"""
if h % self.config.downsample_ratio:
pad = self.config.downsample_ratio - h % self.config.downsample_ratio
else:
pad = 0
x = x.reshape(bs, h, w, input_dim)
if pad > 0:
x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
"""4 to 1 concat"""
x = x.permute(0, 3, 1, 2) # B, C, H, W
x = F.unfold(
x,
kernel_size=self.config.downsample_ratio,
stride=self.config.downsample_ratio,
padding=0,
) # B, C*4, HW // 4
x = x.permute(0, 2, 1)
for layer in self.layers:
x = layer(x)
if isinstance(x, tuple):
x = x[0]
return x
# todo
class DeepseekVL2ForCausalLM(nn.Module):
def __init__(
self,
config: DeepseekVL2Config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
# ----------- vision encoder ------------
vision_config = config.vision_config
self.vision = self._init_vision_module(vision_config, quant_config)
# ----------- vl projector ------------
projector_config = config.projector_config
self.projector = DeepseekVL2MlpProjector(projector_config, quant_config)
self.tile_tag = config.tile_tag
self.global_view_pos = config.global_view_pos
embed_std = 1 / torch.sqrt(
torch.tensor(projector_config.n_embed, dtype=torch.float32)
)
if self.tile_tag == "2D":
self.image_newline = nn.Parameter(
torch.randn(projector_config.n_embed) * embed_std
)
self.view_seperator = nn.Parameter(
torch.randn(projector_config.n_embed) * embed_std
)
else:
raise ValueError(f"tile tag should be 2D, but got {self.tile_tag}")
# ----------- language model ------------
language_config = config.language_config
self.language_model = DeepseekV2ForCausalLM(language_config)
def _init_vision_module(
self, vision_config, quant_config: Optional[QuantizationConfig]
) -> nn.Module:
# TODO: refactor vision model through timm wrapper from transformers
try:
import timm
except ImportError:
raise ImportError("Please install timm") from ImportError
model = timm.create_model(
"vit_so400m_patch14_siglip_384.webli",
pretrained=False,
num_classes=0,
dynamic_img_size=True,
dynamic_img_pad=True,
)
model = model.to(dtype=torch.get_default_dtype())
return model
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
**kwargs: object,
):
input_embeds = self.language_model.model.embed_tokens(input_ids)
if (
forward_batch.forward_mode.is_extend()
and forward_batch.contains_image_inputs()
):
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
for idx, image in enumerate(forward_batch.mm_inputs):
if image is None:
continue
start_idx = extend_start_loc_cpu[idx]
end_idx = start_idx + extend_seq_lens_cpu[idx]
images_emb_mask = image.images_emb_mask.to(device="cuda")
image_features = self.get_image_feature(image)
input_embeds[start_idx:end_idx] = input_embeds[
start_idx:end_idx
].masked_scatter(images_emb_mask.unsqueeze(-1), image_features)
outputs = self.language_model.forward(
input_ids=input_ids,
positions=positions,
forward_batch=forward_batch,
input_embeds=input_embeds,
)
return outputs
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "up_proj", 1),
("gate_up_proj", "gate_proj", 0),
]
params_dict = dict(self.named_parameters())
weights = list(weights)
for name, loaded_weight in weights:
if "language" in name:
name = name.replace("language.", "")
self.language_model.load_weights([(name, loaded_weight)])
else:
param = params_dict[name]
weights_loader = getattr(param, "weight_loader", default_weight_loader)
weights_loader(param, loaded_weight)
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
return input_ids
def get_image_feature(self, image_input: MultimodalInputs):
pixel_values = image_input.pixel_values.type(
next(self.vision.parameters()).dtype
).to(device=next(self.vision.parameters()).device)
image_feature = self.vision.forward_features(pixel_values)
images_embeds = self.projector(image_feature)
_, hw, n_dim = images_embeds.shape
h = w = int(hw**0.5)
tile_index = 0
images_in_this_batch = []
images_spatial_crop = image_input.image_spatial_crop
for jdx in range(images_spatial_crop.shape[1]):
num_width_tiles, num_height_tiles = images_spatial_crop[0, jdx]
if num_width_tiles == 0 or num_height_tiles == 0:
break
num_tiles_in_image = num_width_tiles * num_height_tiles
# [hw, D]
global_features = images_embeds[tile_index]
# [num_height_tiles * num_width_tiles, hw, D]
local_features = images_embeds[
tile_index + 1 : tile_index + 1 + num_tiles_in_image
]
tile_index += num_tiles_in_image + 1
# format global and local features
# ----------------- global view add newline -----------------
# [hw, D] -> [h, w, D]
global_features = global_features.view(h, w, n_dim)
# [D] -> [h, 1, D]
new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
global_features = torch.cat([global_features, new_lines_in_global], dim=1)
# [h, w + 1, D] -> [h * (w + 1), D]
global_features = global_features.view(-1, n_dim)
# ----------------- local view add newline -----------------
# [num_height_tiles * num_width_tiles, h * w, D] ->
# [num_height_tiles * h, num_width_tiles * w, D]
local_features = rearrange(
local_features,
"(th tw) (h w) d -> (th h) (tw w) d",
th=num_height_tiles,
tw=num_width_tiles,
h=h,
w=w,
)
# [D] -> [num_height_tiles * h, 1, D]
new_lines_in_local = repeat(
self.image_newline,
"d -> (th h) 1 d",
th=num_height_tiles,
h=h,
)
# [num_height_tiles * h, num_width_tiles * w + 1, D]
local_features = torch.cat([local_features, new_lines_in_local], dim=1)
# [num_height_tiles * h, num_width_tiles * w + 1, D]
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
local_features = local_features.view(-1, n_dim)
# merge global and local tiles
if self.global_view_pos == "head":
global_local_features = torch.cat(
[
global_features,
self.view_seperator[None, :],
local_features,
]
)
else:
global_local_features = torch.cat(
[
local_features,
self.view_seperator[None, :],
global_features,
]
)
images_in_this_batch.append(global_local_features)
return torch.cat(images_in_this_batch, dim=0)
EntryClass = DeepseekVL2ForCausalLM