# Adapted from: # https://github.com/deepseek-ai/Janus/tree/main/janus/models from dataclasses import dataclass from typing import Dict, List, Tuple, Union import numpy as np import PIL import torch from PIL.Image import Image from transformers import ( BaseImageProcessor, BatchFeature, LlamaConfig, LlamaTokenizerFast, PretrainedConfig, ProcessorMixin, ) from transformers.image_utils import to_numpy_array from sglang.srt.configs.utils import register_image_processor, register_processor from sglang.srt.mm_utils import expand2square class DictToObject(dict): def __init__(self, dictionary): super(self).__init__(dictionary) for key, value in dictionary.items(): if isinstance(value, dict): value = DictToObject(value) setattr(self, key, value) class VisionConfig(PretrainedConfig): model_type = "vision" cls: str = "" params = {} def __init__(self, **kwargs): super().__init__(**kwargs) self.cls = kwargs.get("cls", "") if not isinstance(self.cls, str): self.cls = self.cls.__name__ self.params = kwargs.get("params", {}) class GenAlignerConfig(PretrainedConfig): model_type = "gen_aligner" cls: str = "" params = {} def __init__(self, **kwargs): super().__init__(**kwargs) self.cls = kwargs.get("cls", "") if not isinstance(self.cls, str): self.cls = self.cls.__name__ self.params = kwargs.get("params", {}) class GenHeadConfig(PretrainedConfig): model_type = "gen_head" cls: str = "" params = {} def __init__(self, **kwargs): super().__init__(**kwargs) self.cls = kwargs.get("cls", "") if not isinstance(self.cls, str): self.cls = self.cls.__name__ self.params = kwargs.get("params", {}) class AlignerConfig(PretrainedConfig): model_type = "aligner" cls: str = "" params = {} def __init__(self, **kwargs): super().__init__(**kwargs) self.cls = kwargs.get("cls", "") if not isinstance(self.cls, str): self.cls = self.cls.__name__ self.params = kwargs.get("params", {}) class GenVisionConfig(PretrainedConfig): model_type = "gen_vision" cls: str = "" params = {} def __init__(self, **kwargs): super().__init__(**kwargs) self.cls = kwargs.get("cls", "") if not isinstance(self.cls, str): self.cls = self.cls.__name__ self.params = kwargs.get("params", {}) @dataclass class SigLIPVisionCfg: width: int = 1152 layers: Union[Tuple[int, int, int, int], int] = 27 heads: int = 16 patch_size: int = 14 image_size: Union[Tuple[int, int], int] = 336 global_pool: str = "map" mlp_ratio: float = 3.7362 class_token: bool = False num_classes: int = 0 use_checkpoint: bool = False class MultiModalityConfig(PretrainedConfig): model_type = "multi_modality" vision_config: VisionConfig aligner_config: AlignerConfig gen_vision_config: GenVisionConfig gen_aligner_config: GenAlignerConfig gen_head_config: GenHeadConfig language_config: LlamaConfig def __init__(self, **kwargs): super().__init__(**kwargs) vision_config = kwargs.get("vision_config", {}) self.vision_config = VisionConfig(**vision_config) aligner_config = kwargs.get("aligner_config", {}) self.aligner_config = AlignerConfig(**aligner_config) gen_vision_config = kwargs.get("gen_vision_config", {}) self.gen_vision_config = GenVisionConfig(**gen_vision_config) gen_aligner_config = kwargs.get("gen_aligner_config", {}) self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config) gen_head_config = kwargs.get("gen_head_config", {}) self.gen_head_config = GenHeadConfig(**gen_head_config) language_config = kwargs.get("language_config", {}) if isinstance(language_config, LlamaConfig): self.language_config = language_config else: self.language_config = LlamaConfig(**language_config) class VLMImageProcessor(BaseImageProcessor): model_input_names = ["pixel_values"] def __init__( self, image_size: int, min_size: int = 14, image_mean: Union[Tuple[float, float, float], List[float]] = ( 0.48145466, 0.4578275, 0.40821073, ), image_std: Union[Tuple[float, float, float], List[float]] = ( 0.26862954, 0.26130258, 0.27577711, ), rescale_factor: float = 1.0 / 255.0, do_normalize: bool = True, **kwargs, ): super().__init__(**kwargs) self.image_size = image_size self.rescale_factor = rescale_factor self.image_mean = image_mean self.image_std = image_std self.min_size = min_size self.do_normalize = do_normalize if image_mean is None: self.background_color = (127, 127, 127) else: self.background_color = tuple([int(x * 255) for x in image_mean]) def resize(self, pil_img: Image) -> np.ndarray: """ Args: pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB Returns: x (np.ndarray): [3, self.image_size, self.image_size] """ width, height = pil_img.size max_size = max(width, height) size = [ max(int(height / max_size * self.image_size), self.min_size), max(int(width / max_size * self.image_size), self.min_size), ] if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0: # print(f"orig size = {pil_img.size}, new size = {size}") raise ValueError("Invalid size!") def resize( pil_img, size, interpolation=PIL.Image.Resampling.BICUBIC, antialias=True ): if isinstance(size, int): w, h = pil_img.size if (w <= h and w == size) or (h <= w and h == size): return pil_img if w < h: ow = size oh = int(size * h / w) else: oh = size ow = int(size * w / h) size = (ow, oh) else: size = (size[1], size[0]) return pil_img.resize( size, resample=interpolation, reducing_gap=None if antialias else 3.0 ) pil_img = resize( pil_img, size, interpolation=PIL.Image.Resampling.BICUBIC, antialias=True ) pil_img = expand2square(pil_img, self.background_color) x = to_numpy_array(pil_img) # [H, W, 3] -> [3, H, W] x = np.transpose(x, (2, 0, 1)) return x def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature: # resize and pad to [self.image_size, self.image_size] # then convert from [H, W, 3] to [3, H, W] if not isinstance(images, list): images = [images] images: List[np.ndarray] = [self.resize(image) for image in images] images = [image[:3, ...] for image in images] # rescale from [0, 255] -> [0, 1] images = [ self.rescale( image=image, scale=self.rescale_factor, input_data_format="channels_first", ) for image in images ] # normalize if self.do_normalize: images = [ self.normalize( image=image, mean=self.image_mean, std=self.image_std, input_data_format="channels_first", ) for image in images ] data = {"pixel_values": images} return BatchFeature(data=data, tensor_type=return_tensors) @property def default_shape(self): return [3, self.image_size, self.image_size] class DictOutput(object): def keys(self): return self.__dict__.keys() def __getitem__(self, item): return self.__dict__[item] def __setitem__(self, key, value): self.__dict__[key] = value @dataclass class VLChatProcessorOutput(DictOutput): sft_format: str input_ids: torch.Tensor pixel_values: torch.Tensor num_image_tokens: torch.IntTensor def __len__(self): return len(self.input_ids) @dataclass class BatchedVLChatProcessorOutput(DictOutput): sft_format: List[str] input_ids: torch.Tensor pixel_values: torch.Tensor attention_mask: torch.Tensor images_seq_mask: torch.BoolTensor images_emb_mask: torch.BoolTensor # FIXME: had to place Official Processor here, since image_processor module would not be imported in all threads, # hence AutoProcessor registration would not be affective in some cases class VLChatProcessor(ProcessorMixin): image_processor_class = "AutoImageProcessor" tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") attributes = ["image_processor", "tokenizer"] def __init__( self, image_processor: VLMImageProcessor, tokenizer: LlamaTokenizerFast, image_tag: str = "", image_start_tag: str = "", image_end_tag: str = "", pad_tag: str = "<|▁pad▁|>", num_image_tokens: int = 576, add_special_token: bool = False, sft_format: str = "deepseek", mask_prompt: bool = True, ignore_id: int = -100, **kwargs, ): self.image_processor = image_processor self.tokenizer = tokenizer image_id = self.tokenizer.vocab.get(image_tag) if image_id is None: special_tokens = [image_tag] special_tokens_dict = {"additional_special_tokens": special_tokens} self.tokenizer.add_special_tokens(special_tokens_dict) # print(f"Add image tag = {image_tag} to the tokenizer") self.image_tag = image_tag self.image_start_tag = image_start_tag self.image_end_tag = image_end_tag self.pad_tag = pad_tag self.num_image_tokens = num_image_tokens self.add_special_token = add_special_token self.sft_format = sft_format self.ignore_id = ignore_id super().__init__( image_processor, tokenizer, **kwargs, ) @property def image_token(self): return self.image_tag @property def image_id(self) -> int: image_id = self.tokenizer.vocab.get(self.image_tag) return image_id @property def image_start_id(self): image_start_id = self.tokenizer.vocab.get(self.image_start_tag) return image_start_id @property def image_end_id(self): image_end_id = self.tokenizer.vocab.get(self.image_end_tag) return image_end_id @property def image_start_token(self): return self.image_start_tag @property def image_end_token(self): return self.image_end_tag @property def pad_id(self): pad_id = self.tokenizer.vocab.get(self.pad_tag) return pad_id def add_image_token( self, image_indices: List[int], input_ids: torch.LongTensor, ): """ Args: image_indices (List[int]): [index_0, index_1, ..., index_j] input_ids (torch.LongTensor): [N] Returns: input_ids (torch.LongTensor): [N + image tokens] num_image_tokens (torch.IntTensor): [n_images] """ input_slices = [] start = 0 for index in image_indices: if self.add_special_token: end = index + 1 else: end = index # original text tokens input_slices.append(input_ids[start:end]) # add boi, image tokens, eoi and set the mask as False input_slices.append(self.image_start_id * torch.ones((1), dtype=torch.long)) input_slices.append( self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long) ) input_slices.append(self.image_end_id * torch.ones((1), dtype=torch.long)) start = index + 1 # the left part input_slices.append(input_ids[start:]) # concat all slices input_ids = torch.cat(input_slices, dim=0) num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices)) return input_ids, num_image_tokens def process_one( self, prompt: str = None, images: List[Image] = None, **kwargs, ): """ Args: prompt (str): the formatted prompt; images (List[ImageType]): the list of images; **kwargs: Returns: outputs (BaseProcessorOutput): the output of the processor, - input_ids (torch.LongTensor): [N + image tokens] - target_ids (torch.LongTensor): [N + image tokens] - images (torch.FloatTensor): [n_images, 3, H, W] - image_id (int): the id of the image token - num_image_tokens (List[int]): the number of image tokens """ sft_format = prompt # tokenize input_ids = self.tokenizer.encode(sft_format) input_ids = torch.LongTensor(input_ids) # add image tokens to the input_ids image_token_mask: torch.Tensor = (input_ids == self.image_id).to(torch.bool) image_indices = image_token_mask.nonzero() input_ids, num_image_tokens = self.add_image_token( image_indices=image_indices, input_ids=input_ids, ) # load images images_outputs = self.image_processor(images, return_tensors="pt") prepare = VLChatProcessorOutput( sft_format=sft_format, input_ids=input_ids, pixel_values=images_outputs.pixel_values, num_image_tokens=num_image_tokens, ) return prepare def __call__( self, *, prompt: str = None, conversations: List[Dict[str, str]] = None, images: List[Image] = None, force_batchify: bool = True, **kwargs, ): """ Args: prompt (str): the formatted prompt; conversations (List[Dict]): conversations with a list of messages; images (List[ImageType]): the list of images; force_batchify (bool): force batchify the inputs; **kwargs: Returns: outputs (BaseProcessorOutput): the output of the processor, - input_ids (torch.LongTensor): [N + image tokens] - images (torch.FloatTensor): [n_images, 3, H, W] - image_id (int): the id of the image token - num_image_tokens (List[int]): the number of image tokens """ prepare = self.process_one( prompt=prompt, conversations=conversations, images=images ) if force_batchify: prepare = self.batchify([prepare]) return prepare def batchify( self, prepare_list: List[VLChatProcessorOutput] ) -> BatchedVLChatProcessorOutput: """ Preprocesses the inputs for multimodal inference. Args: prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput. Returns: BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference. """ batch_size = len(prepare_list) sft_format = [] n_images = [] seq_lens = [] for prepare in prepare_list: n_images.append(len(prepare.num_image_tokens)) seq_lens.append(len(prepare)) input_token_max_len = max(seq_lens) max_n_images = max(1, max(n_images)) batched_input_ids = torch.full( (batch_size, input_token_max_len), self.pad_id ).long() # FIXME batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long() batched_pixel_values = torch.zeros( (batch_size, max_n_images, *self.image_processor.default_shape) ).float() batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool() batched_images_emb_mask = torch.zeros( (batch_size, max_n_images, self.num_image_tokens) ).bool() for i, prepare in enumerate(prepare_list): input_ids = prepare.input_ids seq_len = len(prepare) n_image = len(prepare.num_image_tokens) # left-padding batched_attention_mask[i, -seq_len:] = 1 batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids) batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id if n_image > 0: batched_pixel_values[i, :n_image] = prepare.pixel_values for j, n_image_tokens in enumerate(prepare.num_image_tokens): batched_images_emb_mask[i, j, :n_image_tokens] = True sft_format.append(prepare.sft_format) batched_prepares = BatchedVLChatProcessorOutput( input_ids=batched_input_ids, attention_mask=batched_attention_mask, pixel_values=batched_pixel_values, images_seq_mask=batched_images_seq_mask, images_emb_mask=batched_images_emb_mask, sft_format=sft_format, ) return batched_prepares class VLMImageProcessorConfig(PretrainedConfig): model_type = "deepseek_vlm" image_size: int min_size: int image_mean: Union[Tuple[float, float, float], List[float]] image_std: Union[Tuple[float, float, float], List[float]] rescale_factor: float do_normalize: bool def __init__( self, image_size: int, min_size: int = 14, image_mean: Union[Tuple[float, float, float], List[float]] = ( 0.48145466, 0.4578275, 0.40821073, ), image_std: Union[Tuple[float, float, float], List[float]] = ( 0.26862954, 0.26130258, 0.27577711, ), rescale_factor: float = 1.0 / 255.0, do_normalize: bool = True, **kwargs, ): self.image_size = image_size self.min_size = min_size self.image_mean = image_mean self.image_std = image_std self.rescale_factor = rescale_factor self.do_normalize = do_normalize super().__init__(**kwargs) register_processor(MultiModalityConfig, VLChatProcessor) register_image_processor(MultiModalityConfig, VLMImageProcessor)