# Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Inference-only LLaVa model compatible with HuggingFace weights.""" import math import re from typing import Iterable, List, Optional, Tuple import numpy as np import torch from torch import nn from transformers import ( CLIPVisionConfig, CLIPVisionModel, LlavaConfig, MistralConfig, Qwen2Config, SiglipVisionModel, ) from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.managers.schedule_batch import MultimodalInputs from sglang.srt.mm_utils import ( get_anyres_image_grid_shape, unpad_image, unpad_image_shape, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llama import LlamaForCausalLM from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM from sglang.srt.utils import add_prefix class LlavaBaseForCausalLM(nn.Module): def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values # hardcode for spatial_unpad + anyres if image_inputs.modalities is not None and ( "multi-images" in image_inputs.modalities or "video" in image_inputs.modalities ): image_aspect_ratio = "pad" else: image_aspect_ratio = "anyres" offset_list = [] image_inputs.image_pad_len = [] for image_idx, image_s in enumerate(image_sizes): if len(image_sizes) > 16: # 2x2 pooling with stride 2 new_image_feature_len = ( math.ceil(self.image_size / self.patch_size / 2) ** 2 ) else: new_image_feature_len = self.image_feature_len # multiimage height = width = self.num_patches_per_side if "anyres" in image_aspect_ratio: num_patch_width, num_patch_height = get_anyres_image_grid_shape( image_s, self.image_grid_pinpoints, self.vision_tower.config.image_size, ) h = num_patch_height * height w = num_patch_width * width new_h, new_w = unpad_image_shape(h, w, image_s) if "anyres_max" in self.config.image_aspect_ratio: matched_anyres_max_num_patches = re.match( r"anyres_max_(\d+)", self.config.image_aspect_ratio ) if matched_anyres_max_num_patches: max_num_patches = int(matched_anyres_max_num_patches.group(1)) # times = math.sqrt(h * w / (max_num_patches * unit**2)) times = math.sqrt( new_h * new_w / (max_num_patches * self.image_feature_len) ) if times > 1.1: new_h = int(new_h // times) new_w = int(new_w // times) new_image_feature_len += new_h * (new_w + 1) try: offset = input_ids.index(self.config.image_token_index) except ValueError: offset = 0 # old_len + pad_len - 1, because we need to remove image_token_id input_ids = ( input_ids[:offset] + [pad_values[image_idx]] * new_image_feature_len + input_ids[offset + 1 :] ) offset_list.append(offset) image_inputs.image_pad_len.append(new_image_feature_len) image_inputs.image_offsets = offset_list return input_ids def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor: image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) # NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated. selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer] if self.vision_feature_select_strategy in ["default", "patch"]: selected_image_feature = selected_image_feature[:, 1:] elif self.vision_feature_select_strategy == "full": selected_image_feature = selected_image_feature else: raise ValueError( f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" ) image_features = self.multi_modal_projector(selected_image_feature) return image_features @torch.no_grad() def forward( self, input_ids: torch.LongTensor, positions: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: image_inputs = forward_batch.mm_inputs if forward_batch.forward_mode.is_extend(): # Clamp input ids. This is because the input_ids for the image tokens are # filled with the hash values of the image for the prefix matching in the radix attention. # There values are useless because their embeddings will be replaced by vision embeddings anyway. input_ids.clamp_(min=0, max=self.config.vocab_size - 1) # Embed text inputs input_embeds = self.language_model.model.embed_tokens(input_ids) # Got List[List[str]] extend it to List[str] # The length of the List should be equal to batch size modalities_list = [] max_image_offset = [] for im in image_inputs: if im and im.modalities is not None: modalities_list.extend(im.modalities) if im and im.image_offsets: max_image_offset.append( np.max(np.array(im.image_offsets) + np.array(im.image_pad_len)) ) else: max_image_offset.append(-1) start_positions = positions[forward_batch.extend_start_loc].cpu().numpy() need_vision = start_positions <= np.array(max_image_offset) if need_vision.any(): bs = forward_batch.batch_size pixel_values = [ image_inputs[i].pixel_values for i in range(bs) if need_vision[i] ] image_sizes = [ image_inputs[i].image_sizes for i in range(bs) if need_vision[i] ] ########## Encode Image ######## if pixel_values[0].ndim == 4: # llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images np.concatenate(pixel_values, axis=0) # ndim=4 concat_images = torch.tensor( np.concatenate(pixel_values, axis=0), device=self.vision_tower.device, ) image_features = self.encode_images(concat_images) split_sizes = [image.shape[0] for image in pixel_values] image_features = torch.split(image_features, split_sizes, dim=0) # hd image_features: BS, num_patch, 576, 4096 else: # normal pixel: BS, C=3, H=336, W=336 pixel_values = torch.tensor( np.array(pixel_values), device=self.vision_tower.device ) image_features = self.encode_images(pixel_values) # image_features: BS, 576, 4096 if self.mm_patch_merge_type.startswith("spatial"): new_image_features = [] height = width = self.num_patches_per_side for image_idx, image_feature in enumerate(image_features): if modalities_list[image_idx] == "image": image_aspect_ratio = ( self.config.image_aspect_ratio ) # single image elif ( modalities_list[image_idx] == "multi-images" or modalities_list[image_idx] == "video" ): image_aspect_ratio = "pad" # multi image # image_aspect_ratio = ( # "anyres" if len(image_sizes[image_idx]) == 1 else "pad" # ) if ( image_feature.shape[0] > 1 and "anyres" in image_aspect_ratio and modalities_list[image_idx] == "image" ): base_image_feature = image_feature[0] image_feature = image_feature[1:] assert height * width == base_image_feature.shape[0] if "anyres_max" in image_aspect_ratio: matched_anyres_max_num_patches = re.match( r"anyres_max_(\d+)", image_aspect_ratio ) if matched_anyres_max_num_patches: max_num_patches = int( matched_anyres_max_num_patches.group(1) ) if ( image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio ): vision_tower_image_size = self.image_size try: num_patch_width, num_patch_height = ( get_anyres_image_grid_shape( image_sizes[image_idx][0], self.config.image_grid_pinpoints, vision_tower_image_size, ) ) except Exception as e: print(f"Error: {e}") num_patch_width, num_patch_height = 2, 2 image_feature = image_feature.view( num_patch_height, num_patch_width, height, width, -1 ) else: image_feature = image_feature.view( 2, 2, height, width, -1 ) # ( # num_patch_width, # num_patch_height, # ) = get_anyres_image_grid_shape( # image_sizes[image_idx][0], # self.image_grid_pinpoints, # self.vision_tower.config.image_size, # ) # image_feature = image_feature.view( # num_patch_height, num_patch_width, height, width, -1 # ) if "unpad" in self.mm_patch_merge_type: unit = image_feature.shape[2] image_feature = image_feature.permute( 4, 0, 2, 1, 3 ).contiguous() image_feature = image_feature.flatten(1, 2).flatten( 2, 3 ) image_feature = unpad_image( image_feature, image_sizes[image_idx][0] ) if ( "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches ): c, h, w = image_feature.shape times = math.sqrt( h * w / (max_num_patches * unit**2) ) if times > 1.1: image_feature = image_feature[None] image_feature = nn.functional.interpolate( image_feature, [int(h // times), int(w // times)], mode="bilinear", )[0] image_feature = torch.cat( ( image_feature, self.language_model.model.image_newline[ :, None, None ].expand(*image_feature.shape[:-1], 1), ), dim=-1, ) image_feature = image_feature.flatten(1, 2).transpose( 0, 1 ) else: image_feature = image_feature.permute( 0, 2, 1, 3, 4 ).contiguous() image_feature = image_feature.flatten(0, 3) image_feature = torch.cat( (base_image_feature, image_feature), dim=0 ) image_feature = image_feature.unsqueeze(0) else: if modalities_list[image_idx] == "video": # video # 2x2 pooling num_of_frames = image_feature.shape[0] image_feature = image_feature.view( num_of_frames, height, width, -1 ) image_feature = image_feature.permute( 0, 3, 1, 2 ).contiguous() # N, C, H, W height, weight = image_feature.shape[2:] scaled_shape = [ math.ceil(height / 2), math.ceil(weight / 2), ] image_feature = nn.functional.interpolate( image_feature, size=scaled_shape, mode="bilinear" ) image_feature = ( image_feature.flatten(2) .transpose(1, 2) .contiguous() ) # N, C, H*W if "unpad" in self.mm_patch_merge_type: image_feature = torch.cat( ( image_feature, # Expand to (bs, 1, hidden_dim) and concat at the end of the image tokens self.language_model.model.image_newline[ None, None ].expand( image_feature.shape[0], 1, image_feature.shape[-1], ), ), dim=1, ) new_image_features.append(image_feature) image_features = new_image_features # Fill in the placeholder for the image extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() extend_seq_lens = forward_batch.extend_seq_lens.cpu().numpy() prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu pt = 0 for i in range(bs): if not need_vision[i]: continue start_idx = extend_start_loc_cpu[i] seq_len = extend_seq_lens[i] prefix_len = prefix_lens_cpu[i] # Multiple images for image_idx, image_offset in enumerate( image_inputs[i].image_offsets ): if ( image_offset + image_inputs[i].image_pad_len[image_idx] <= prefix_len ): continue if image_offset >= prefix_len + seq_len: break tmp_image_feature = image_features[pt][image_idx] pad_len = tmp_image_feature.shape[0] input_offset = image_offset - prefix_len left_idx = start_idx + input_offset right_idx = left_idx + pad_len assert right_idx > start_idx if input_offset < 0: left_idx = start_idx tmp_image_feature = tmp_image_feature[-input_offset:] if right_idx > start_idx + seq_len: tmp_image_feature = tmp_image_feature[ : start_idx + seq_len - right_idx ] right_idx = start_idx + seq_len try: input_embeds[left_idx:right_idx] = tmp_image_feature except RuntimeError as e: print(f"RuntimeError in image encoding: {e}") print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}") print( f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}" ) pt += 1 return self.language_model( input_ids, positions, forward_batch, input_embeds=input_embeds ) elif forward_batch.forward_mode.is_decode(): return self.language_model(input_ids, positions, forward_batch) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Load clip vision model by cfg['mm_vision_tower']: # huggingface_name or path_of_clip_relative_to_llava_model_dir # We put the initialization here instead of __init__ to allow it being reused by other subclasses. vision_path = self.config.mm_vision_tower if "clip" in vision_path: self.vision_tower = CLIPVisionModel.from_pretrained( vision_path, torch_dtype=torch.float16 ).cuda() elif "siglip" in vision_path: self.vision_tower = SiglipVisionModel.from_pretrained( vision_path, torch_dtype=torch.float16 ).cuda() # Siglip needs all feature tokens self.config.mm_vision_select_feature = "full" self.vision_tower.eval() self.vision_feature_layer = self.config.mm_vision_select_layer self.vision_feature_select_strategy = self.config.mm_vision_select_feature self.image_size = self.vision_tower.config.image_size self.patch_size = self.vision_tower.config.patch_size self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat") self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None) self.image_feature_len = int((self.image_size // self.patch_size) ** 2) if ( self.vision_feature_select_strategy == "patch" or self.vision_feature_select_strategy == "full" ): pass elif self.vision_feature_select_strategy == "cls_patch": self.image_feature_len += 1 else: raise ValueError(f"Unexpected select feature: {self.select_feature}") # load mm_projector projector_weights = { "model.mm_projector.0": "multi_modal_projector.linear_1", "model.mm_projector.2": "multi_modal_projector.linear_2", "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). "model.image_newline": "language_model.model.image_newline", } params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "projector" in name or "vision_tower" in name or "image_newline" in name: for weight_name, param_name in projector_weights.items(): if weight_name in name: name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) else: self.language_model.load_weights([(name, loaded_weight)]) @property def num_patches_per_side(self): return self.image_size // self.patch_size class LlavaLlamaForCausalLM(LlavaBaseForCausalLM): def __init__( self, config: LlavaConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.config = config self.vision_tower = None self.config.vision_config.hidden_size = config.mm_hidden_size self.config.text_config.hidden_size = config.hidden_size self.multi_modal_projector = LlavaMultiModalProjector(config) self.language_model = LlamaForCausalLM( config, quant_config=quant_config, prefix=add_prefix("language_model", prefix), ) if "unpad" in getattr(config, "mm_patch_merge_type", ""): self.language_model.model.image_newline = nn.Parameter( torch.empty(config.text_config.hidden_size, dtype=torch.float16) ) class LlavaQwenForCausalLM(LlavaBaseForCausalLM): def __init__( self, config: LlavaConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.config = config self.vision_tower = None if getattr(self.config, "vision_config", None) is None: self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower) if getattr(self.config, "text_config", None) is None: self.config.text_config = Qwen2Config(self.config._name_or_path) self.config.vision_config.hidden_size = config.mm_hidden_size self.config.text_config.hidden_size = config.hidden_size if getattr(self.config, "projector_hidden_act", None) is None: self.config.projector_hidden_act = "gelu" if getattr(self.config, "image_token_index", None) is None: self.config.image_token_index = 151646 self.multi_modal_projector = LlavaMultiModalProjector(config) self.language_model = Qwen2ForCausalLM( config, quant_config=quant_config, prefix=add_prefix("language_model", prefix), ) if "unpad" in getattr(config, "mm_patch_merge_type", ""): self.language_model.model.image_newline = nn.Parameter( torch.empty(config.text_config.hidden_size, dtype=torch.float16) ) class LlavaMistralForCausalLM(LlavaBaseForCausalLM): def __init__( self, config: LlavaConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.config = config self.vision_tower = None if getattr(self.config, "vision_config", None) is None: self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower) if getattr(self.config, "text_config", None) is None: self.config.text_config = MistralConfig(self.config._name_or_path) self.config.vision_config.hidden_size = config.mm_hidden_size self.config.text_config.hidden_size = config.hidden_size if getattr(self.config, "projector_hidden_act", None) is None: self.config.projector_hidden_act = "gelu" if getattr(self.config, "image_token_index", None) is None: self.config.image_token_index = 32000 self.multi_modal_projector = LlavaMultiModalProjector(config) self.language_model = MistralForCausalLM( config, quant_config=quant_config, prefix=add_prefix("language_model", prefix), ) if "unpad" in getattr(config, "mm_patch_merge_type", ""): self.language_model.model.image_newline = nn.Parameter( torch.empty(config.text_config.hidden_size, dtype=torch.float16) ) EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]