sglang0.4.5.post1/python/sglang/srt/configs/janus_pro.py

629 lines
19 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Adapted from:
# https://github.com/deepseek-ai/Janus/tree/main/janus/models
from dataclasses import dataclass
from typing import Dict, List, Tuple, Union
import numpy as np
import PIL
import torch
from PIL.Image import Image
from transformers import (
BaseImageProcessor,
BatchFeature,
LlamaConfig,
LlamaTokenizerFast,
PretrainedConfig,
ProcessorMixin,
)
from transformers.image_utils import to_numpy_array
from sglang.srt.configs.utils import register_image_processor, register_processor
from sglang.srt.mm_utils import expand2square
class DictToObject(dict):
def __init__(self, dictionary):
super(self).__init__(dictionary)
for key, value in dictionary.items():
if isinstance(value, dict):
value = DictToObject(value)
setattr(self, key, value)
class VisionConfig(PretrainedConfig):
model_type = "vision"
cls: str = ""
params = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = kwargs.get("params", {})
class GenAlignerConfig(PretrainedConfig):
model_type = "gen_aligner"
cls: str = ""
params = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = kwargs.get("params", {})
class GenHeadConfig(PretrainedConfig):
model_type = "gen_head"
cls: str = ""
params = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = kwargs.get("params", {})
class AlignerConfig(PretrainedConfig):
model_type = "aligner"
cls: str = ""
params = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = kwargs.get("params", {})
class GenVisionConfig(PretrainedConfig):
model_type = "gen_vision"
cls: str = ""
params = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = kwargs.get("params", {})
@dataclass
class SigLIPVisionCfg:
width: int = 1152
layers: Union[Tuple[int, int, int, int], int] = 27
heads: int = 16
patch_size: int = 14
image_size: Union[Tuple[int, int], int] = 336
global_pool: str = "map"
mlp_ratio: float = 3.7362
class_token: bool = False
num_classes: int = 0
use_checkpoint: bool = False
class MultiModalityConfig(PretrainedConfig):
model_type = "multi_modality"
vision_config: VisionConfig
aligner_config: AlignerConfig
gen_vision_config: GenVisionConfig
gen_aligner_config: GenAlignerConfig
gen_head_config: GenHeadConfig
language_config: LlamaConfig
def __init__(self, **kwargs):
super().__init__(**kwargs)
vision_config = kwargs.get("vision_config", {})
self.vision_config = VisionConfig(**vision_config)
aligner_config = kwargs.get("aligner_config", {})
self.aligner_config = AlignerConfig(**aligner_config)
gen_vision_config = kwargs.get("gen_vision_config", {})
self.gen_vision_config = GenVisionConfig(**gen_vision_config)
gen_aligner_config = kwargs.get("gen_aligner_config", {})
self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config)
gen_head_config = kwargs.get("gen_head_config", {})
self.gen_head_config = GenHeadConfig(**gen_head_config)
language_config = kwargs.get("language_config", {})
if isinstance(language_config, LlamaConfig):
self.language_config = language_config
else:
self.language_config = LlamaConfig(**language_config)
class VLMImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
def __init__(
self,
image_size: int,
min_size: int = 14,
image_mean: Union[Tuple[float, float, float], List[float]] = (
0.48145466,
0.4578275,
0.40821073,
),
image_std: Union[Tuple[float, float, float], List[float]] = (
0.26862954,
0.26130258,
0.27577711,
),
rescale_factor: float = 1.0 / 255.0,
do_normalize: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.image_size = image_size
self.rescale_factor = rescale_factor
self.image_mean = image_mean
self.image_std = image_std
self.min_size = min_size
self.do_normalize = do_normalize
if image_mean is None:
self.background_color = (127, 127, 127)
else:
self.background_color = tuple([int(x * 255) for x in image_mean])
def resize(self, pil_img: Image) -> np.ndarray:
"""
Args:
pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB
Returns:
x (np.ndarray): [3, self.image_size, self.image_size]
"""
width, height = pil_img.size
max_size = max(width, height)
size = [
max(int(height / max_size * self.image_size), self.min_size),
max(int(width / max_size * self.image_size), self.min_size),
]
if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0:
# print(f"orig size = {pil_img.size}, new size = {size}")
raise ValueError("Invalid size!")
def resize(
pil_img, size, interpolation=PIL.Image.Resampling.BICUBIC, antialias=True
):
if isinstance(size, int):
w, h = pil_img.size
if (w <= h and w == size) or (h <= w and h == size):
return pil_img
if w < h:
ow = size
oh = int(size * h / w)
else:
oh = size
ow = int(size * w / h)
size = (ow, oh)
else:
size = (size[1], size[0])
return pil_img.resize(
size, resample=interpolation, reducing_gap=None if antialias else 3.0
)
pil_img = resize(
pil_img, size, interpolation=PIL.Image.Resampling.BICUBIC, antialias=True
)
pil_img = expand2square(pil_img, self.background_color)
x = to_numpy_array(pil_img)
# [H, W, 3] -> [3, H, W]
x = np.transpose(x, (2, 0, 1))
return x
def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature:
# resize and pad to [self.image_size, self.image_size]
# then convert from [H, W, 3] to [3, H, W]
if not isinstance(images, list):
images = [images]
images: List[np.ndarray] = [self.resize(image) for image in images]
images = [image[:3, ...] for image in images]
# rescale from [0, 255] -> [0, 1]
images = [
self.rescale(
image=image,
scale=self.rescale_factor,
input_data_format="channels_first",
)
for image in images
]
# normalize
if self.do_normalize:
images = [
self.normalize(
image=image,
mean=self.image_mean,
std=self.image_std,
input_data_format="channels_first",
)
for image in images
]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
@property
def default_shape(self):
return [3, self.image_size, self.image_size]
class DictOutput(object):
def keys(self):
return self.__dict__.keys()
def __getitem__(self, item):
return self.__dict__[item]
def __setitem__(self, key, value):
self.__dict__[key] = value
@dataclass
class VLChatProcessorOutput(DictOutput):
sft_format: str
input_ids: torch.Tensor
pixel_values: torch.Tensor
num_image_tokens: torch.IntTensor
def __len__(self):
return len(self.input_ids)
@dataclass
class BatchedVLChatProcessorOutput(DictOutput):
sft_format: List[str]
input_ids: torch.Tensor
pixel_values: torch.Tensor
attention_mask: torch.Tensor
images_seq_mask: torch.BoolTensor
images_emb_mask: torch.BoolTensor
# FIXME: had to place Official Processor here, since image_processor module would not be imported in all threads,
# hence AutoProcessor registration would not be affective in some cases
class VLChatProcessor(ProcessorMixin):
image_processor_class = "AutoImageProcessor"
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
attributes = ["image_processor", "tokenizer"]
def __init__(
self,
image_processor: VLMImageProcessor,
tokenizer: LlamaTokenizerFast,
image_tag: str = "<image_placeholder>",
image_start_tag: str = "<begin_of_image>",
image_end_tag: str = "<end_of_image>",
pad_tag: str = "<▁pad▁>",
num_image_tokens: int = 576,
add_special_token: bool = False,
sft_format: str = "deepseek",
mask_prompt: bool = True,
ignore_id: int = -100,
**kwargs,
):
self.image_processor = image_processor
self.tokenizer = tokenizer
image_id = self.tokenizer.vocab.get(image_tag)
if image_id is None:
special_tokens = [image_tag]
special_tokens_dict = {"additional_special_tokens": special_tokens}
self.tokenizer.add_special_tokens(special_tokens_dict)
# print(f"Add image tag = {image_tag} to the tokenizer")
self.image_tag = image_tag
self.image_start_tag = image_start_tag
self.image_end_tag = image_end_tag
self.pad_tag = pad_tag
self.num_image_tokens = num_image_tokens
self.add_special_token = add_special_token
self.sft_format = sft_format
self.ignore_id = ignore_id
super().__init__(
image_processor,
tokenizer,
**kwargs,
)
@property
def image_token(self):
return self.image_tag
@property
def image_id(self) -> int:
image_id = self.tokenizer.vocab.get(self.image_tag)
return image_id
@property
def image_start_id(self):
image_start_id = self.tokenizer.vocab.get(self.image_start_tag)
return image_start_id
@property
def image_end_id(self):
image_end_id = self.tokenizer.vocab.get(self.image_end_tag)
return image_end_id
@property
def image_start_token(self):
return self.image_start_tag
@property
def image_end_token(self):
return self.image_end_tag
@property
def pad_id(self):
pad_id = self.tokenizer.vocab.get(self.pad_tag)
return pad_id
def add_image_token(
self,
image_indices: List[int],
input_ids: torch.LongTensor,
):
"""
Args:
image_indices (List[int]): [index_0, index_1, ..., index_j]
input_ids (torch.LongTensor): [N]
Returns:
input_ids (torch.LongTensor): [N + image tokens]
num_image_tokens (torch.IntTensor): [n_images]
"""
input_slices = []
start = 0
for index in image_indices:
if self.add_special_token:
end = index + 1
else:
end = index
# original text tokens
input_slices.append(input_ids[start:end])
# add boi, image tokens, eoi and set the mask as False
input_slices.append(self.image_start_id * torch.ones((1), dtype=torch.long))
input_slices.append(
self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long)
)
input_slices.append(self.image_end_id * torch.ones((1), dtype=torch.long))
start = index + 1
# the left part
input_slices.append(input_ids[start:])
# concat all slices
input_ids = torch.cat(input_slices, dim=0)
num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices))
return input_ids, num_image_tokens
def process_one(
self,
prompt: str = None,
images: List[Image] = None,
**kwargs,
):
"""
Args:
prompt (str): the formatted prompt;
images (List[ImageType]): the list of images;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- target_ids (torch.LongTensor): [N + image tokens]
- images (torch.FloatTensor): [n_images, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
sft_format = prompt
# tokenize
input_ids = self.tokenizer.encode(sft_format)
input_ids = torch.LongTensor(input_ids)
# add image tokens to the input_ids
image_token_mask: torch.Tensor = (input_ids == self.image_id).to(torch.bool)
image_indices = image_token_mask.nonzero()
input_ids, num_image_tokens = self.add_image_token(
image_indices=image_indices,
input_ids=input_ids,
)
# load images
images_outputs = self.image_processor(images, return_tensors="pt")
prepare = VLChatProcessorOutput(
sft_format=sft_format,
input_ids=input_ids,
pixel_values=images_outputs.pixel_values,
num_image_tokens=num_image_tokens,
)
return prepare
def __call__(
self,
*,
prompt: str = None,
conversations: List[Dict[str, str]] = None,
images: List[Image] = None,
force_batchify: bool = True,
**kwargs,
):
"""
Args:
prompt (str): the formatted prompt;
conversations (List[Dict]): conversations with a list of messages;
images (List[ImageType]): the list of images;
force_batchify (bool): force batchify the inputs;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- images (torch.FloatTensor): [n_images, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
prepare = self.process_one(
prompt=prompt, conversations=conversations, images=images
)
if force_batchify:
prepare = self.batchify([prepare])
return prepare
def batchify(
self, prepare_list: List[VLChatProcessorOutput]
) -> BatchedVLChatProcessorOutput:
"""
Preprocesses the inputs for multimodal inference.
Args:
prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput.
Returns:
BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference.
"""
batch_size = len(prepare_list)
sft_format = []
n_images = []
seq_lens = []
for prepare in prepare_list:
n_images.append(len(prepare.num_image_tokens))
seq_lens.append(len(prepare))
input_token_max_len = max(seq_lens)
max_n_images = max(1, max(n_images))
batched_input_ids = torch.full(
(batch_size, input_token_max_len), self.pad_id
).long() # FIXME
batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long()
batched_pixel_values = torch.zeros(
(batch_size, max_n_images, *self.image_processor.default_shape)
).float()
batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool()
batched_images_emb_mask = torch.zeros(
(batch_size, max_n_images, self.num_image_tokens)
).bool()
for i, prepare in enumerate(prepare_list):
input_ids = prepare.input_ids
seq_len = len(prepare)
n_image = len(prepare.num_image_tokens)
# left-padding
batched_attention_mask[i, -seq_len:] = 1
batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids)
batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id
if n_image > 0:
batched_pixel_values[i, :n_image] = prepare.pixel_values
for j, n_image_tokens in enumerate(prepare.num_image_tokens):
batched_images_emb_mask[i, j, :n_image_tokens] = True
sft_format.append(prepare.sft_format)
batched_prepares = BatchedVLChatProcessorOutput(
input_ids=batched_input_ids,
attention_mask=batched_attention_mask,
pixel_values=batched_pixel_values,
images_seq_mask=batched_images_seq_mask,
images_emb_mask=batched_images_emb_mask,
sft_format=sft_format,
)
return batched_prepares
class VLMImageProcessorConfig(PretrainedConfig):
model_type = "deepseek_vlm"
image_size: int
min_size: int
image_mean: Union[Tuple[float, float, float], List[float]]
image_std: Union[Tuple[float, float, float], List[float]]
rescale_factor: float
do_normalize: bool
def __init__(
self,
image_size: int,
min_size: int = 14,
image_mean: Union[Tuple[float, float, float], List[float]] = (
0.48145466,
0.4578275,
0.40821073,
),
image_std: Union[Tuple[float, float, float], List[float]] = (
0.26862954,
0.26130258,
0.27577711,
),
rescale_factor: float = 1.0 / 255.0,
do_normalize: bool = True,
**kwargs,
):
self.image_size = image_size
self.min_size = min_size
self.image_mean = image_mean
self.image_std = image_std
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
super().__init__(**kwargs)
register_processor(MultiModalityConfig, VLChatProcessor)
register_image_processor(MultiModalityConfig, VLMImageProcessor)