# Copied and adapted from: https://huggingface.co/openbmb/MiniCPM-o-2_6/blob/main/modeling_minicpmo.py # Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Inference-only MiniCPM-o model compatible with HuggingFace weights.""" import math from dataclasses import dataclass from typing import Any, Iterable, List, Literal, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F import torch.nn.utils.parametrize as P import torch.types from torch import nn from torch.nn.utils import weight_norm from tqdm import tqdm from transformers import LlamaConfig, LlamaModel, PretrainedConfig, PreTrainedModel from transformers.activations import ACT2FN from transformers.cache_utils import DynamicCache, EncoderDecoderCache from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput from transformers.models.whisper.modeling_whisper import ( WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder, ) from sglang.srt.layers.quantization import QuantizationConfig from sglang.srt.managers.mm_utils import ( MultiModalityDataPaddingPatternTokenPairs, embed_mm_inputs, get_multimodal_data_bounds, ) from sglang.srt.managers.schedule_batch import MultimodalInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.utils import set_default_torch_dtype from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.minicpmv import ( Idefics2VisionTransformer, MiniCPMVBaseModel, Resampler2_5, ) from sglang.srt.models.qwen2 import Qwen2ForCausalLM from sglang.srt.utils import logger try: from transformers import LogitsWarper from vector_quantize_pytorch import GroupedResidualFSQ from vocos import Vocos from vocos.pretrained import instantiate_class _tts_deps = True except: LogitsWarper = None _tts_deps = False def apply_spk_emb( input_ids: torch.Tensor = None, spk_emb: torch.Tensor = None, input_embeds: torch.Tensor = None, spk_emb_token_id: int = 0, num_spk_embs: int = 1, ): """ Replace consecutive `num_spk_embs` speaker embedding placeholders in input_embeds with pre-prepared speaker embeddings. This is an in-place replacement, no new tensor is created, so no value is returned. Args: input_ids (torch.Tensor): Input ID tensor, shape [batch_size, seq_len_max] spk_emb (torch.Tensor): Speaker embedding tensor, shape [batch_size, num_spk_emb, hidden_dim] input_embeds (torch.Tensor): Input embedding tensor, shape [batch_size, seq_len_max, hidden_dim] spk_emb_token_id (int): ID of the speaker embedding token num_spk_embs (int): Number of speaker embeddings Returns: None """ batch_size = input_ids.shape[0] for idx in range(batch_size): input_ids_ = input_ids[idx] # [seq_len_max] spk_emb_ = spk_emb[idx] # [num_spk_emb] mask_ = input_ids_ == spk_emb_token_id # [batch_size, seq_len_max] nonzero_position_idx = mask_.nonzero(as_tuple=False) # [num_spk_emb, 1] assert nonzero_position_idx.shape[0] == num_spk_embs begin_idx = nonzero_position_idx.min() end_idx = nonzero_position_idx.max() input_embeds[idx, begin_idx : end_idx + 1, :] = spk_emb_ return @dataclass class ConditionalChatTTSGenerationOutput(ModelOutput): """ Output class for ConditionalChatTTS generation. Args: new_ids (torch.LongTensor): Newly generated audio code sequence, shape (batch_size, sequence_length, num_vq). audio_input_ids (torch.LongTensor): Updated input IDs including condition and generated audio codes, shape (batch_size, full_sequence_length, num_vq). past_key_values (Tuple[Tuple[torch.FloatTensor]]): Tuple containing pre-computed keys and values used for attention mechanism. Each element has shape (batch_size, num_heads, sequence_length, embed_size_per_head). finished (bool): Boolean indicating whether generation is complete. """ new_ids: torch.LongTensor = None audio_input_ids: torch.LongTensor = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None finished: bool = None def make_streaming_chunk_mask_generation( inputs_embeds: torch.Tensor, past_seen_tokens: int, streaming_tts_text_mask: torch.Tensor, streaming_reserved_length: int = 300, streaming_audio_chunk_size: int = 50, streaming_text_chunk_size: int = 10, num_spk_emb: int = 1, use_spk_emb: bool = True, ) -> torch.Tensor: """ In streaming audio generation, determine which `text` positions the TTS model can attend to when generating each chunk of `audio` tokens. This function creates a mask that allows the model to attend to a specific chunk of text tokens when generating each chunk of audio tokens, enabling streaming TTS generation. Args: inputs_embeds (torch.Tensor): Input embeddings tensor. past_seen_tokens (int): Number of tokens already seen by the model. streaming_tts_text_mask (torch.Tensor): Mask for the text tokens. streaming_reserved_length (int, optional): Number of reserved tokens for streaming. Defaults to 300. streaming_text_chunk_size (int, optional): Size of each text chunk. Defaults to 7. Returns: torch.Tensor: Causal mask for streaming TTS generation, shape is [batch_size=1, 1, seq_len=1, past_seen_tokens+1] Raises: AssertionError: If the batch size is not 1 (only supports batch size of 1 for inference). """ assert inputs_embeds.shape[0] == 1 dtype = inputs_embeds.dtype device = inputs_embeds.device min_dtype = torch.finfo(dtype).min # Add `1` to the past seen tokens to account for new `tokens` during `generate` causal_mask = torch.full( (1, past_seen_tokens + inputs_embeds.shape[1]), fill_value=0, dtype=dtype, device=device, ) # Calculate the start of invisible text tokens invisible_text_tokens_start = ( min( math.ceil( (past_seen_tokens - streaming_reserved_length) / streaming_audio_chunk_size ) * streaming_text_chunk_size, streaming_reserved_length, ) + 1 + num_spk_emb * use_spk_emb ) # Add 1 for [Stts] and N for [spk_emb] tokens if `use_spk_emb` is True invisible_text_tokens_end = ( streaming_reserved_length + 1 + num_spk_emb * use_spk_emb + 1 ) # Add 1 for [Ptts] (aka `audio_bos_token_id`) # Set invisible text tokens to min_dtype (effectively -inf) causal_mask[0, invisible_text_tokens_start:invisible_text_tokens_end] = min_dtype # Mask padding positions in the text mask causal_mask[ 0, 0 : 1 + num_spk_emb * use_spk_emb + streaming_reserved_length + 1 ].masked_fill_(streaming_tts_text_mask == 0, min_dtype) # Add extra dimensions for batch and heads causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) return causal_mask # Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py` class ConvNeXtBlock(nn.Module): def __init__( self, dim: int, intermediate_dim: int, kernel: int, dilation: int, layer_scale_init_value: float = 1e-6, ): # ConvNeXt Block copied from Vocos. super().__init__() self.dwconv = nn.Conv1d( dim, dim, kernel_size=kernel, padding=dilation * (kernel // 2), dilation=dilation, groups=dim, ) self.norm = nn.LayerNorm(dim, eps=1e-6) self.pwconv1 = nn.Linear(dim, intermediate_dim) self.act = nn.GELU() self.pwconv2 = nn.Linear(intermediate_dim, dim) self.coef = ( nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) if layer_scale_init_value > 0 else None ) def forward(self, x: torch.Tensor, cond=None) -> torch.Tensor: residual = x y = self.dwconv(x) y.transpose_(1, 2) # (B, C, T) -> (B, T, C) x = self.norm(y) del y y = self.pwconv1(x) del x x = self.act(y) del y y = self.pwconv2(x) del x if self.coef is not None: y *= self.coef y.transpose_(1, 2) # (B, T, C) -> (B, C, T) x = y + residual del y return x # Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py` class DVAEDecoder(nn.Module): def __init__( self, idim: int, odim: int, n_layer=12, bn_dim=64, hidden=256, kernel=7, dilation=2, up=False, ): super().__init__() self.up = up self.conv_in = nn.Sequential( nn.Conv1d(idim, bn_dim, 3, 1, 1), nn.GELU(), nn.Conv1d(bn_dim, hidden, 3, 1, 1), ) self.decoder_block = nn.ModuleList( [ ConvNeXtBlock( hidden, hidden * 4, kernel, dilation, ) for _ in range(n_layer) ] ) self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False) def forward(self, x: torch.Tensor, conditioning=None) -> torch.Tensor: # B, C, T y = self.conv_in(x) del x for f in self.decoder_block: y = f(y, conditioning) x = self.conv_out(y) del y return x # Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py` class GFSQ(nn.Module): def __init__( self, dim: int, levels: List[int], G: int, R: int, eps=1e-5, transpose=True, ): super(GFSQ, self).__init__() self.quantizer = GroupedResidualFSQ( dim=dim, levels=list(levels), num_quantizers=R, groups=G, ) self.n_ind = math.prod(levels) self.eps = eps self.transpose = transpose self.G = G self.R = R def _embed(self, x: torch.Tensor): if self.transpose: x = x.transpose(1, 2) x = x.view(x.size(0), x.size(1), self.G, self.R).permute(2, 0, 1, 3) feat = self.quantizer.get_output_from_indices(x) return feat.transpose_(1, 2) if self.transpose else feat def __call__(self, x: torch.Tensor) -> torch.Tensor: return super().__call__(x) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.transpose: x.transpose_(1, 2) _, ind = self.quantizer(x) ind = ind.permute(1, 2, 0, 3).contiguous() ind = ind.view(ind.size(0), ind.size(1), -1) return ind.transpose_(1, 2) if self.transpose else ind # Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py` class DVAE(nn.Module): def __init__( self, ): super().__init__() coef = torch.rand(100) self.coef = nn.Parameter(coef.unsqueeze(0).unsqueeze_(2)) self.downsample_conv = nn.Sequential( nn.Conv1d(100, 512, 3, 1, 1), nn.GELU(), nn.Conv1d(512, 512, 4, 2, 1), nn.GELU(), ) self.encoder = DVAEDecoder( idim=512, odim=1024, hidden=256, n_layer=12, bn_dim=128, ) self.decoder = DVAEDecoder( idim=512, odim=512, hidden=256, n_layer=12, bn_dim=128, ) self.out_conv = nn.Conv1d(512, 100, 3, 1, 1, bias=False) self.vq_layer = GFSQ( dim=1024, levels=(5, 5, 5, 5), G=2, R=2, ) @torch.inference_mode() def forward( self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode" ) -> torch.Tensor: if mode == "encode" and hasattr(self, "encoder") and self.vq_layer is not None: mel = inp.clone() x: torch.Tensor = self.downsample_conv( torch.div(mel, self.coef.view(100, 1).expand(mel.shape), out=mel), ).unsqueeze_(0) del mel x = self.encoder(x) ind = self.vq_layer(x) del x return ind if self.vq_layer is not None: vq_feats = self.vq_layer._embed(inp) else: vq_feats = inp vq_feats = ( vq_feats.view( (vq_feats.size(0), 2, vq_feats.size(1) // 2, vq_feats.size(2)), ) .permute(0, 2, 3, 1) .flatten(2) ) dec_out = self.out_conv( self.decoder( x=vq_feats, ), ) del vq_feats return torch.mul(dec_out, self.coef, out=dec_out) # Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/processors.py` class CustomRepetitionPenaltyLogitsProcessorRepeat: def __init__(self, penalty: float, max_input_ids: int, past_window: int): if not isinstance(penalty, float) or not (penalty > 0): raise ValueError( f"`penalty` has to be a strictly positive float, but is {penalty}" ) self.penalty = penalty self.max_input_ids = max_input_ids self.past_window = past_window def __call__( self, input_ids: torch.LongTensor, scores: torch.FloatTensor ) -> torch.FloatTensor: if input_ids.size(1) > self.past_window: input_ids = input_ids.narrow(1, -self.past_window, self.past_window) freq = F.one_hot(input_ids, scores.size(1)).sum(1) if freq.size(0) > self.max_input_ids: freq.narrow( 0, self.max_input_ids, freq.size(0) - self.max_input_ids ).zero_() alpha = torch.pow(self.penalty, freq) scores = scores.contiguous() inp = scores.multiply(alpha) oth = scores.divide(alpha) con = scores < 0 out = torch.where(con, inp, oth) del inp, oth, scores, con, alpha return out class ConditionalChatTTS(PreTrainedModel): """A conditional text-to-speech model that can generate speech from text with speaker conditioning. This model extends PreTrainedModel to provide text-to-speech capabilities with: - LLM hidden state conditioning - Streaming generation The model uses a transformer architecture with LLM hidden states and can operate in both streaming and non-streaming modes for flexible deployment. The model process sequence in the following format: | text bos token | LLM embedding projected to tts embedding space | text tokens (fixed length, reserved for future tokens) | audio bos token | audio tokens (audio token length is not fixed)| audio eos token | The format is designed to support LLM-conditioned streaming audio generation. Usage: To support streaming generation, two global variables should be maintained outside of the model. 1. `audio_input_ids`: stores *discrete* audio codes. It is a tensor with shape [1, sequence length+1, num_vq]. 2. `past_key_values`: stores the KV cache for both text tokens and audio codes. It is a list of tuples, each tuple contains two tensors with shape [1, num_attention_heads, sequence length, hidden_size // num_attention_heads] where `num_vq` is the number of audio codebooks, in default setting, it is `4`. 1. Create an empty `past_key_values` with ```python initial_kv_cache_length = 1 + model.num_spk_embs + model.streaming_text_reserved_len # where `1` denotes the `bos` token dtype = model.emb_text.weight.dtype device = model.emb_text.weight.device past_key_values = [ ( torch.zeros(1, model.config.num_attention_heads, initial_kv_cache_length, model.config.hidden_size // model.config.num_attention_heads, dtype=dtype, device=device), torch.zeros(1, model.config.num_attention_heads, initial_kv_cache_length, model.config.hidden_size // model.config.num_attention_heads, dtype=dtype, device=device) ) for _ in range(model.config.num_hidden_layers) ] 2. At the same time, create an empty `audio_input_ids` with shape [1, sequence length, num_vq], `num_vq` denotes multiple layer audio codebooks. But here we also include text tokens in the sequence, but they will be zeros, and will not be used, just a placeholder. ```python initial_audio_input_ids_length = 1 + model.num_spk_embs + model.streaming_text_reserved_len + 1 # [bos token, speaker embeddings, text tokens, audio bos token] audio_input_ids = torch.zeros(batch_size=1, initial_audio_input_ids_length, model.num_vq) ``` 2. Prefill some text tokens to TTS model (for example, 10 tokens) using `prefill_text` method. ```python outputs = llm.generate(**kwargs) llm_tokens = some_function_to_extract_llm_tokens(outputs) lm_spk_emb_last_hidden_states = some_function_to_extract_lm_spk_emb_last_hidden_states(outputs) tts_text_input_ids = tts_tokenizer.encode(llm_tokenizer.decode(llm_tokens)) # here assume we are prefilling text token 0 to text token 9 (included), totally 10 tokens. begin = 0 end = 9+1 position_ids = torch.arange(begin, end, dtype=torch.long, device=device) past_key_values = model.prefill_text( input_ids=tts_text_input_ids, position_ids=position_ids, past_key_values=past_key_values, lm_spk_emb_last_hidden_states=lm_spk_emb_last_hidden_states, ) ``` 3. Make a `streaming_tts_text_mask` to denote which position contains valid text tokens, similar to `attention_mask` in standard causal attention. ```python streaming_tts_text_mask = torch.zeros(model.streaming_reserved_length) streaming_tts_text_mask[0:end] = 1 # denotes these post ``` 3. Generate audio codes using `generate` method. ```python outputs = model.generate( input_ids=audio_input_ids, past_key_values=past_key_values, streaming_tts_text_mask=streaming_tts_text_mask, max_new_token=50, ) # update past_key_values and input_ids past_key_values = outputs.past_key_values audio_input_ids = outputs.input_ids ``` The `past_key_values` is extended by `max_new_token=50`, and `audio_input_ids` is also extended by `max_new_token=50` after `generate` calling. 4. Notice that after prefilling `10` text tokens, the model can generate up to `50` audio tokens, if you want to generate more audio tokens, you need to prefill next `10` text tokens. And it is okay to only generate `25` audio tokens for faster initial response. 5. Repeat steps `2,3,4` as needed in your streaming audio generation cases, but ensure usage complies with the following guidelines discussed above. """ config_class = PretrainedConfig _no_split_modules = [] def __init__(self, config: PretrainedConfig): super().__init__(config) self.use_speaker_embedding = config.use_speaker_embedding self.use_llm_hidden_state = config.use_llm_hidden_state self.num_spk_embs = config.num_spk_embs self.spk_emb_token_id = config.spk_emb_token_id self.use_text = config.use_text self.streaming = config.streaming self.streaming_text_chunk_size = config.streaming_text_chunk_size self.streaming_audio_chunk_size = config.streaming_audio_chunk_size self.streaming_text_reserved_len = config.streaming_text_reserved_len self.audio_bos_token_id = config.audio_bos_token_id self.num_mel_bins = config.num_mel_bins self.num_vq = config.num_vq self.num_audio_tokens = config.num_audio_tokens self.top_p = config.top_p self.top_k = config.top_k self.repetition_penalty = config.repetition_penalty if self.config.use_mlp: self.projector = MultiModalProjector(config.llm_dim, config.hidden_size) else: self.projector = nn.Linear(config.llm_dim, config.hidden_size, bias=False) self.emb_code = nn.ModuleList( [ nn.Embedding(config.num_audio_tokens, config.hidden_size) for _ in range(config.num_vq) ] ) self.emb_text = nn.Embedding(config.num_text_tokens, config.hidden_size) self.head_code = nn.ModuleList( [ weight_norm( nn.Linear(config.hidden_size, config.num_audio_tokens, bias=False), name="weight", ) for _ in range(config.num_vq) ] ) dvae = DVAE() self.dvae = dvae model_config = LlamaConfig( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, num_attention_heads=config.num_attention_heads, num_hidden_layers=config.num_hidden_layers, max_position_embeddings=config.max_position_embeddings, attn_implementation=config.attn_implementation, ) model = LlamaModel(model_config) self.model = model @torch.inference_mode() def merge_inputs_embeds( self, input_ids: torch.Tensor, lm_spk_emb_last_hidden_states: Optional[torch.Tensor] = None, ): """Merge `input_ids` and `lm_spk_emb_last_hidden_states` to `inputs_embeds`. Args: input_ids (torch.Tensor): Input token IDs. lm_spk_emb_last_hidden_states (Optional[torch.Tensor], optional): Last hidden states of speaker embeddings from the language model. Defaults to None. Raises: NotImplementedError: If speaker embedding is not used and language model hidden states are not implemented. Returns: torch.Tensor: Prepared input embeddings for the model. """ assert input_ids.shape[0] == 1 # Embed input_ids to input_embeds inputs_embeds = self.emb_text(input_ids) # Inject speaker embedding to input_embeds if it exists if self.use_speaker_embedding: spk_emb_mask = input_ids == self.spk_emb_token_id if spk_emb_mask.any(): assert lm_spk_emb_last_hidden_states is not None # Project spk emb to tts hidden size first, [batch_size, num_spk_emb, llm_dim] -> [batch_size, num_spk_emb, self.hidden_size] lm_spk_emb_last_hidden_states = lm_spk_emb_last_hidden_states.to( self.projector.linear1.weight.dtype ) projected_spk_emb = self.projector(lm_spk_emb_last_hidden_states) projected_spk_emb = F.normalize(projected_spk_emb, p=2, dim=-1) apply_spk_emb( input_ids=input_ids, spk_emb=projected_spk_emb, input_embeds=inputs_embeds, spk_emb_token_id=self.spk_emb_token_id, num_spk_embs=self.num_spk_embs, ) else: raise NotImplementedError return inputs_embeds @torch.inference_mode() def prefill_text( self, input_ids: torch.Tensor, position_ids: torch.LongTensor, past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], lm_spk_emb_last_hidden_states: Optional[torch.Tensor] = None, ): """Prefill a chunk of new text tokens in streaming setting. Specifically speaking, update `past_key_values` using new text tokens, then the model will read the new text tokens. Args: input_ids (Tensor): Tensor of shape [batch_size, seq_len] position_ids (LongTensor): Tensor of shape [batch_size, seq_len] past_key_values (List[Tuple[Tensor]]): KV Cache of all layers, each layer is a tuple (Tensor, Tensor) denoting keys and values. Each tensor is of seq_len = `self.streaming_text_reserved_len`. `past_key_values` will be updated. lm_spk_emb_last_hidden_states (Tensor, optional): Tensor of shape [batch_size, num_spk_emb, llm_dim]. Defaults to None. Note that all `batch_size` should be `1`. """ assert input_ids.shape[0] == 1 assert past_key_values is not None # Merge text and LLM embeddings inputs_embeds = self.merge_inputs_embeds( input_ids=input_ids, lm_spk_emb_last_hidden_states=lm_spk_emb_last_hidden_states, ) # Clone KV Cache past_key_values_for_prefill = [] for i in range(len(past_key_values)): past_key_values_for_prefill.append( ( past_key_values[i][0][:, :, : position_ids[:, 0], :].clone(), past_key_values[i][1][:, :, : position_ids[:, 0], :].clone(), ) ) # ModelMiniCPMVBaseModel outputs_prefill: BaseModelOutputWithPast = self.model( attention_mask=None, # because for text, it is standard causal attention mask, do nothing position_ids=position_ids, # position_ids denotes the position of new text tokens in the sequence past_key_values=past_key_values_for_prefill, # `past_key_values` will be updated by the model inputs_embeds=inputs_embeds, # contains text and language model embedding use_cache=True, output_attentions=False, cache_position=position_ids, # which new positions will use this cache, basically the same as position_ids ) # Get model updated KV Cache past_key_values_for_prefill_updated = outputs_prefill.past_key_values # Update generated KV Cache to input `past_key_values` for layer_idx in range(len(past_key_values)): # Update keys past_key_values[layer_idx][0][ :, :, position_ids[:, 0] : position_ids[:, -1] + 1, : ] = past_key_values_for_prefill_updated[layer_idx][0][ :, :, position_ids[:, 0] : position_ids[:, -1] + 1 ].clone() # Update values past_key_values[layer_idx][1][ :, :, position_ids[:, 0] : position_ids[:, -1] + 1, : ] = past_key_values_for_prefill_updated[layer_idx][1][ :, :, position_ids[:, 0] : position_ids[:, -1] + 1 ].clone() # TODO: del past_key_values_for_prefill_updated recursively # TODO: del outputs_prefill recursively return past_key_values @torch.inference_mode() def prefill_audio_ids( self, input_ids: torch.Tensor, past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], streaming_tts_text_mask=None, add_audio_bos: bool = True, ): """Prefill a chunk of audio ids to the model. Used in sliding-window long audio generation. Specifically, prefill many audio ids (typically from last window) to the model in the new window. Args: input_ids (torch.Tensor): (1, seq_len, num_vq) Audio input token ids. past_key_values (List[Tuple[torch.Tensor, torch.Tensor]]): Past key values for attention mechanism. """ assert input_ids.shape[0] == 1 assert past_key_values is not None code_emb = [self.emb_code[i](input_ids[:, :, i]) for i in range(self.num_vq)] inputs_embeds = torch.stack(code_emb, 3).sum(3) # [1,seq_len,768] input_len = input_ids.shape[1] if add_audio_bos: narrowed_input_ids = torch.tensor( [[self.audio_bos_token_id]], dtype=torch.long, device=self.device ) bos_inputs_embeds = self.emb_text(narrowed_input_ids) inputs_embeds = torch.cat([bos_inputs_embeds, inputs_embeds], dim=1) input_len += 1 past_key_values_length = past_key_values[0][0].shape[2] position_ids = torch.arange( past_key_values_length, past_key_values_length + input_len, dtype=torch.long, device=self.device, ).unsqueeze(0) cache_position = position_ids.clone() causal_mask = make_streaming_chunk_mask_generation( inputs_embeds=inputs_embeds, past_seen_tokens=past_key_values[0][0].shape[2], streaming_tts_text_mask=streaming_tts_text_mask, streaming_reserved_length=self.streaming_text_reserved_len, streaming_text_chunk_size=self.streaming_text_chunk_size, ) # [1, 1, 1, past_key_values_length + input_len] # Model forward outputs: BaseModelOutputWithPast = self.model( attention_mask=causal_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=True, output_attentions=False, cache_position=cache_position, ) past_key_values = outputs.past_key_values return past_key_values @torch.inference_mode() def generate( self, input_ids: torch.Tensor, past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], temperature: torch.Tensor, eos_token: Union[int, torch.Tensor], streaming_tts_text_mask=None, force_no_stop=False, min_new_token=10, max_new_token=50, logits_warpers: List[LogitsWarper] = [], logits_processors: List[CustomRepetitionPenaltyLogitsProcessorRepeat] = [], show_tqdm=False, ): """Generate audio codes in streaming setting or non-streaming setting. Specifically speaking, generate audio codes when not all text tokens are prefilled. Always pass a valid `past_key_values` to the method. The method does not do `prefill` by itself. It relies on `prefill_text` method to provide valid `past_key_values`. Please refer to docstring of this class for more details. In this method, we borrowed a lot of codes from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/gpt.py`. Args: input_ids (torch.Tensor): Input token ids. past_key_values (List[Tuple[torch.Tensor, torch.Tensor]]): Past key values for attention mechanism. temperature (torch.Tensor): Temperature for sampling. eos_token (Union[int, torch.Tensor]): End of sequence token. streaming_tts_text_mask (Optional[torch.Tensor], optional): Mask for streaming TTS text. Defaults to None. max_new_token (int, optional): Maximum number of new tokens to generate. Defaults to 50. logits_warpers (List[LogitsWarper], optional): List of logits warpers. Defaults to []. logits_processors (List[CustomRepetitionPenaltyLogitsProcessorRepeat], optional): List of logits processors. Defaults to []. show_tqdm (bool, optional): Whether to show progress bar. Defaults to True. Returns: GenerationOutputs: Generation outputs. """ # We only support batch size `1` for now assert input_ids.shape[0] == 1 assert past_key_values is not None # fix: this should not be `input_ids.shape[1]` # start_idx = input_ids.shape[1] start_idx = ( 1 + self.num_spk_embs * self.use_speaker_embedding + self.streaming_text_reserved_len + 1 ) finish = torch.zeros(input_ids.shape[0], device=input_ids.device).bool() temperature = ( temperature.unsqueeze(0) .expand(input_ids.shape[0], -1) .contiguous() .view(-1, 1) ) progress = input_ids.shape[1] # Pre-allocate input_ids, shape is [batch_size=1, max_possible_seq_len, self.num_vqs] input_ids_buf = torch.zeros( input_ids.shape[0], # batch_size progress + max_new_token, # max_possible_seq_len = input_ids.shape[1] + max_new_token input_ids.shape[2], # self.num_vqs dtype=input_ids.dtype, device=input_ids.device, ) # Copy existing `input_ids` to `input_ids_buf` input_ids_buf.narrow(1, 0, progress).copy_(input_ids) del input_ids input_ids = input_ids_buf.narrow(1, 0, progress) pbar: Optional[tqdm] = None if show_tqdm: pbar = tqdm( total=max_new_token, desc="code", bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]", ) condition_length = ( 1 + self.num_spk_embs * self.use_speaker_embedding + self.streaming_text_reserved_len + 1 ) for i in range(max_new_token): # Prepare generation inputs audio_bos = False # If this is the first audio token, the case is SPECIAL if progress == condition_length: audio_bos = True assert progress == ( past_key_values[0][0].shape[2] + 1 ) # If you are using according to the guidelines, this should be passed. if audio_bos: # Generate the first token, activate the model with `self.audio_bos_token_id`, the model will predict # a new audio token. This is a special case because without the `audio bos token`, it is impossible # to generate the first audio token in our streaming setting. narrowed_input_ids = torch.tensor( [[self.audio_bos_token_id]], dtype=torch.long, device=self.device ) inputs_embeds = self.emb_text(narrowed_input_ids) del narrowed_input_ids else: # Generate the following audio tokens, it is applicable to all other cases, including second and the # following calling of `generate`. narrowed_input_ids = input_ids.narrow( dim=1, start=input_ids.shape[1] - 1, length=1 ) code_emb = [ self.emb_code[i](narrowed_input_ids[:, :, i]) for i in range(self.num_vq) ] inputs_embeds = torch.stack(code_emb, 3).sum(3) position_ids = torch.tensor( [past_key_values[0][0].shape[2]], dtype=torch.long, device=self.device ).unsqueeze(0) cache_position = position_ids.clone() # Make causal mask causal_mask = make_streaming_chunk_mask_generation( inputs_embeds=inputs_embeds, past_seen_tokens=past_key_values[0][0].shape[2], streaming_tts_text_mask=streaming_tts_text_mask, streaming_reserved_length=self.streaming_text_reserved_len, streaming_text_chunk_size=self.streaming_text_chunk_size, ) # Model forward outputs: BaseModelOutputWithPast = self.model( attention_mask=causal_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=True, output_attentions=False, cache_position=cache_position, ) del position_ids del inputs_embeds del cache_position del causal_mask hidden_states = outputs.last_hidden_state past_key_values = outputs.past_key_values with P.cached(): logits = torch.empty( hidden_states.size(0), hidden_states.size(1), self.num_audio_tokens, self.num_vq, dtype=torch.float, device=self.device, ) for num_vq_iter in range(self.num_vq): x: torch.Tensor = self.head_code[num_vq_iter](hidden_states) logits[..., num_vq_iter] = x del x del hidden_states # logits = logits[:, -1].float() logits = logits.narrow(1, -1, 1).squeeze_(1).float() # logits = rearrange(logits, "b c n -> (b n) c") logits = logits.permute(0, 2, 1) logits = logits.reshape(-1, logits.size(2)) # logits_token = rearrange(input_ids[:, start_idx:], "b c n -> (b n) c") input_ids_sliced = input_ids.narrow( 1, start_idx, input_ids.size(1) - start_idx, ).permute(0, 2, 1) logits_token = input_ids_sliced.reshape( input_ids_sliced.size(0) * input_ids_sliced.size(1), -1, ).to(self.device) del input_ids_sliced logits /= temperature if not audio_bos: for logitsProcessors in logits_processors: logits = logitsProcessors(logits_token, logits) if not audio_bos: for logitsWarpers in logits_warpers: logits = logitsWarpers(logits_token, logits) del logits_token if i < min_new_token: logits[:, eos_token] = -torch.inf if force_no_stop: logits[:, eos_token] = -torch.inf scores = F.softmax(logits, dim=-1) del logits idx_next = torch.multinomial(scores, num_samples=1) # .to(finish.device) del scores # idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq) idx_next = idx_next.view(-1, self.num_vq) finish_or = idx_next.eq(eos_token).any(1) finish.logical_or_(finish_or) del finish_or # Store new `token` into `input_ids_buf` input_ids_buf.narrow(1, progress, 1).copy_(idx_next.unsqueeze_(1)) if i == 0 and finish.any(): # raise Exception break del idx_next progress += 1 input_ids = input_ids_buf.narrow(1, 0, progress) if finish.all(): break if pbar is not None: pbar.update(1) if pbar is not None: pbar.close() if not finish.all(): if show_tqdm: logger.info(f"incomplete result. hit max_new_token: {max_new_token}") del input_ids_buf if finish.all(): # the last may contains eos token genrated_input_ids = input_ids[:, condition_length:-1, :] else: # there is no eos token genrated_input_ids = input_ids[:, condition_length:, :] return ConditionalChatTTSGenerationOutput( new_ids=genrated_input_ids, audio_input_ids=input_ids, # for update purpose past_key_values=past_key_values, # for update purpose finished=finish.all(), ) @torch.inference_mode() def decode_to_mel_specs( self, result_list: List[torch.Tensor], ): """Decode discrete audio codes to mel spectrograms. Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/core.py` Args: result_list (List[torch.Tensor]): Audio codes output from `generate`. Returns: torch.Tensor: Mel spectrograms. """ decoder = self.dvae max_x_len = -1 if len(result_list) == 0: return np.array([], dtype=np.float32) for result in result_list: if result.size(0) > max_x_len: max_x_len = result.size(0) batch_result = torch.zeros( (len(result_list), result_list[0].size(1), max_x_len), dtype=result_list[0].dtype, device=result_list[0].device, ) for i in range(len(result_list)): src = result_list[i] batch_result[i].narrow(1, 0, src.size(0)).copy_(src.permute(1, 0)) del src mel_specs = decoder(batch_result) del batch_result return mel_specs # Copied from transformers.models.whisper.modeling_whisper.WhisperEncoderLayer and add use_cache for streaming inference class MiniCPMWhisperEncoderLayer(nn.Module): def __init__(self, config: WhisperConfig, layer_idx: int = None): super().__init__() self.embed_dim = config.d_model self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, config=config, layer_idx=layer_idx, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, layer_head_mask: torch.Tensor, output_attentions: bool = False, past_key_values: Optional[EncoderDecoderCache] = None, use_cache: Optional[bool] = False, ) -> torch.Tensor: r""" Args: hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, embed_dim)`): Hidden states to be fed into the encoder layer. attention_mask (`torch.FloatTensor` of shape `(batch_size, 1, tgt_len, src_len)`): Attention mask where padding elements are indicated by large negative values. layer_head_mask (`torch.FloatTensor` of shape `(encoder_attention_heads,)`): Mask to nullify selected heads of the attention modules. output_attentions (`bool`, *optional*): Whether or not to return the attention weights. past_key_values (`EncoderDecoderCache`, *optional*): Past key-value pairs used for incremental decoding. use_cache (`bool`, *optional*): Whether or not to return updated `past_key_values` for caching. Returns: A tuple of shape `(hidden_states, optional(attn_weights), optional(past_key_values))`. """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, attn_weights, past_key_values = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, past_key_value=past_key_values, ) hidden_states = nn.functional.dropout( hidden_states, p=self.dropout, training=False ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout( hidden_states, p=self.activation_dropout, training=False ) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout( hidden_states, p=self.dropout, training=False ) hidden_states = residual + hidden_states if hidden_states.dtype == torch.float16 and ( torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() ): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp( hidden_states, min=-clamp_value, max=clamp_value ) outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) if use_cache: outputs += (past_key_values,) return outputs # Copied from from transformers.models.whisper.modeling_whisper.WhisperEncoder and add use_cache for streaming inference class MiniCPMWhisperEncoder(WhisperEncoder): def __init__(self, config: WhisperConfig): super().__init__(config) self.layers = nn.ModuleList( [ MiniCPMWhisperEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers) ] ) def forward( self, input_features, attention_mask=None, head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, past_key_values: Optional[EncoderDecoderCache] = None, use_cache: Optional[bool] = None, ): r""" Forward pass of the Whisper encoder. Args: input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`): Float values of log-mel features extracted from the raw audio waveform. Typically generated by a feature extractor (e.g., `WhisperFeatureExtractor`) that processes `.flac` or `.wav` files into padded 2D mel spectrogram frames. These features are projected via convolution layers (`conv1` and `conv2`) and then transformed into embeddings for the encoder. attention_mask (`torch.Tensor`, *optional*): Not used by Whisper for masking `input_features`, but included for API compatibility with other models. If provided, it is simply ignored within the model. By default, Whisper effectively ignores silence in the input log-mel spectrogram. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): Mask to nullify selected attention heads. The elements should be either 1 or 0, where: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked** (i.e., the attention head is dropped). output_attentions (`bool`, *optional*): Whether or not to return the attention tensors of all encoder layers. If set to `True`, the returned tuple (or `BaseModelOutputWithPast`) will contain an additional element with attention weights for each encoder layer. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. If set to `True`, the returned tuple (or `BaseModelOutputWithPast`) will contain a tuple of hidden states, including the initial embedding output as well as the outputs of each layer. return_dict (`bool`, *optional*): Whether or not to return a `BaseModelOutputWithPast` (a subclass of `ModelOutput`) instead of a plain tuple. If set to `True`, the output will be a `BaseModelOutputWithPast` object, otherwise it will be a tuple. past_key_values (`EncoderDecoderCache`, *optional*): When using caching for faster inference, this is an object that stores the key-value pairs for attention states. If provided, the model will append new states to the existing cache and return the updated cache. This speeds up sequential decoding or chunked inference. - If `past_key_values` is `None`, no past states are used or returned. - If `past_key_values` is not `None` and `use_cache=True`, the model will use the provided cache and return the updated cache (as `next_encoder_cache`). use_cache (`bool`, *optional*): Whether or not the model should use caching (`past_key_values`) to speed up processing during inference. When set to `True`, the model will: - Inspect and use `past_key_values` if provided. - Return updated `past_key_values` (under the name `next_encoder_cache` in `BaseModelOutputWithPast`). Returns: `BaseModelOutputWithPast` or `tuple` (depending on `return_dict`): If `return_dict=True`, a `BaseModelOutputWithPast` is returned, which contains: - **last_hidden_state** (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): The output of the final encoder layer. - **hidden_states** (`tuple(torch.FloatTensor)`, *optional*, returned if `output_hidden_states=True`): Hidden states of the model at each layer (including the initial projection). - **attentions** (`tuple(torch.FloatTensor)`, *optional*, returned if `output_attentions=True`): Attention weights from each encoder layer. - **past_key_values** (an object of type `EncoderDecoderCache` or `None`, *optional*): Updated cache of key-value pairs if `use_cache=True`. If `return_dict=False`, a tuple is returned, where the format is: `(last_hidden_state, hidden_states, attentions)`, with `hidden_states` and `attentions` only present if their respective `output_*` arguments are set to `True`. """ output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # Ignore copy input_features = input_features.to( dtype=self.conv1.weight.dtype, device=self.conv1.weight.device ) inputs_embeds = nn.functional.gelu(self.conv1(input_features)) inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) inputs_embeds = inputs_embeds.permute(0, 2, 1) embed_pos = self.embed_positions.weight past_key_values_length = 0 if use_cache: if past_key_values is None: past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) elif isinstance(past_key_values, list): past_key_values = EncoderDecoderCache( DynamicCache.from_legacy_cache(past_key_values), DynamicCache() ) elif isinstance(past_key_values, DynamicCache): past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) else: pass past_key_values_length = ( past_key_values.self_attention_cache.get_usable_length( inputs_embeds.shape[1] ) ) if inputs_embeds.shape[1] + past_key_values_length > embed_pos.shape[0]: logger.warning( "seems the audio is longer than 30s. repeating the last part of the audio" ) embed_pos_front = embed_pos[past_key_values_length:, :] embed_pos = torch.cat( ( embed_pos_front, torch.repeat_interleave( embed_pos[-1, :].unsqueeze(0), inputs_embeds.shape[1] - embed_pos.shape[0] + past_key_values_length, dim=0, ), ) ) else: embed_pos = embed_pos[ past_key_values_length : inputs_embeds.shape[1] + past_key_values_length, :, ] else: embed_pos = embed_pos[: inputs_embeds.shape[1], :] hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout( hidden_states, p=self.dropout, training=False ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None # check if head_mask has a correct number of layers specified if desired if head_mask is not None: assert head_mask.size()[0] == ( len(self.layers) ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) to_drop = False # Ignore copy if to_drop: layer_outputs = (None, None) else: layer_outputs = encoder_layer( hidden_states, attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), output_attentions=output_attentions, past_key_values=past_key_values, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_encoder_cache = layer_outputs[2 if output_attentions else 1] else: next_encoder_cache = None if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) hidden_states = self.layer_norm(hidden_states) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if not return_dict: return tuple( v for v in [hidden_states, encoder_states, all_attentions] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions, past_key_values=next_encoder_cache, ) class MultiModalProjector(nn.Module): def __init__(self, in_dim, out_dim): super().__init__() self.linear1 = nn.Linear(in_features=in_dim, out_features=out_dim, bias=True) self.relu = nn.ReLU() self.linear2 = nn.Linear(in_features=out_dim, out_features=out_dim, bias=True) def forward(self, audio_features): hidden_states = self.relu(self.linear1(audio_features)) hidden_states = self.linear2(hidden_states) return hidden_states class MiniCPMO(MiniCPMVBaseModel): def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__(config=config, quant_config=quant_config) self.llm = self.init_llm(config=config, quant_config=quant_config) self.embed_dim = self.llm.config.hidden_size # init vision module if self.config.init_vision: # print("vision-understanding enabled") self.vpm = self.init_vision_module(config=config, quant_config=quant_config) self.vision_dim = self.vpm.embed_dim self.resampler = self.init_resampler(self.embed_dim, self.vision_dim) # init audio module self.config.init_audio = True if self.config.init_audio: # print("audio-understanding enabled") self.apm = self.init_audio_module() audio_output_dim = int(self.apm.config.encoder_ffn_dim // 4) self.audio_avg_pooler = nn.AvgPool1d( self.config.audio_pool_step, stride=self.config.audio_pool_step ) self.audio_projection_layer = MultiModalProjector( in_dim=audio_output_dim, out_dim=self.embed_dim ) self.audio_encoder_layer = -1 # init tts module self.config.init_tts = False logger.info("TTS is disabled for now") if self.config.init_tts: # print("tts enabled") assert ( _tts_deps ), "please make sure vector_quantize_pytorch and vocos are installed." self.tts = self.init_tts_module() def init_tts_module(self): model = ConditionalChatTTS(self.config.tts_config) return model def init_audio_module(self): model = MiniCPMWhisperEncoder(self.config.audio_config) return model def init_llm( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> nn.Module: return Qwen2ForCausalLM(config=config, quant_config=quant_config, prefix=prefix) def init_vision_module( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig], prefix: str = "", ): if self.config._attn_implementation == "flash_attention_2": self.config.vision_config._attn_implementation = "flash_attention_2" else: self.config.vision_config._attn_implementation = "eager" model = Idefics2VisionTransformer( config=config.vision_config, quant_config=quant_config, prefix=prefix ) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] setattr(model, "embed_dim", model.embeddings.embed_dim) setattr(model, "patch_size", model.embeddings.patch_size) return model def init_resampler( self, embed_dim: int, vision_dim: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> nn.Module: with set_default_torch_dtype(torch.float16): # The resampler in 2.6 remains consistent with the one in 2.5. resampler = Resampler2_5( num_queries=self.config.query_num, embed_dim=embed_dim, num_heads=embed_dim // 128, kv_dim=vision_dim, quant_config=quant_config, prefix=prefix, ) return resampler.to(device="cuda", dtype=torch.get_default_dtype()) def pad_input_ids(self, input_ids: List[int], mm_input: MultimodalInputs): # Get all special token IDs im_start_id: int = mm_input.im_start_id im_end_id: int = mm_input.im_end_id slice_start_id: int = mm_input.slice_start_id slice_end_id: int = mm_input.slice_end_id media_token_pairs = [ (im_start_id, im_end_id), (slice_start_id, slice_end_id), (mm_input.audio_start_id, mm_input.audio_end_id), ] pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs) return pattern.pad_input_tokens(input_ids, mm_input) def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ Computes the output length of the convolutional layers and the output length of the audio encoder """ input_lengths_after_cnn = (input_lengths - 1) // 2 + 1 input_lengths_after_pooling = ( input_lengths_after_cnn - self.config.audio_pool_step ) // self.config.audio_pool_step + 1 input_lengths_after_pooling = input_lengths_after_pooling.to(dtype=torch.int32) return input_lengths_after_cnn, input_lengths_after_pooling def get_audio_embedding_streaming(self, multimodal_input: MultimodalInputs): r""" Extract audio embeddings in a streaming manner using cached key-value pairs. This method processes incoming audio features incrementally and stores/updates `past_key_values` for faster inference on subsequent audio frames. It only supports batch_size=1 and is intended for streaming scenarios. Args: multimodal_input (dict): - **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`. - **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch. Returns: List[List[torch.Tensor]]: audio embeddings """ # print("audio embedding") wavforms = ( [] if multimodal_input.audio_features is None else multimodal_input.audio_features ) # list, [[x1, x2], [y1], [z1]] audio_feature_lens_raw = ( [] if multimodal_input.audio_feature_lens is None else multimodal_input.audio_feature_lens ) # exist audio if len(wavforms) > 0: audio_feature_lens = torch.hstack(audio_feature_lens_raw) batch_size, _, max_mel_seq_len = wavforms.shape assert batch_size == 1 max_seq_len = (max_mel_seq_len - 1) // 2 + 1 if self.audio_past_key_values is not None: cache_length = self.audio_past_key_values[0][0].shape[2] apm_max_len = self.apm.embed_positions.weight.shape[0] if cache_length + max_seq_len >= apm_max_len: logger.warning( f"audio_past_key_values length {cache_length + max_seq_len} exceed {apm_max_len}, reset." ) self.audio_past_key_values = None audio_outputs = self.apm( wavforms, past_key_values=self.audio_past_key_values, use_cache=True ) audio_states = ( audio_outputs.last_hidden_state ) # [:, :audio_feat_lengths, :] self.audio_past_key_values = audio_outputs.past_key_values audio_embeds = self.audio_projection_layer(audio_states) audio_embeds = audio_embeds.transpose(1, 2) audio_embeds = self.audio_avg_pooler(audio_embeds) audio_embeds = audio_embeds.transpose(1, 2) _, feature_lens_after_pooling = self._get_feat_extract_output_lengths( audio_feature_lens ) num_audio_tokens = feature_lens_after_pooling final_audio_embeds = [] idx = 0 for i in range(len(audio_feature_lens_raw)): target_audio_embeds = [] for _ in range(len(audio_feature_lens_raw[i])): target_audio_embeds.append( audio_embeds[idx, : num_audio_tokens[idx], :] ) idx += 1 final_audio_embeds.append(target_audio_embeds) return final_audio_embeds else: return [] def subsequent_chunk_mask( self, size: int, chunk_size: int, num_left_chunks: int = -1, device: torch.device = torch.device("cpu"), num_lookhead: int = 0, ) -> torch.Tensor: """Create mask for subsequent steps (size, size) with chunk size, this is for streaming encoder Args: size (int): size of mask chunk_size (int): size of chunk num_left_chunks (int): number of left chunks <0: use full chunk >=0: use num_left_chunks device (torch.device): "cpu" or "cuda" or torch.Tensor.device Returns: torch.Tensor: mask """ ret = torch.zeros(size, size, device=device, dtype=torch.bool) for i in range(size): if num_left_chunks < 0: start = 0 else: start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) ending = min((i // chunk_size + 1) * chunk_size + num_lookhead, size) ret[i, start:ending] = True return ret def get_audio_embedding(self, multimodal_input: MultimodalInputs, chunk_length=-1): r""" Extract full audio embeddings with optional chunk-based attention. This method computes embeddings for all audio frames at once, either using full attention (when `chunk_length` is -1) or chunk-based attention (when `chunk_length` is a positive number). It does not use key-value caching and is suitable for non-streaming inference. Args: multimodal_input (dict): - **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`. - **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch. chunk_length (int, optional): Determines whether to use full attention (-1) or chunk-based attention (>0) during embedding computation. Returns: List[List[torch.Tensor]]: audio embeddings """ # print("audio embedding") # (bs, 80, frames) or [], multi audios need filled in advance wavforms = ( [] if multimodal_input.audio_features is None else multimodal_input.audio_features ) # list, [[x1, x2], [y1], [z1]] audio_feature_lens_raw = ( [] if multimodal_input.audio_feature_lens is None else multimodal_input.audio_feature_lens ) final_audio_embeds = [] # exist audio for wavform in wavforms: if len(wavform) > 0: audio_feature_lens = torch.hstack(audio_feature_lens_raw) batch_size, _, max_mel_seq_len = wavform.shape max_seq_len = (max_mel_seq_len - 1) // 2 + 1 # Create a sequence tensor of shape (batch_size, max_seq_len) seq_range = ( torch.arange( 0, max_seq_len, dtype=audio_feature_lens.dtype, device=audio_feature_lens.device, ) .unsqueeze(0) .expand(batch_size, max_seq_len) ) lengths_expand = audio_feature_lens.unsqueeze(1).expand( batch_size, max_seq_len ) # Create mask padding_mask = seq_range >= lengths_expand # 1 for padded values audio_attention_mask_ = padding_mask.view( batch_size, 1, 1, max_seq_len ).expand(batch_size, 1, max_seq_len, max_seq_len) audio_attention_mask = audio_attention_mask_.to( dtype=self.apm.conv1.weight.dtype, device=self.apm.conv1.weight.device, ) if chunk_length > 0: chunk_num_frame = int(chunk_length * 50) chunk_mask = self.subsequent_chunk_mask( size=max_seq_len, chunk_size=chunk_num_frame, num_left_chunks=-1, device=audio_attention_mask_.device, ) audio_attention_mask_ = torch.logical_or( audio_attention_mask_, torch.logical_not(chunk_mask) ) audio_attention_mask[audio_attention_mask_] = float("-inf") audio_states = self.apm( wavform, output_hidden_states=True, attention_mask=audio_attention_mask, ).hidden_states[self.audio_encoder_layer] audio_embeds = self.audio_projection_layer(audio_states) audio_embeds = audio_embeds.transpose(1, 2) audio_embeds = self.audio_avg_pooler(audio_embeds) audio_embeds = audio_embeds.transpose(1, 2) _, feature_lens_after_pooling = self._get_feat_extract_output_lengths( audio_feature_lens ) num_audio_tokens = feature_lens_after_pooling idx = 0 for i in range(len(audio_feature_lens_raw)): target_audio_embeds = [] for _ in range(len(audio_feature_lens_raw[i])): target_audio_embeds.append( audio_embeds[idx, : num_audio_tokens[idx], :] ) idx += 1 final_audio_embeds.append(target_audio_embeds) return final_audio_embeds def get_omni_embedding( self, input_ids, multimodal_input: MultimodalInputs, input_embeds: torch.Tensor, forward_mode: ForwardMode, chunk_length=-1, stream_input=False, ): """ Args: multimodal_input: input_embeds: chunk_length: whisper use full attention or chunk attention stream_input: use streaming audio embedding Returns: final embeddings with audio feature """ input_embeds = input_embeds.unsqueeze(0) if not forward_mode.is_decode() and multimodal_input.contains_audio_inputs(): audio_bounds = get_multimodal_data_bounds( input_ids=input_ids, pad_values=multimodal_input.pad_values, token_pairs=[ (multimodal_input.audio_start_id, multimodal_input.audio_end_id) ], ) if audio_bounds.numel() == 0: input_embeds = input_embeds.squeeze(0) # TODO logger.warn("Unimplemented logic. Please try disabling chunked prefill") return input_embeds audio_bounds = audio_bounds.unsqueeze(0) bs = len(input_embeds) if stream_input: audio_embeddings = self.get_audio_embedding_streaming(multimodal_input) else: audio_embeddings = self.get_audio_embedding( multimodal_input, chunk_length ) # batch size assert len(audio_embeddings) == len(input_embeds) if len(audio_embeddings) > 0: if self.config.chunk_input: for i in range(bs): audio_embs = torch.cat(audio_embeddings[i], dim=0).to( device=input_embeds.device, dtype=input_embeds.dtype ) audio_start_pos = 0 for bound in audio_bounds[i]: audio_len = bound[1] - bound[0] + 1 input_embeds[0, bound[0] : bound[1] + 1] = audio_embs[ audio_start_pos : audio_start_pos + audio_len, : ] audio_start_pos += audio_len else: for i in range(bs): audio_embs = audio_embeddings[i] bounds = audio_bounds[i] for embs, bound in zip(audio_embs, bounds): audio_indices = torch.arange( bound[0], bound[1], dtype=torch.long ).to(input_embeds.device) if embs.shape[0] != len(audio_indices): raise ValueError( f"Shape mismatch: Trying to assign embeddings of shape {embs.shape} " f"to input indices of length {len(audio_indices)}" ) input_embeds[i, audio_indices] = embs.to(input_embeds.dtype) input_embeds = input_embeds.squeeze(0) return input_embeds def get_image_features( self, image_inputs: MultimodalInputs, ) -> torch.Tensor: pixel_values = image_inputs.pixel_values tgt_sizes = image_inputs.tgt_sizes device = self.vpm.embeddings.position_embedding.weight.device dtype = self.vpm.embeddings.position_embedding.weight.dtype all_pixel_values_lst = [ i.flatten(end_dim=1).permute(1, 0) for i in pixel_values ] max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item() assert isinstance(max_patches, int) all_pixel_values = torch.nn.utils.rnn.pad_sequence( all_pixel_values_lst, batch_first=True, padding_value=0.0 ) B, L, _ = all_pixel_values.shape all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L) patch_attn_mask = torch.zeros( (B, 1, max_patches), dtype=torch.bool, device=device ) tgt_sizes_tensor = tgt_sizes.clone().to(device=patch_attn_mask.device) mask_shapes = tgt_sizes_tensor[:, 0] * tgt_sizes_tensor[:, 1] patch_attn_mask[:, 0, :] = torch.arange( patch_attn_mask.size(2), device=patch_attn_mask.device ).unsqueeze(0) < mask_shapes.unsqueeze(1) vision_embedding = self.vpm( all_pixel_values.type(dtype), patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes, ) return self.resampler(vision_embedding, tgt_sizes) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, **kwargs: Any, ) -> torch.Tensor: inputs_embeds = None # TODO(mick): optimize the logic here: clamp, merge and embedding should happens at most once if ( not forward_batch.forward_mode.is_decode() and forward_batch.contains_image_inputs() ): mm_inputs = forward_batch.merge_mm_inputs() inputs_embeds = embed_mm_inputs( mm_input=mm_inputs, input_ids=input_ids, input_embedding=self.get_input_embeddings(), mm_data_embedding_func=self.get_image_features, placeholder_token_ids=[mm_inputs.im_token_id] + mm_inputs.pad_values, ) input_ids = input_ids.clamp( min=0, max=self.get_input_embeddings().num_embeddings - 1 ) if inputs_embeds is None: inputs_embeds = self.llm.get_input_embeddings(input_ids) if ( not forward_batch.forward_mode.is_decode() and self.config.init_audio and forward_batch.contains_audio_inputs() ): mm_input = forward_batch.merge_mm_inputs() inputs_embeds = self.get_omni_embedding( input_ids=input_ids, multimodal_input=mm_input, input_embeds=inputs_embeds, forward_mode=forward_batch.forward_mode, chunk_length=self.config.audio_chunk_length, stream_input=False, ) forward_batch.mm_inputs = None hidden_states = self.llm.model( input_ids=None, positions=positions, forward_batch=forward_batch, input_embeds=inputs_embeds, ) return self.logits_processor( input_ids, hidden_states, self.llm.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq~" in name or "projector" in name: continue if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue # adapt to parametrization if self.config.init_tts and "tts" in name: name = name.replace(".parametrizations", "") name = name.replace(".weight.original0", ".weight_g") name = name.replace(".weight.original1", ".weight_v") # adapt to VisionAttention if "vpm" in name: name = name.replace(r"self_attn.out_proj", r"self_attn.proj") if not self.config.init_tts and "tts" in name: continue if not self.config.init_audio and ("apm" in name or "audio" in name): continue if not self.config.init_vision and "vpm" in name: continue if ( "sampler" in name or "apm" in name or ("tts" in name and "self_attn" in name) or ("tts.model.layers" in name and ".mlp" in name) ): param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) continue for param_name, weight_name, shard_id in stacked_params_mapping: # replace the name and load with customized loader if weight_name not in name: continue name = name.replace(weight_name, param_name) # # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) EntryClass = [MiniCPMO]