import math import os from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import torch from PIL import Image, ImageOps from transformers import ( AutoProcessor, LlamaTokenizerFast, PretrainedConfig, ProcessorMixin, ) def select_best_resolution(image_size, candidate_resolutions): # used for cropping original_width, original_height = image_size best_fit = None max_effective_resolution = 0 min_wasted_resolution = float("inf") for width, height in candidate_resolutions: scale = min(width / original_width, height / original_height) downscaled_width, downscaled_height = int(original_width * scale), int( original_height * scale ) effective_resolution = min( downscaled_width * downscaled_height, original_width * original_height ) wasted_resolution = (width * height) - effective_resolution if effective_resolution > max_effective_resolution or ( effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution ): max_effective_resolution = effective_resolution min_wasted_resolution = wasted_resolution best_fit = (width, height) return best_fit class DictOutput(object): def keys(self): return self.__dict__.keys() def __getitem__(self, item): return self.__dict__[item] def __setitem__(self, key, value): self.__dict__[key] = value @dataclass class VLChatProcessorOutput(DictOutput): input_ids: torch.LongTensor target_ids: torch.LongTensor images: torch.Tensor images_seq_mask: torch.BoolTensor images_spatial_crop: torch.LongTensor def __len__(self): return len(self.input_ids) class ImageTransform(object): def __init__( self, mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), normalize: bool = True, ): self.mean = mean self.std = std self.normalize = normalize # only load torchvision.transforms when needed try: import torchvision.transforms as T # FIXME: add version check for gguf except ImportError as err: raise ImportError( "Please install torchvision via `pip install torchvision` to use Deepseek-VL2." ) from err transform_pipelines = [T.ToTensor()] if normalize: transform_pipelines.append(T.Normalize(mean, std)) self.transform = T.Compose(transform_pipelines) def __call__(self, pil_img: Image.Image): x = self.transform(pil_img) return x class DeepseekVLV2Processor(ProcessorMixin): tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") attributes = ["tokenizer"] def __init__( self, tokenizer: LlamaTokenizerFast, candidate_resolutions: Tuple[Tuple[int, int]], patch_size: int, downsample_ratio: int, image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5), normalize: bool = True, image_token: str = "", pad_token: str = "<|▁pad▁|>", add_special_token: bool = False, sft_format: str = "deepseek", mask_prompt: bool = True, ignore_id: int = -100, **kwargs, ): self.candidate_resolutions = candidate_resolutions self.image_size = candidate_resolutions[0][0] self.patch_size = patch_size self.image_mean = image_mean self.image_std = image_std self.normalize = normalize self.downsample_ratio = downsample_ratio self.image_transform = ImageTransform( mean=image_mean, std=image_std, normalize=normalize ) self.tokenizer = tokenizer # must set this,padding side with make a difference in batch inference self.tokenizer.padding_side = "left" # add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id' if tokenizer.pad_token is None: self.tokenizer.add_special_tokens({"pad_token": pad_token}) # add image token image_token_id = self.tokenizer.vocab.get(image_token) if image_token_id is None: special_tokens = [image_token] special_tokens_dict = {"additional_special_tokens": special_tokens} self.tokenizer.add_special_tokens(special_tokens_dict) self.image_token_id = self.tokenizer.vocab.get(image_token) # add five special tokens for grounding-related tasks # <|ref|>, <|/ref|>, <|det|>, <|/det|>, <|grounding|> special_tokens = ["<|ref|>", "<|/ref|>", "<|det|>", "<|/det|>", "<|grounding|>"] special_tokens_dict = {"additional_special_tokens": special_tokens} self.tokenizer.add_special_tokens(special_tokens_dict) # add special tokens for SFT data special_tokens = ["<|User|>", "<|Assistant|>"] special_tokens_dict = {"additional_special_tokens": special_tokens} self.tokenizer.add_special_tokens(special_tokens_dict) self.image_token = image_token self.pad_token = pad_token self.add_special_token = add_special_token self.sft_format = sft_format self.mask_prompt = mask_prompt self.ignore_id = ignore_id super().__init__( tokenizer, **kwargs, ) def format_messages_v2(self, messages, pil_images, max_req_input_len=-1): """play the role of format_messages_v2 and get_images_info in the last version""" tokenized_data = [] masked_tokenized_data = [] # labels images_list = [] images_seq_mask = [] images_spatial_crop = [] image_index = 0 image_token_cnt = messages.count(self.image_token) tokenized_str, images, seq_mask, spatial_crop = self.tokenize_with_images( messages, pil_images[image_index : image_index + image_token_cnt], bos=False, eos=True, cropping=len(pil_images) <= 2, max_req_input_len=max_req_input_len, ) image_index = image_token_cnt tokenized_data += tokenized_str if self.mask_prompt: masked_tokenized_data += [self.ignore_id] * len(tokenized_str) else: masked_tokenized_data += tokenized_str images_list += images images_seq_mask += seq_mask images_spatial_crop += spatial_crop assert len(tokenized_data) == len( images_seq_mask ), f"format_messages_v2: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}" return ( tokenized_data, masked_tokenized_data, images_list, images_seq_mask, images_spatial_crop, ) @property def bos_id(self): return self.tokenizer.bos_token_id @property def eos_id(self): return self.tokenizer.eos_token_id @property def pad_id(self): return self.tokenizer.pad_token_id def encode(self, text: str, bos: bool = True, eos: bool = False): t = self.tokenizer.encode(text, add_special_tokens=False) if bos: t = [self.bos_id] + t if eos: t = t + [self.eos_id] return t def decode(self, t: List[int], **kwargs) -> str: return self.tokenizer.decode(t, **kwargs) def process_one( self, prompt: str = None, conversations: List[Dict[str, str]] = None, images: List[Image.Image] = None, apply_sft_format: bool = False, inference_mode: bool = True, system_prompt: str = "", max_req_input_len: int = -1, **kwargs, ): """ Args: prompt (str): the formatted prompt; conversations (List[Dict]): conversations with a list of messages; images (List[ImageType]): the list of images; apply_sft_format (bool): if prompt is not None, then apply the SFT format to prompt; if conversations is not None, then it will always apply the SFT format to conversations; inference_mode (bool): if True, then remove the last eos token; system_prompt (str): the system prompt; **kwargs: Returns: outputs (BaseProcessorOutput): the output of the processor, - input_ids (torch.LongTensor): [N + image tokens] - target_ids (torch.LongTensor): [N + image tokens] - images (torch.FloatTensor): [n_images, 3, H, W] - image_id (int): the id of the image token - num_image_tokens (List[int]): the number of image tokens """ assert ( prompt is None or conversations is None ), "prompt and conversations cannot be used at the same time." ( tokenized_str, masked_tokenized_str, images_list, images_seq_mask, images_spatial_crop, ) = self.format_messages_v2(conversations, images, max_req_input_len) assert ( len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str) ), ( f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, " f"imags_seq_mask's length {len(images_seq_mask)}, are not equal" ) input_ids = torch.LongTensor(tokenized_str) target_ids = torch.LongTensor(masked_tokenized_str) images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) # set input_ids < 0 | input_ids == self.image_token_id as ignore_id target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = ( self.ignore_id ) input_ids[input_ids < 0] = self.pad_id if inference_mode: assert input_ids[-1] == self.eos_id input_ids = input_ids[:-1] target_ids = target_ids[:-1] images_seq_mask = images_seq_mask[:-1] if len(images_list) == 0: images = torch.zeros((1, 3, self.image_size, self.image_size)) images_spatial_crop = torch.zeros((1, 2), dtype=torch.long) else: images = torch.stack(images_list, dim=0) images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) prepare = VLChatProcessorOutput( input_ids=input_ids, target_ids=target_ids, images=images, images_seq_mask=images_seq_mask, images_spatial_crop=images_spatial_crop, ) return prepare def __call__( self, *, prompt: str = None, conversations: List[Dict[str, str]] = None, images: List[Image.Image] = None, apply_sft_format: bool = False, inference_mode: bool = True, system_prompt: str = "", max_req_input_len: int = -1, **kwargs, ): prepare = self.process_one( prompt=prompt, conversations=conversations, images=images, apply_sft_format=apply_sft_format, inference_mode=inference_mode, system_prompt=system_prompt, max_req_input_len=max_req_input_len, ) return prepare def find_all_indices(self, messages, target_value): indices = [] for index, item in enumerate(messages): if item == target_value: indices.append(index) return indices def tokenize_with_images( self, conversation: str, images: List[Image.Image], bos: bool = True, eos: bool = True, cropping: bool = True, max_req_input_len: int = -1, ): """Tokenize text with tags.""" images_list, images_seq_mask, images_spatial_crop = [], [], [] text_splits = conversation.split(self.image_token) tokenized_str = [] for text_sep, image in zip(text_splits, images): """encode text_sep""" tokenized_sep = self.encode(text_sep, bos=False, eos=False) tokenized_str += tokenized_sep images_seq_mask += [False] * len(tokenized_sep) """select best resolution for anyres""" if cropping: best_width, best_height = select_best_resolution( image.size, self.candidate_resolutions ) else: best_width, best_height = self.image_size, self.image_size # print(image.size, (best_width, best_height)) # check the select_best_resolutions func """process the global view""" global_view = ImageOps.pad( image, (self.image_size, self.image_size), color=tuple(int(x * 255) for x in self.image_transform.mean), ) images_list.append(self.image_transform(global_view)) """process the local views""" local_view = ImageOps.pad( image, (best_width, best_height), color=tuple(int(x * 255) for x in self.image_transform.mean), ) for i in range(0, best_height, self.image_size): for j in range(0, best_width, self.image_size): images_list.append( self.image_transform( local_view.crop( (j, i, j + self.image_size, i + self.image_size) ) ) ) """record height / width crop num""" num_width_tiles, num_height_tiles = ( best_width // self.image_size, best_height // self.image_size, ) images_spatial_crop.append([num_width_tiles, num_height_tiles]) """add image tokens""" h = w = math.ceil( (self.image_size // self.patch_size) / self.downsample_ratio ) # global views tokens h * (w + 1), 1 is for line seperator tokenized_image = [self.image_token_id] * h * (w + 1) # add a seperator between global and local views tokenized_image += [self.image_token_id] # local views tokens, (num_height_tiles * h) * (num_width_tiles * w + 1) tokenized_image += ( [self.image_token_id] * (num_height_tiles * h) * (num_width_tiles * w + 1) ) tokenized_str += tokenized_image images_seq_mask += [True] * len(tokenized_image) # print(width_crop_num, height_crop_num, len(tokenized_image)) # test the correctness of the number of image-related tokens """process the last text split""" tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False) # deal with video, limit with request len if max_req_input_len > -1: if max_req_input_len < len(tokenized_sep) + len(tokenized_str) - 1: rest = max_req_input_len - len(tokenized_sep) - 1 - 1024 tokenized_str = tokenized_str[:rest] images_seq_mask = images_seq_mask[:rest] tokenized_str += tokenized_sep images_seq_mask += [False] * len(tokenized_sep) """add the bos and eos tokens""" if bos: tokenized_str = [self.bos_id] + tokenized_str images_seq_mask = [False] + images_seq_mask if eos: tokenized_str = tokenized_str + [self.eos_id] images_seq_mask = images_seq_mask + [False] assert len(tokenized_str) == len( images_seq_mask ), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}" return tokenized_str, images_list, images_seq_mask, images_spatial_crop class DeepseekVL2VisionEncoderConfig(PretrainedConfig): model_type: str = "vision" model_name: str = "siglip_large_patch16_384" image_size: int = 384 patch_size: int = 16 width: int = 1024 layers: int = 24 heads: int = 16 mlp_ratio: int = 4 global_pool: str = "map" ignore_head: bool = True class_token: bool = False num_classes: int = 0 use_checkpoint: bool = False weight_init: str = "skip" deterministic: bool = False num_recomputing_layers: int = 0 def __init__( self, model_name: str = "siglip_large_patch16_384", image_size: int = 384, patch_size: int = 16, width: int = 1024, layers: int = 24, heads: int = 16, mlp_ratio: int = 4, global_pool: str = "map", ignore_head: bool = True, class_token: bool = False, num_classes: int = 0, use_checkpoint: bool = False, **kwargs, ): self.model_name = model_name self.image_size = image_size self.patch_size = patch_size self.width = width self.layers = layers self.heads = heads self.mlp_ratio = mlp_ratio self.global_pool = global_pool self.ignore_head = ignore_head self.class_token = class_token self.num_classes = num_classes self.use_checkpoint = use_checkpoint super().__init__(**kwargs) class DeepseekVL2MlpProjectorConfig(PretrainedConfig): model_type = "mlp_projector" projector_type: str = "downsample_mlp_gelu" input_dim: int = 1152 n_embed: int = 2048 depth: int = 2 mlp_ratio: int = 1 downsample_ratio: int = 2 token_pooling: bool = False def __init__( self, projector_type: str = "downsample_mlp_gelu", input_dim: int = 1152, n_embed: int = 2048, depth: int = 2, mlp_ratio: int = 1, downsample_ratio: int = 2, **kwargs, ): self.projector_type = projector_type self.input_dim = input_dim self.n_embed = n_embed self.depth = depth self.mlp_ratio = mlp_ratio self.downsample_ratio = downsample_ratio super().__init__(**kwargs) class DeepseekV2Config(PretrainedConfig): model_type = "deepseek_v2" keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, vocab_size=102400, hidden_size=4096, intermediate_size=11008, moe_intermediate_size=1407, num_hidden_layers=30, num_attention_heads=32, num_key_value_heads=32, n_shared_experts=None, n_routed_experts=None, ep_size=1, routed_scaling_factor=1.0, kv_lora_rank=512, q_lora_rank=1536, qk_rope_head_dim=64, v_head_dim=128, qk_nope_head_dim=128, topk_method="gready", n_group=None, topk_group=None, num_experts_per_tok=None, moe_layer_freq=1, first_k_dense_replace=0, norm_topk_prob=False, scoring_func="softmax", aux_loss_alpha=0.001, seq_aux=True, hidden_act="silu", max_position_embeddings=2048, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, pad_token_id=None, bos_token_id=100000, eos_token_id=100001, pretraining_tp=1, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, attention_dropout=0.0, use_mla=True, **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.moe_intermediate_size = moe_intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.n_shared_experts = n_shared_experts self.n_routed_experts = n_routed_experts self.ep_size = ep_size self.routed_scaling_factor = routed_scaling_factor self.kv_lora_rank = kv_lora_rank self.q_lora_rank = q_lora_rank self.qk_rope_head_dim = qk_rope_head_dim self.v_head_dim = v_head_dim self.qk_nope_head_dim = qk_nope_head_dim self.topk_method = topk_method self.n_group = n_group self.topk_group = topk_group self.num_experts_per_tok = num_experts_per_tok self.moe_layer_freq = moe_layer_freq self.first_k_dense_replace = first_k_dense_replace self.norm_topk_prob = norm_topk_prob self.scoring_func = scoring_func self.aux_loss_alpha = aux_loss_alpha self.seq_aux = seq_aux # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = float(rms_norm_eps) self.pretraining_tp = pretraining_tp self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.use_mla = use_mla super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, ) class DeepseekVL2Config(PretrainedConfig): model_type = "deepseek_vl_v2" vision_config: DeepseekVL2VisionEncoderConfig projector_config: DeepseekVL2MlpProjectorConfig language_config: DeepseekV2Config tile_tag: str = "2D" global_view_pos: str = "head" candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384),) def __init__( self, tile_tag: str = "tile_tag", global_view_pos: str = "head", candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384),), **kwargs, ): super().__init__(**kwargs) vision_config = kwargs.get("vision_config", {}) self.vision_config = DeepseekVL2VisionEncoderConfig(**vision_config) projector_config = kwargs.get("projector_config", {}) self.projector_config = DeepseekVL2MlpProjectorConfig(**projector_config) language_config = kwargs.get("language_config", {}) if isinstance(language_config, DeepseekV2Config): self.language_config = language_config else: self.language_config = DeepseekV2Config(**language_config) self.tile_tag = tile_tag self.global_view_pos = global_view_pos self.candidate_resolutions = candidate_resolutions self.architectures = ["DeepseekVL2ForCausalLM"] AutoProcessor.register(DeepseekVL2Config, DeepseekVLV2Processor)