629 lines
19 KiB
Python
629 lines
19 KiB
Python
# Adapted from:
|
||
# https://github.com/deepseek-ai/Janus/tree/main/janus/models
|
||
|
||
from dataclasses import dataclass
|
||
from typing import Dict, List, Tuple, Union
|
||
|
||
import numpy as np
|
||
import PIL
|
||
import torch
|
||
from PIL.Image import Image
|
||
from transformers import (
|
||
BaseImageProcessor,
|
||
BatchFeature,
|
||
LlamaConfig,
|
||
LlamaTokenizerFast,
|
||
PretrainedConfig,
|
||
ProcessorMixin,
|
||
)
|
||
from transformers.image_utils import to_numpy_array
|
||
|
||
from sglang.srt.configs.utils import register_image_processor, register_processor
|
||
from sglang.srt.mm_utils import expand2square
|
||
|
||
|
||
class DictToObject(dict):
|
||
def __init__(self, dictionary):
|
||
super(self).__init__(dictionary)
|
||
|
||
for key, value in dictionary.items():
|
||
if isinstance(value, dict):
|
||
value = DictToObject(value)
|
||
setattr(self, key, value)
|
||
|
||
|
||
class VisionConfig(PretrainedConfig):
|
||
model_type = "vision"
|
||
cls: str = ""
|
||
params = {}
|
||
|
||
def __init__(self, **kwargs):
|
||
super().__init__(**kwargs)
|
||
|
||
self.cls = kwargs.get("cls", "")
|
||
if not isinstance(self.cls, str):
|
||
self.cls = self.cls.__name__
|
||
|
||
self.params = kwargs.get("params", {})
|
||
|
||
|
||
class GenAlignerConfig(PretrainedConfig):
|
||
model_type = "gen_aligner"
|
||
cls: str = ""
|
||
params = {}
|
||
|
||
def __init__(self, **kwargs):
|
||
super().__init__(**kwargs)
|
||
|
||
self.cls = kwargs.get("cls", "")
|
||
if not isinstance(self.cls, str):
|
||
self.cls = self.cls.__name__
|
||
|
||
self.params = kwargs.get("params", {})
|
||
|
||
|
||
class GenHeadConfig(PretrainedConfig):
|
||
model_type = "gen_head"
|
||
cls: str = ""
|
||
params = {}
|
||
|
||
def __init__(self, **kwargs):
|
||
super().__init__(**kwargs)
|
||
|
||
self.cls = kwargs.get("cls", "")
|
||
if not isinstance(self.cls, str):
|
||
self.cls = self.cls.__name__
|
||
|
||
self.params = kwargs.get("params", {})
|
||
|
||
|
||
class AlignerConfig(PretrainedConfig):
|
||
model_type = "aligner"
|
||
cls: str = ""
|
||
params = {}
|
||
|
||
def __init__(self, **kwargs):
|
||
super().__init__(**kwargs)
|
||
|
||
self.cls = kwargs.get("cls", "")
|
||
if not isinstance(self.cls, str):
|
||
self.cls = self.cls.__name__
|
||
|
||
self.params = kwargs.get("params", {})
|
||
|
||
|
||
class GenVisionConfig(PretrainedConfig):
|
||
model_type = "gen_vision"
|
||
cls: str = ""
|
||
params = {}
|
||
|
||
def __init__(self, **kwargs):
|
||
super().__init__(**kwargs)
|
||
|
||
self.cls = kwargs.get("cls", "")
|
||
if not isinstance(self.cls, str):
|
||
self.cls = self.cls.__name__
|
||
|
||
self.params = kwargs.get("params", {})
|
||
|
||
|
||
@dataclass
|
||
class SigLIPVisionCfg:
|
||
width: int = 1152
|
||
layers: Union[Tuple[int, int, int, int], int] = 27
|
||
heads: int = 16
|
||
patch_size: int = 14
|
||
image_size: Union[Tuple[int, int], int] = 336
|
||
global_pool: str = "map"
|
||
mlp_ratio: float = 3.7362
|
||
class_token: bool = False
|
||
num_classes: int = 0
|
||
use_checkpoint: bool = False
|
||
|
||
|
||
class MultiModalityConfig(PretrainedConfig):
|
||
model_type = "multi_modality"
|
||
vision_config: VisionConfig
|
||
aligner_config: AlignerConfig
|
||
|
||
gen_vision_config: GenVisionConfig
|
||
gen_aligner_config: GenAlignerConfig
|
||
gen_head_config: GenHeadConfig
|
||
|
||
language_config: LlamaConfig
|
||
|
||
def __init__(self, **kwargs):
|
||
super().__init__(**kwargs)
|
||
vision_config = kwargs.get("vision_config", {})
|
||
self.vision_config = VisionConfig(**vision_config)
|
||
|
||
aligner_config = kwargs.get("aligner_config", {})
|
||
self.aligner_config = AlignerConfig(**aligner_config)
|
||
|
||
gen_vision_config = kwargs.get("gen_vision_config", {})
|
||
self.gen_vision_config = GenVisionConfig(**gen_vision_config)
|
||
|
||
gen_aligner_config = kwargs.get("gen_aligner_config", {})
|
||
self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config)
|
||
|
||
gen_head_config = kwargs.get("gen_head_config", {})
|
||
self.gen_head_config = GenHeadConfig(**gen_head_config)
|
||
|
||
language_config = kwargs.get("language_config", {})
|
||
if isinstance(language_config, LlamaConfig):
|
||
self.language_config = language_config
|
||
else:
|
||
self.language_config = LlamaConfig(**language_config)
|
||
|
||
|
||
class VLMImageProcessor(BaseImageProcessor):
|
||
model_input_names = ["pixel_values"]
|
||
|
||
def __init__(
|
||
self,
|
||
image_size: int,
|
||
min_size: int = 14,
|
||
image_mean: Union[Tuple[float, float, float], List[float]] = (
|
||
0.48145466,
|
||
0.4578275,
|
||
0.40821073,
|
||
),
|
||
image_std: Union[Tuple[float, float, float], List[float]] = (
|
||
0.26862954,
|
||
0.26130258,
|
||
0.27577711,
|
||
),
|
||
rescale_factor: float = 1.0 / 255.0,
|
||
do_normalize: bool = True,
|
||
**kwargs,
|
||
):
|
||
super().__init__(**kwargs)
|
||
|
||
self.image_size = image_size
|
||
self.rescale_factor = rescale_factor
|
||
self.image_mean = image_mean
|
||
self.image_std = image_std
|
||
self.min_size = min_size
|
||
self.do_normalize = do_normalize
|
||
|
||
if image_mean is None:
|
||
self.background_color = (127, 127, 127)
|
||
else:
|
||
self.background_color = tuple([int(x * 255) for x in image_mean])
|
||
|
||
def resize(self, pil_img: Image) -> np.ndarray:
|
||
"""
|
||
|
||
Args:
|
||
pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB
|
||
|
||
Returns:
|
||
x (np.ndarray): [3, self.image_size, self.image_size]
|
||
"""
|
||
|
||
width, height = pil_img.size
|
||
max_size = max(width, height)
|
||
|
||
size = [
|
||
max(int(height / max_size * self.image_size), self.min_size),
|
||
max(int(width / max_size * self.image_size), self.min_size),
|
||
]
|
||
|
||
if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0:
|
||
# print(f"orig size = {pil_img.size}, new size = {size}")
|
||
raise ValueError("Invalid size!")
|
||
|
||
def resize(
|
||
pil_img, size, interpolation=PIL.Image.Resampling.BICUBIC, antialias=True
|
||
):
|
||
if isinstance(size, int):
|
||
w, h = pil_img.size
|
||
if (w <= h and w == size) or (h <= w and h == size):
|
||
return pil_img
|
||
if w < h:
|
||
ow = size
|
||
oh = int(size * h / w)
|
||
else:
|
||
oh = size
|
||
ow = int(size * w / h)
|
||
size = (ow, oh)
|
||
else:
|
||
size = (size[1], size[0])
|
||
|
||
return pil_img.resize(
|
||
size, resample=interpolation, reducing_gap=None if antialias else 3.0
|
||
)
|
||
|
||
pil_img = resize(
|
||
pil_img, size, interpolation=PIL.Image.Resampling.BICUBIC, antialias=True
|
||
)
|
||
|
||
pil_img = expand2square(pil_img, self.background_color)
|
||
x = to_numpy_array(pil_img)
|
||
|
||
# [H, W, 3] -> [3, H, W]
|
||
x = np.transpose(x, (2, 0, 1))
|
||
|
||
return x
|
||
|
||
def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature:
|
||
# resize and pad to [self.image_size, self.image_size]
|
||
# then convert from [H, W, 3] to [3, H, W]
|
||
if not isinstance(images, list):
|
||
images = [images]
|
||
images: List[np.ndarray] = [self.resize(image) for image in images]
|
||
images = [image[:3, ...] for image in images]
|
||
|
||
# rescale from [0, 255] -> [0, 1]
|
||
images = [
|
||
self.rescale(
|
||
image=image,
|
||
scale=self.rescale_factor,
|
||
input_data_format="channels_first",
|
||
)
|
||
for image in images
|
||
]
|
||
|
||
# normalize
|
||
if self.do_normalize:
|
||
images = [
|
||
self.normalize(
|
||
image=image,
|
||
mean=self.image_mean,
|
||
std=self.image_std,
|
||
input_data_format="channels_first",
|
||
)
|
||
for image in images
|
||
]
|
||
data = {"pixel_values": images}
|
||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||
|
||
@property
|
||
def default_shape(self):
|
||
return [3, self.image_size, self.image_size]
|
||
|
||
|
||
class DictOutput(object):
|
||
def keys(self):
|
||
return self.__dict__.keys()
|
||
|
||
def __getitem__(self, item):
|
||
return self.__dict__[item]
|
||
|
||
def __setitem__(self, key, value):
|
||
self.__dict__[key] = value
|
||
|
||
|
||
@dataclass
|
||
class VLChatProcessorOutput(DictOutput):
|
||
sft_format: str
|
||
input_ids: torch.Tensor
|
||
pixel_values: torch.Tensor
|
||
num_image_tokens: torch.IntTensor
|
||
|
||
def __len__(self):
|
||
return len(self.input_ids)
|
||
|
||
|
||
@dataclass
|
||
class BatchedVLChatProcessorOutput(DictOutput):
|
||
sft_format: List[str]
|
||
input_ids: torch.Tensor
|
||
pixel_values: torch.Tensor
|
||
attention_mask: torch.Tensor
|
||
images_seq_mask: torch.BoolTensor
|
||
images_emb_mask: torch.BoolTensor
|
||
|
||
|
||
# FIXME: had to place Official Processor here, since image_processor module would not be imported in all threads,
|
||
# hence AutoProcessor registration would not be affective in some cases
|
||
class VLChatProcessor(ProcessorMixin):
|
||
image_processor_class = "AutoImageProcessor"
|
||
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
||
|
||
attributes = ["image_processor", "tokenizer"]
|
||
|
||
def __init__(
|
||
self,
|
||
image_processor: VLMImageProcessor,
|
||
tokenizer: LlamaTokenizerFast,
|
||
image_tag: str = "<image_placeholder>",
|
||
image_start_tag: str = "<begin_of_image>",
|
||
image_end_tag: str = "<end_of_image>",
|
||
pad_tag: str = "<|▁pad▁|>",
|
||
num_image_tokens: int = 576,
|
||
add_special_token: bool = False,
|
||
sft_format: str = "deepseek",
|
||
mask_prompt: bool = True,
|
||
ignore_id: int = -100,
|
||
**kwargs,
|
||
):
|
||
self.image_processor = image_processor
|
||
self.tokenizer = tokenizer
|
||
|
||
image_id = self.tokenizer.vocab.get(image_tag)
|
||
if image_id is None:
|
||
special_tokens = [image_tag]
|
||
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
||
self.tokenizer.add_special_tokens(special_tokens_dict)
|
||
# print(f"Add image tag = {image_tag} to the tokenizer")
|
||
|
||
self.image_tag = image_tag
|
||
self.image_start_tag = image_start_tag
|
||
self.image_end_tag = image_end_tag
|
||
self.pad_tag = pad_tag
|
||
|
||
self.num_image_tokens = num_image_tokens
|
||
self.add_special_token = add_special_token
|
||
self.sft_format = sft_format
|
||
self.ignore_id = ignore_id
|
||
|
||
super().__init__(
|
||
image_processor,
|
||
tokenizer,
|
||
**kwargs,
|
||
)
|
||
|
||
@property
|
||
def image_token(self):
|
||
return self.image_tag
|
||
|
||
@property
|
||
def image_id(self) -> int:
|
||
image_id = self.tokenizer.vocab.get(self.image_tag)
|
||
return image_id
|
||
|
||
@property
|
||
def image_start_id(self):
|
||
image_start_id = self.tokenizer.vocab.get(self.image_start_tag)
|
||
return image_start_id
|
||
|
||
@property
|
||
def image_end_id(self):
|
||
image_end_id = self.tokenizer.vocab.get(self.image_end_tag)
|
||
return image_end_id
|
||
|
||
@property
|
||
def image_start_token(self):
|
||
return self.image_start_tag
|
||
|
||
@property
|
||
def image_end_token(self):
|
||
return self.image_end_tag
|
||
|
||
@property
|
||
def pad_id(self):
|
||
pad_id = self.tokenizer.vocab.get(self.pad_tag)
|
||
return pad_id
|
||
|
||
def add_image_token(
|
||
self,
|
||
image_indices: List[int],
|
||
input_ids: torch.LongTensor,
|
||
):
|
||
"""
|
||
|
||
Args:
|
||
image_indices (List[int]): [index_0, index_1, ..., index_j]
|
||
input_ids (torch.LongTensor): [N]
|
||
|
||
Returns:
|
||
input_ids (torch.LongTensor): [N + image tokens]
|
||
num_image_tokens (torch.IntTensor): [n_images]
|
||
"""
|
||
|
||
input_slices = []
|
||
|
||
start = 0
|
||
for index in image_indices:
|
||
if self.add_special_token:
|
||
end = index + 1
|
||
else:
|
||
end = index
|
||
|
||
# original text tokens
|
||
input_slices.append(input_ids[start:end])
|
||
|
||
# add boi, image tokens, eoi and set the mask as False
|
||
input_slices.append(self.image_start_id * torch.ones((1), dtype=torch.long))
|
||
input_slices.append(
|
||
self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long)
|
||
)
|
||
input_slices.append(self.image_end_id * torch.ones((1), dtype=torch.long))
|
||
start = index + 1
|
||
|
||
# the left part
|
||
input_slices.append(input_ids[start:])
|
||
|
||
# concat all slices
|
||
input_ids = torch.cat(input_slices, dim=0)
|
||
num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices))
|
||
|
||
return input_ids, num_image_tokens
|
||
|
||
def process_one(
|
||
self,
|
||
prompt: str = None,
|
||
images: List[Image] = None,
|
||
**kwargs,
|
||
):
|
||
"""
|
||
|
||
Args:
|
||
prompt (str): the formatted prompt;
|
||
images (List[ImageType]): the list of images;
|
||
**kwargs:
|
||
|
||
Returns:
|
||
outputs (BaseProcessorOutput): the output of the processor,
|
||
- input_ids (torch.LongTensor): [N + image tokens]
|
||
- target_ids (torch.LongTensor): [N + image tokens]
|
||
- images (torch.FloatTensor): [n_images, 3, H, W]
|
||
- image_id (int): the id of the image token
|
||
- num_image_tokens (List[int]): the number of image tokens
|
||
"""
|
||
|
||
sft_format = prompt
|
||
# tokenize
|
||
input_ids = self.tokenizer.encode(sft_format)
|
||
input_ids = torch.LongTensor(input_ids)
|
||
|
||
# add image tokens to the input_ids
|
||
image_token_mask: torch.Tensor = (input_ids == self.image_id).to(torch.bool)
|
||
image_indices = image_token_mask.nonzero()
|
||
input_ids, num_image_tokens = self.add_image_token(
|
||
image_indices=image_indices,
|
||
input_ids=input_ids,
|
||
)
|
||
|
||
# load images
|
||
images_outputs = self.image_processor(images, return_tensors="pt")
|
||
|
||
prepare = VLChatProcessorOutput(
|
||
sft_format=sft_format,
|
||
input_ids=input_ids,
|
||
pixel_values=images_outputs.pixel_values,
|
||
num_image_tokens=num_image_tokens,
|
||
)
|
||
|
||
return prepare
|
||
|
||
def __call__(
|
||
self,
|
||
*,
|
||
prompt: str = None,
|
||
conversations: List[Dict[str, str]] = None,
|
||
images: List[Image] = None,
|
||
force_batchify: bool = True,
|
||
**kwargs,
|
||
):
|
||
"""
|
||
|
||
Args:
|
||
prompt (str): the formatted prompt;
|
||
conversations (List[Dict]): conversations with a list of messages;
|
||
images (List[ImageType]): the list of images;
|
||
force_batchify (bool): force batchify the inputs;
|
||
**kwargs:
|
||
|
||
Returns:
|
||
outputs (BaseProcessorOutput): the output of the processor,
|
||
- input_ids (torch.LongTensor): [N + image tokens]
|
||
- images (torch.FloatTensor): [n_images, 3, H, W]
|
||
- image_id (int): the id of the image token
|
||
- num_image_tokens (List[int]): the number of image tokens
|
||
"""
|
||
|
||
prepare = self.process_one(
|
||
prompt=prompt, conversations=conversations, images=images
|
||
)
|
||
|
||
if force_batchify:
|
||
prepare = self.batchify([prepare])
|
||
|
||
return prepare
|
||
|
||
def batchify(
|
||
self, prepare_list: List[VLChatProcessorOutput]
|
||
) -> BatchedVLChatProcessorOutput:
|
||
"""
|
||
Preprocesses the inputs for multimodal inference.
|
||
|
||
Args:
|
||
prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput.
|
||
|
||
Returns:
|
||
BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference.
|
||
"""
|
||
|
||
batch_size = len(prepare_list)
|
||
sft_format = []
|
||
n_images = []
|
||
seq_lens = []
|
||
for prepare in prepare_list:
|
||
n_images.append(len(prepare.num_image_tokens))
|
||
seq_lens.append(len(prepare))
|
||
|
||
input_token_max_len = max(seq_lens)
|
||
max_n_images = max(1, max(n_images))
|
||
|
||
batched_input_ids = torch.full(
|
||
(batch_size, input_token_max_len), self.pad_id
|
||
).long() # FIXME
|
||
batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long()
|
||
batched_pixel_values = torch.zeros(
|
||
(batch_size, max_n_images, *self.image_processor.default_shape)
|
||
).float()
|
||
batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool()
|
||
batched_images_emb_mask = torch.zeros(
|
||
(batch_size, max_n_images, self.num_image_tokens)
|
||
).bool()
|
||
|
||
for i, prepare in enumerate(prepare_list):
|
||
input_ids = prepare.input_ids
|
||
seq_len = len(prepare)
|
||
n_image = len(prepare.num_image_tokens)
|
||
# left-padding
|
||
batched_attention_mask[i, -seq_len:] = 1
|
||
batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids)
|
||
batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id
|
||
|
||
if n_image > 0:
|
||
batched_pixel_values[i, :n_image] = prepare.pixel_values
|
||
for j, n_image_tokens in enumerate(prepare.num_image_tokens):
|
||
batched_images_emb_mask[i, j, :n_image_tokens] = True
|
||
|
||
sft_format.append(prepare.sft_format)
|
||
|
||
batched_prepares = BatchedVLChatProcessorOutput(
|
||
input_ids=batched_input_ids,
|
||
attention_mask=batched_attention_mask,
|
||
pixel_values=batched_pixel_values,
|
||
images_seq_mask=batched_images_seq_mask,
|
||
images_emb_mask=batched_images_emb_mask,
|
||
sft_format=sft_format,
|
||
)
|
||
|
||
return batched_prepares
|
||
|
||
|
||
class VLMImageProcessorConfig(PretrainedConfig):
|
||
model_type = "deepseek_vlm"
|
||
image_size: int
|
||
min_size: int
|
||
image_mean: Union[Tuple[float, float, float], List[float]]
|
||
image_std: Union[Tuple[float, float, float], List[float]]
|
||
rescale_factor: float
|
||
do_normalize: bool
|
||
|
||
def __init__(
|
||
self,
|
||
image_size: int,
|
||
min_size: int = 14,
|
||
image_mean: Union[Tuple[float, float, float], List[float]] = (
|
||
0.48145466,
|
||
0.4578275,
|
||
0.40821073,
|
||
),
|
||
image_std: Union[Tuple[float, float, float], List[float]] = (
|
||
0.26862954,
|
||
0.26130258,
|
||
0.27577711,
|
||
),
|
||
rescale_factor: float = 1.0 / 255.0,
|
||
do_normalize: bool = True,
|
||
**kwargs,
|
||
):
|
||
self.image_size = image_size
|
||
self.min_size = min_size
|
||
self.image_mean = image_mean
|
||
self.image_std = image_std
|
||
self.rescale_factor = rescale_factor
|
||
self.do_normalize = do_normalize
|
||
|
||
super().__init__(**kwargs)
|
||
|
||
|
||
register_processor(MultiModalityConfig, VLChatProcessor)
|
||
register_image_processor(MultiModalityConfig, VLMImageProcessor)
|