1996 lines
78 KiB
Python
1996 lines
78 KiB
Python
# Copied and adapted from: https://huggingface.co/openbmb/MiniCPM-o-2_6/blob/main/modeling_minicpmo.py
|
|
|
|
# Copyright 2023-2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Inference-only MiniCPM-o model compatible with HuggingFace weights."""
|
|
|
|
import math
|
|
from dataclasses import dataclass
|
|
from typing import Any, Iterable, List, Literal, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.nn.utils.parametrize as P
|
|
import torch.types
|
|
from torch import nn
|
|
from torch.nn.utils import weight_norm
|
|
from tqdm import tqdm
|
|
from transformers import LlamaConfig, LlamaModel, PretrainedConfig, PreTrainedModel
|
|
from transformers.activations import ACT2FN
|
|
from transformers.cache_utils import DynamicCache, EncoderDecoderCache
|
|
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
|
from transformers.models.whisper.modeling_whisper import (
|
|
WHISPER_ATTENTION_CLASSES,
|
|
WhisperConfig,
|
|
WhisperEncoder,
|
|
)
|
|
|
|
from sglang.srt.layers.quantization import QuantizationConfig
|
|
from sglang.srt.managers.mm_utils import (
|
|
MultiModalityDataPaddingPatternTokenPairs,
|
|
embed_mm_inputs,
|
|
get_multimodal_data_bounds,
|
|
)
|
|
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
|
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
from sglang.srt.models.minicpmv import (
|
|
Idefics2VisionTransformer,
|
|
MiniCPMVBaseModel,
|
|
Resampler2_5,
|
|
)
|
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
|
from sglang.srt.utils import logger
|
|
|
|
try:
|
|
from transformers import LogitsWarper
|
|
from vector_quantize_pytorch import GroupedResidualFSQ
|
|
from vocos import Vocos
|
|
from vocos.pretrained import instantiate_class
|
|
|
|
_tts_deps = True
|
|
except:
|
|
LogitsWarper = None
|
|
_tts_deps = False
|
|
|
|
|
|
def apply_spk_emb(
|
|
input_ids: torch.Tensor = None,
|
|
spk_emb: torch.Tensor = None,
|
|
input_embeds: torch.Tensor = None,
|
|
spk_emb_token_id: int = 0,
|
|
num_spk_embs: int = 1,
|
|
):
|
|
"""
|
|
Replace consecutive `num_spk_embs` speaker embedding placeholders in input_embeds with pre-prepared speaker embeddings. This is an in-place replacement, no new tensor is created, so no value is returned.
|
|
|
|
Args:
|
|
input_ids (torch.Tensor): Input ID tensor, shape [batch_size, seq_len_max]
|
|
spk_emb (torch.Tensor): Speaker embedding tensor, shape [batch_size, num_spk_emb, hidden_dim]
|
|
input_embeds (torch.Tensor): Input embedding tensor, shape [batch_size, seq_len_max, hidden_dim]
|
|
spk_emb_token_id (int): ID of the speaker embedding token
|
|
num_spk_embs (int): Number of speaker embeddings
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
|
|
batch_size = input_ids.shape[0]
|
|
|
|
for idx in range(batch_size):
|
|
input_ids_ = input_ids[idx] # [seq_len_max]
|
|
spk_emb_ = spk_emb[idx] # [num_spk_emb]
|
|
mask_ = input_ids_ == spk_emb_token_id # [batch_size, seq_len_max]
|
|
nonzero_position_idx = mask_.nonzero(as_tuple=False) # [num_spk_emb, 1]
|
|
assert nonzero_position_idx.shape[0] == num_spk_embs
|
|
begin_idx = nonzero_position_idx.min()
|
|
end_idx = nonzero_position_idx.max()
|
|
input_embeds[idx, begin_idx : end_idx + 1, :] = spk_emb_
|
|
|
|
return
|
|
|
|
|
|
@dataclass
|
|
class ConditionalChatTTSGenerationOutput(ModelOutput):
|
|
"""
|
|
Output class for ConditionalChatTTS generation.
|
|
|
|
Args:
|
|
new_ids (torch.LongTensor): Newly generated audio code sequence, shape (batch_size, sequence_length, num_vq).
|
|
audio_input_ids (torch.LongTensor): Updated input IDs including condition and generated audio codes, shape (batch_size, full_sequence_length, num_vq).
|
|
past_key_values (Tuple[Tuple[torch.FloatTensor]]): Tuple containing pre-computed keys and values used for attention mechanism. Each element has shape (batch_size, num_heads, sequence_length, embed_size_per_head).
|
|
finished (bool): Boolean indicating whether generation is complete.
|
|
|
|
"""
|
|
|
|
new_ids: torch.LongTensor = None
|
|
audio_input_ids: torch.LongTensor = None
|
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
|
finished: bool = None
|
|
|
|
|
|
def make_streaming_chunk_mask_generation(
|
|
inputs_embeds: torch.Tensor,
|
|
past_seen_tokens: int,
|
|
streaming_tts_text_mask: torch.Tensor,
|
|
streaming_reserved_length: int = 300,
|
|
streaming_audio_chunk_size: int = 50,
|
|
streaming_text_chunk_size: int = 10,
|
|
num_spk_emb: int = 1,
|
|
use_spk_emb: bool = True,
|
|
) -> torch.Tensor:
|
|
"""
|
|
In streaming audio generation, determine which `text` positions the TTS model can attend to when generating each chunk of `audio` tokens.
|
|
|
|
This function creates a mask that allows the model to attend to a specific chunk of text
|
|
tokens when generating each chunk of audio tokens, enabling streaming TTS generation.
|
|
|
|
Args:
|
|
inputs_embeds (torch.Tensor): Input embeddings tensor.
|
|
past_seen_tokens (int): Number of tokens already seen by the model.
|
|
streaming_tts_text_mask (torch.Tensor): Mask for the text tokens.
|
|
streaming_reserved_length (int, optional): Number of reserved tokens for streaming. Defaults to 300.
|
|
streaming_text_chunk_size (int, optional): Size of each text chunk. Defaults to 7.
|
|
|
|
Returns:
|
|
torch.Tensor: Causal mask for streaming TTS generation, shape is [batch_size=1, 1, seq_len=1, past_seen_tokens+1]
|
|
|
|
Raises:
|
|
AssertionError: If the batch size is not 1 (only supports batch size of 1 for inference).
|
|
"""
|
|
assert inputs_embeds.shape[0] == 1
|
|
|
|
dtype = inputs_embeds.dtype
|
|
device = inputs_embeds.device
|
|
min_dtype = torch.finfo(dtype).min
|
|
|
|
# Add `1` to the past seen tokens to account for new `tokens` during `generate`
|
|
causal_mask = torch.full(
|
|
(1, past_seen_tokens + inputs_embeds.shape[1]),
|
|
fill_value=0,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
|
|
# Calculate the start of invisible text tokens
|
|
invisible_text_tokens_start = (
|
|
min(
|
|
math.ceil(
|
|
(past_seen_tokens - streaming_reserved_length)
|
|
/ streaming_audio_chunk_size
|
|
)
|
|
* streaming_text_chunk_size,
|
|
streaming_reserved_length,
|
|
)
|
|
+ 1
|
|
+ num_spk_emb * use_spk_emb
|
|
) # Add 1 for [Stts] and N for [spk_emb] tokens if `use_spk_emb` is True
|
|
|
|
invisible_text_tokens_end = (
|
|
streaming_reserved_length + 1 + num_spk_emb * use_spk_emb + 1
|
|
) # Add 1 for [Ptts] (aka `audio_bos_token_id`)
|
|
|
|
# Set invisible text tokens to min_dtype (effectively -inf)
|
|
causal_mask[0, invisible_text_tokens_start:invisible_text_tokens_end] = min_dtype
|
|
|
|
# Mask padding positions in the text mask
|
|
causal_mask[
|
|
0, 0 : 1 + num_spk_emb * use_spk_emb + streaming_reserved_length + 1
|
|
].masked_fill_(streaming_tts_text_mask == 0, min_dtype)
|
|
|
|
# Add extra dimensions for batch and heads
|
|
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
|
|
|
|
return causal_mask
|
|
|
|
|
|
# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py`
|
|
class ConvNeXtBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
intermediate_dim: int,
|
|
kernel: int,
|
|
dilation: int,
|
|
layer_scale_init_value: float = 1e-6,
|
|
):
|
|
# ConvNeXt Block copied from Vocos.
|
|
super().__init__()
|
|
self.dwconv = nn.Conv1d(
|
|
dim,
|
|
dim,
|
|
kernel_size=kernel,
|
|
padding=dilation * (kernel // 2),
|
|
dilation=dilation,
|
|
groups=dim,
|
|
)
|
|
|
|
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
|
self.pwconv1 = nn.Linear(dim, intermediate_dim)
|
|
self.act = nn.GELU()
|
|
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
|
self.coef = (
|
|
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
|
if layer_scale_init_value > 0
|
|
else None
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor, cond=None) -> torch.Tensor:
|
|
residual = x
|
|
|
|
y = self.dwconv(x)
|
|
y.transpose_(1, 2) # (B, C, T) -> (B, T, C)
|
|
x = self.norm(y)
|
|
del y
|
|
y = self.pwconv1(x)
|
|
del x
|
|
x = self.act(y)
|
|
del y
|
|
y = self.pwconv2(x)
|
|
del x
|
|
if self.coef is not None:
|
|
y *= self.coef
|
|
y.transpose_(1, 2) # (B, T, C) -> (B, C, T)
|
|
|
|
x = y + residual
|
|
del y
|
|
|
|
return x
|
|
|
|
|
|
# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py`
|
|
class DVAEDecoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
idim: int,
|
|
odim: int,
|
|
n_layer=12,
|
|
bn_dim=64,
|
|
hidden=256,
|
|
kernel=7,
|
|
dilation=2,
|
|
up=False,
|
|
):
|
|
super().__init__()
|
|
self.up = up
|
|
self.conv_in = nn.Sequential(
|
|
nn.Conv1d(idim, bn_dim, 3, 1, 1),
|
|
nn.GELU(),
|
|
nn.Conv1d(bn_dim, hidden, 3, 1, 1),
|
|
)
|
|
self.decoder_block = nn.ModuleList(
|
|
[
|
|
ConvNeXtBlock(
|
|
hidden,
|
|
hidden * 4,
|
|
kernel,
|
|
dilation,
|
|
)
|
|
for _ in range(n_layer)
|
|
]
|
|
)
|
|
self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)
|
|
|
|
def forward(self, x: torch.Tensor, conditioning=None) -> torch.Tensor:
|
|
# B, C, T
|
|
y = self.conv_in(x)
|
|
del x
|
|
for f in self.decoder_block:
|
|
y = f(y, conditioning)
|
|
|
|
x = self.conv_out(y)
|
|
del y
|
|
return x
|
|
|
|
|
|
# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py`
|
|
class GFSQ(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
levels: List[int],
|
|
G: int,
|
|
R: int,
|
|
eps=1e-5,
|
|
transpose=True,
|
|
):
|
|
super(GFSQ, self).__init__()
|
|
self.quantizer = GroupedResidualFSQ(
|
|
dim=dim,
|
|
levels=list(levels),
|
|
num_quantizers=R,
|
|
groups=G,
|
|
)
|
|
self.n_ind = math.prod(levels)
|
|
self.eps = eps
|
|
self.transpose = transpose
|
|
self.G = G
|
|
self.R = R
|
|
|
|
def _embed(self, x: torch.Tensor):
|
|
if self.transpose:
|
|
x = x.transpose(1, 2)
|
|
x = x.view(x.size(0), x.size(1), self.G, self.R).permute(2, 0, 1, 3)
|
|
feat = self.quantizer.get_output_from_indices(x)
|
|
return feat.transpose_(1, 2) if self.transpose else feat
|
|
|
|
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
|
return super().__call__(x)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
if self.transpose:
|
|
x.transpose_(1, 2)
|
|
_, ind = self.quantizer(x)
|
|
ind = ind.permute(1, 2, 0, 3).contiguous()
|
|
ind = ind.view(ind.size(0), ind.size(1), -1)
|
|
return ind.transpose_(1, 2) if self.transpose else ind
|
|
|
|
|
|
# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py`
|
|
class DVAE(nn.Module):
|
|
def __init__(
|
|
self,
|
|
):
|
|
super().__init__()
|
|
|
|
coef = torch.rand(100)
|
|
self.coef = nn.Parameter(coef.unsqueeze(0).unsqueeze_(2))
|
|
|
|
self.downsample_conv = nn.Sequential(
|
|
nn.Conv1d(100, 512, 3, 1, 1),
|
|
nn.GELU(),
|
|
nn.Conv1d(512, 512, 4, 2, 1),
|
|
nn.GELU(),
|
|
)
|
|
|
|
self.encoder = DVAEDecoder(
|
|
idim=512,
|
|
odim=1024,
|
|
hidden=256,
|
|
n_layer=12,
|
|
bn_dim=128,
|
|
)
|
|
|
|
self.decoder = DVAEDecoder(
|
|
idim=512,
|
|
odim=512,
|
|
hidden=256,
|
|
n_layer=12,
|
|
bn_dim=128,
|
|
)
|
|
|
|
self.out_conv = nn.Conv1d(512, 100, 3, 1, 1, bias=False)
|
|
|
|
self.vq_layer = GFSQ(
|
|
dim=1024,
|
|
levels=(5, 5, 5, 5),
|
|
G=2,
|
|
R=2,
|
|
)
|
|
|
|
@torch.inference_mode()
|
|
def forward(
|
|
self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode"
|
|
) -> torch.Tensor:
|
|
if mode == "encode" and hasattr(self, "encoder") and self.vq_layer is not None:
|
|
mel = inp.clone()
|
|
x: torch.Tensor = self.downsample_conv(
|
|
torch.div(mel, self.coef.view(100, 1).expand(mel.shape), out=mel),
|
|
).unsqueeze_(0)
|
|
del mel
|
|
x = self.encoder(x)
|
|
ind = self.vq_layer(x)
|
|
del x
|
|
return ind
|
|
|
|
if self.vq_layer is not None:
|
|
vq_feats = self.vq_layer._embed(inp)
|
|
else:
|
|
vq_feats = inp
|
|
|
|
vq_feats = (
|
|
vq_feats.view(
|
|
(vq_feats.size(0), 2, vq_feats.size(1) // 2, vq_feats.size(2)),
|
|
)
|
|
.permute(0, 2, 3, 1)
|
|
.flatten(2)
|
|
)
|
|
|
|
dec_out = self.out_conv(
|
|
self.decoder(
|
|
x=vq_feats,
|
|
),
|
|
)
|
|
|
|
del vq_feats
|
|
|
|
return torch.mul(dec_out, self.coef, out=dec_out)
|
|
|
|
|
|
# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/processors.py`
|
|
class CustomRepetitionPenaltyLogitsProcessorRepeat:
|
|
def __init__(self, penalty: float, max_input_ids: int, past_window: int):
|
|
if not isinstance(penalty, float) or not (penalty > 0):
|
|
raise ValueError(
|
|
f"`penalty` has to be a strictly positive float, but is {penalty}"
|
|
)
|
|
|
|
self.penalty = penalty
|
|
self.max_input_ids = max_input_ids
|
|
self.past_window = past_window
|
|
|
|
def __call__(
|
|
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
|
) -> torch.FloatTensor:
|
|
if input_ids.size(1) > self.past_window:
|
|
input_ids = input_ids.narrow(1, -self.past_window, self.past_window)
|
|
freq = F.one_hot(input_ids, scores.size(1)).sum(1)
|
|
if freq.size(0) > self.max_input_ids:
|
|
freq.narrow(
|
|
0, self.max_input_ids, freq.size(0) - self.max_input_ids
|
|
).zero_()
|
|
alpha = torch.pow(self.penalty, freq)
|
|
scores = scores.contiguous()
|
|
inp = scores.multiply(alpha)
|
|
oth = scores.divide(alpha)
|
|
con = scores < 0
|
|
out = torch.where(con, inp, oth)
|
|
del inp, oth, scores, con, alpha
|
|
return out
|
|
|
|
|
|
class ConditionalChatTTS(PreTrainedModel):
|
|
"""A conditional text-to-speech model that can generate speech from text with speaker conditioning.
|
|
|
|
This model extends PreTrainedModel to provide text-to-speech capabilities with:
|
|
- LLM hidden state conditioning
|
|
- Streaming generation
|
|
|
|
The model uses a transformer architecture with LLM hidden states and can operate in both
|
|
streaming and non-streaming modes for flexible deployment.
|
|
|
|
The model process sequence in the following format:
|
|
| text bos token | LLM embedding projected to tts embedding space | text tokens (fixed length, reserved for future tokens) | audio bos token | audio tokens (audio token length is not fixed)| audio eos token |
|
|
|
|
The format is designed to support LLM-conditioned streaming audio generation.
|
|
|
|
Usage:
|
|
To support streaming generation, two global variables should be maintained outside of the model.
|
|
1. `audio_input_ids`: stores *discrete* audio codes. It is a tensor with shape [1, sequence length+1, num_vq].
|
|
2. `past_key_values`: stores the KV cache for both text tokens and audio codes. It is a list of tuples, each tuple contains two tensors with shape [1, num_attention_heads, sequence length, hidden_size // num_attention_heads]
|
|
|
|
where `num_vq` is the number of audio codebooks, in default setting, it is `4`.
|
|
|
|
1. Create an empty `past_key_values` with
|
|
```python
|
|
initial_kv_cache_length = 1 + model.num_spk_embs + model.streaming_text_reserved_len # where `1` denotes the `bos` token
|
|
dtype = model.emb_text.weight.dtype
|
|
device = model.emb_text.weight.device
|
|
past_key_values = [
|
|
(
|
|
torch.zeros(1, model.config.num_attention_heads, initial_kv_cache_length, model.config.hidden_size // model.config.num_attention_heads, dtype=dtype, device=device),
|
|
torch.zeros(1, model.config.num_attention_heads, initial_kv_cache_length, model.config.hidden_size // model.config.num_attention_heads, dtype=dtype, device=device)
|
|
)
|
|
for _ in range(model.config.num_hidden_layers)
|
|
]
|
|
|
|
2. At the same time, create an empty `audio_input_ids` with shape [1, sequence length, num_vq], `num_vq` denotes multiple layer audio codebooks. But here we also include text tokens in the sequence, but they will be zeros, and will not be used, just a placeholder.
|
|
|
|
```python
|
|
initial_audio_input_ids_length = 1 + model.num_spk_embs + model.streaming_text_reserved_len + 1
|
|
# [bos token, speaker embeddings, text tokens, audio bos token]
|
|
audio_input_ids = torch.zeros(batch_size=1, initial_audio_input_ids_length, model.num_vq)
|
|
```
|
|
|
|
2. Prefill some text tokens to TTS model (for example, 10 tokens) using `prefill_text` method.
|
|
|
|
```python
|
|
outputs = llm.generate(**kwargs)
|
|
llm_tokens = some_function_to_extract_llm_tokens(outputs)
|
|
lm_spk_emb_last_hidden_states = some_function_to_extract_lm_spk_emb_last_hidden_states(outputs)
|
|
tts_text_input_ids = tts_tokenizer.encode(llm_tokenizer.decode(llm_tokens))
|
|
# here assume we are prefilling text token 0 to text token 9 (included), totally 10 tokens.
|
|
begin = 0
|
|
end = 9+1
|
|
position_ids = torch.arange(begin, end, dtype=torch.long, device=device)
|
|
|
|
past_key_values = model.prefill_text(
|
|
input_ids=tts_text_input_ids,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
lm_spk_emb_last_hidden_states=lm_spk_emb_last_hidden_states,
|
|
)
|
|
```
|
|
|
|
3. Make a `streaming_tts_text_mask` to denote which position contains valid text tokens, similar to `attention_mask` in standard causal attention.
|
|
|
|
```python
|
|
streaming_tts_text_mask = torch.zeros(model.streaming_reserved_length)
|
|
streaming_tts_text_mask[0:end] = 1 # denotes these post
|
|
```
|
|
|
|
3. Generate audio codes using `generate` method.
|
|
|
|
```python
|
|
outputs = model.generate(
|
|
input_ids=audio_input_ids,
|
|
past_key_values=past_key_values,
|
|
streaming_tts_text_mask=streaming_tts_text_mask,
|
|
max_new_token=50,
|
|
)
|
|
|
|
# update past_key_values and input_ids
|
|
past_key_values = outputs.past_key_values
|
|
audio_input_ids = outputs.input_ids
|
|
```
|
|
|
|
The `past_key_values` is extended by `max_new_token=50`, and `audio_input_ids` is also extended by `max_new_token=50` after `generate` calling.
|
|
|
|
4. Notice that after prefilling `10` text tokens, the model can generate up to `50` audio tokens, if you want to generate more audio tokens, you need to prefill next `10` text tokens. And it is okay to only generate `25` audio tokens for faster initial response.
|
|
|
|
5. Repeat steps `2,3,4` as needed in your streaming audio generation cases, but ensure usage complies with the following guidelines discussed above.
|
|
"""
|
|
|
|
config_class = PretrainedConfig
|
|
_no_split_modules = []
|
|
|
|
def __init__(self, config: PretrainedConfig):
|
|
super().__init__(config)
|
|
|
|
self.use_speaker_embedding = config.use_speaker_embedding
|
|
self.use_llm_hidden_state = config.use_llm_hidden_state
|
|
self.num_spk_embs = config.num_spk_embs
|
|
self.spk_emb_token_id = config.spk_emb_token_id
|
|
|
|
self.use_text = config.use_text
|
|
self.streaming = config.streaming
|
|
self.streaming_text_chunk_size = config.streaming_text_chunk_size
|
|
self.streaming_audio_chunk_size = config.streaming_audio_chunk_size
|
|
self.streaming_text_reserved_len = config.streaming_text_reserved_len
|
|
self.audio_bos_token_id = config.audio_bos_token_id
|
|
self.num_mel_bins = config.num_mel_bins
|
|
self.num_vq = config.num_vq
|
|
self.num_audio_tokens = config.num_audio_tokens
|
|
|
|
self.top_p = config.top_p
|
|
self.top_k = config.top_k
|
|
self.repetition_penalty = config.repetition_penalty
|
|
|
|
if self.config.use_mlp:
|
|
self.projector = MultiModalProjector(config.llm_dim, config.hidden_size)
|
|
else:
|
|
self.projector = nn.Linear(config.llm_dim, config.hidden_size, bias=False)
|
|
self.emb_code = nn.ModuleList(
|
|
[
|
|
nn.Embedding(config.num_audio_tokens, config.hidden_size)
|
|
for _ in range(config.num_vq)
|
|
]
|
|
)
|
|
self.emb_text = nn.Embedding(config.num_text_tokens, config.hidden_size)
|
|
self.head_code = nn.ModuleList(
|
|
[
|
|
weight_norm(
|
|
nn.Linear(config.hidden_size, config.num_audio_tokens, bias=False),
|
|
name="weight",
|
|
)
|
|
for _ in range(config.num_vq)
|
|
]
|
|
)
|
|
|
|
dvae = DVAE()
|
|
self.dvae = dvae
|
|
|
|
model_config = LlamaConfig(
|
|
hidden_size=config.hidden_size,
|
|
intermediate_size=config.intermediate_size,
|
|
num_attention_heads=config.num_attention_heads,
|
|
num_hidden_layers=config.num_hidden_layers,
|
|
max_position_embeddings=config.max_position_embeddings,
|
|
attn_implementation=config.attn_implementation,
|
|
)
|
|
|
|
model = LlamaModel(model_config)
|
|
self.model = model
|
|
|
|
@torch.inference_mode()
|
|
def merge_inputs_embeds(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
lm_spk_emb_last_hidden_states: Optional[torch.Tensor] = None,
|
|
):
|
|
"""Merge `input_ids` and `lm_spk_emb_last_hidden_states` to `inputs_embeds`.
|
|
|
|
Args:
|
|
input_ids (torch.Tensor): Input token IDs.
|
|
lm_spk_emb_last_hidden_states (Optional[torch.Tensor], optional): Last hidden states of speaker embeddings from the language model. Defaults to None.
|
|
|
|
Raises:
|
|
NotImplementedError: If speaker embedding is not used and language model hidden states are not implemented.
|
|
|
|
Returns:
|
|
torch.Tensor: Prepared input embeddings for the model.
|
|
"""
|
|
assert input_ids.shape[0] == 1
|
|
|
|
# Embed input_ids to input_embeds
|
|
inputs_embeds = self.emb_text(input_ids)
|
|
|
|
# Inject speaker embedding to input_embeds if it exists
|
|
if self.use_speaker_embedding:
|
|
spk_emb_mask = input_ids == self.spk_emb_token_id
|
|
if spk_emb_mask.any():
|
|
assert lm_spk_emb_last_hidden_states is not None
|
|
# Project spk emb to tts hidden size first, [batch_size, num_spk_emb, llm_dim] -> [batch_size, num_spk_emb, self.hidden_size]
|
|
lm_spk_emb_last_hidden_states = lm_spk_emb_last_hidden_states.to(
|
|
self.projector.linear1.weight.dtype
|
|
)
|
|
projected_spk_emb = self.projector(lm_spk_emb_last_hidden_states)
|
|
projected_spk_emb = F.normalize(projected_spk_emb, p=2, dim=-1)
|
|
apply_spk_emb(
|
|
input_ids=input_ids,
|
|
spk_emb=projected_spk_emb,
|
|
input_embeds=inputs_embeds,
|
|
spk_emb_token_id=self.spk_emb_token_id,
|
|
num_spk_embs=self.num_spk_embs,
|
|
)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
return inputs_embeds
|
|
|
|
@torch.inference_mode()
|
|
def prefill_text(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
position_ids: torch.LongTensor,
|
|
past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
lm_spk_emb_last_hidden_states: Optional[torch.Tensor] = None,
|
|
):
|
|
"""Prefill a chunk of new text tokens in streaming setting.
|
|
Specifically speaking, update `past_key_values` using new text tokens, then the model will read the new text tokens.
|
|
|
|
Args:
|
|
input_ids (Tensor): Tensor of shape [batch_size, seq_len]
|
|
position_ids (LongTensor): Tensor of shape [batch_size, seq_len]
|
|
past_key_values (List[Tuple[Tensor]]): KV Cache of all layers, each layer is a tuple (Tensor, Tensor) denoting keys and values. Each tensor is of seq_len = `self.streaming_text_reserved_len`. `past_key_values` will be updated.
|
|
lm_spk_emb_last_hidden_states (Tensor, optional): Tensor of shape [batch_size, num_spk_emb, llm_dim]. Defaults to None.
|
|
|
|
Note that all `batch_size` should be `1`.
|
|
"""
|
|
assert input_ids.shape[0] == 1
|
|
assert past_key_values is not None
|
|
|
|
# Merge text and LLM embeddings
|
|
inputs_embeds = self.merge_inputs_embeds(
|
|
input_ids=input_ids,
|
|
lm_spk_emb_last_hidden_states=lm_spk_emb_last_hidden_states,
|
|
)
|
|
|
|
# Clone KV Cache
|
|
past_key_values_for_prefill = []
|
|
for i in range(len(past_key_values)):
|
|
past_key_values_for_prefill.append(
|
|
(
|
|
past_key_values[i][0][:, :, : position_ids[:, 0], :].clone(),
|
|
past_key_values[i][1][:, :, : position_ids[:, 0], :].clone(),
|
|
)
|
|
)
|
|
|
|
# ModelMiniCPMVBaseModel
|
|
outputs_prefill: BaseModelOutputWithPast = self.model(
|
|
attention_mask=None, # because for text, it is standard causal attention mask, do nothing
|
|
position_ids=position_ids, # position_ids denotes the position of new text tokens in the sequence
|
|
past_key_values=past_key_values_for_prefill, # `past_key_values` will be updated by the model
|
|
inputs_embeds=inputs_embeds, # contains text and language model embedding
|
|
use_cache=True,
|
|
output_attentions=False,
|
|
cache_position=position_ids, # which new positions will use this cache, basically the same as position_ids
|
|
)
|
|
|
|
# Get model updated KV Cache
|
|
past_key_values_for_prefill_updated = outputs_prefill.past_key_values
|
|
|
|
# Update generated KV Cache to input `past_key_values`
|
|
for layer_idx in range(len(past_key_values)):
|
|
# Update keys
|
|
past_key_values[layer_idx][0][
|
|
:, :, position_ids[:, 0] : position_ids[:, -1] + 1, :
|
|
] = past_key_values_for_prefill_updated[layer_idx][0][
|
|
:, :, position_ids[:, 0] : position_ids[:, -1] + 1
|
|
].clone()
|
|
# Update values
|
|
past_key_values[layer_idx][1][
|
|
:, :, position_ids[:, 0] : position_ids[:, -1] + 1, :
|
|
] = past_key_values_for_prefill_updated[layer_idx][1][
|
|
:, :, position_ids[:, 0] : position_ids[:, -1] + 1
|
|
].clone()
|
|
|
|
# TODO: del past_key_values_for_prefill_updated recursively
|
|
# TODO: del outputs_prefill recursively
|
|
|
|
return past_key_values
|
|
|
|
@torch.inference_mode()
|
|
def prefill_audio_ids(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
streaming_tts_text_mask=None,
|
|
add_audio_bos: bool = True,
|
|
):
|
|
"""Prefill a chunk of audio ids to the model. Used in sliding-window long audio generation.
|
|
Specifically, prefill many audio ids (typically from last window) to the model in the new window.
|
|
|
|
Args:
|
|
input_ids (torch.Tensor): (1, seq_len, num_vq) Audio input token ids.
|
|
past_key_values (List[Tuple[torch.Tensor, torch.Tensor]]): Past key values for attention mechanism.
|
|
"""
|
|
assert input_ids.shape[0] == 1
|
|
assert past_key_values is not None
|
|
|
|
code_emb = [self.emb_code[i](input_ids[:, :, i]) for i in range(self.num_vq)]
|
|
inputs_embeds = torch.stack(code_emb, 3).sum(3) # [1,seq_len,768]
|
|
input_len = input_ids.shape[1]
|
|
|
|
if add_audio_bos:
|
|
narrowed_input_ids = torch.tensor(
|
|
[[self.audio_bos_token_id]], dtype=torch.long, device=self.device
|
|
)
|
|
bos_inputs_embeds = self.emb_text(narrowed_input_ids)
|
|
inputs_embeds = torch.cat([bos_inputs_embeds, inputs_embeds], dim=1)
|
|
input_len += 1
|
|
|
|
past_key_values_length = past_key_values[0][0].shape[2]
|
|
position_ids = torch.arange(
|
|
past_key_values_length,
|
|
past_key_values_length + input_len,
|
|
dtype=torch.long,
|
|
device=self.device,
|
|
).unsqueeze(0)
|
|
|
|
cache_position = position_ids.clone()
|
|
causal_mask = make_streaming_chunk_mask_generation(
|
|
inputs_embeds=inputs_embeds,
|
|
past_seen_tokens=past_key_values[0][0].shape[2],
|
|
streaming_tts_text_mask=streaming_tts_text_mask,
|
|
streaming_reserved_length=self.streaming_text_reserved_len,
|
|
streaming_text_chunk_size=self.streaming_text_chunk_size,
|
|
) # [1, 1, 1, past_key_values_length + input_len]
|
|
|
|
# Model forward
|
|
outputs: BaseModelOutputWithPast = self.model(
|
|
attention_mask=causal_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=True,
|
|
output_attentions=False,
|
|
cache_position=cache_position,
|
|
)
|
|
past_key_values = outputs.past_key_values
|
|
return past_key_values
|
|
|
|
@torch.inference_mode()
|
|
def generate(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
temperature: torch.Tensor,
|
|
eos_token: Union[int, torch.Tensor],
|
|
streaming_tts_text_mask=None,
|
|
force_no_stop=False,
|
|
min_new_token=10,
|
|
max_new_token=50,
|
|
logits_warpers: List[LogitsWarper] = [],
|
|
logits_processors: List[CustomRepetitionPenaltyLogitsProcessorRepeat] = [],
|
|
show_tqdm=False,
|
|
):
|
|
"""Generate audio codes in streaming setting or non-streaming setting.
|
|
Specifically speaking, generate audio codes when not all text tokens are prefilled.
|
|
|
|
Always pass a valid `past_key_values` to the method. The method does not do `prefill` by itself. It relies on `prefill_text` method to provide valid `past_key_values`. Please refer to docstring of this class for more details.
|
|
|
|
In this method, we borrowed a lot of codes from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/gpt.py`.
|
|
|
|
Args:
|
|
input_ids (torch.Tensor): Input token ids.
|
|
past_key_values (List[Tuple[torch.Tensor, torch.Tensor]]): Past key values for attention mechanism.
|
|
temperature (torch.Tensor): Temperature for sampling.
|
|
eos_token (Union[int, torch.Tensor]): End of sequence token.
|
|
streaming_tts_text_mask (Optional[torch.Tensor], optional): Mask for streaming TTS text. Defaults to None.
|
|
max_new_token (int, optional): Maximum number of new tokens to generate. Defaults to 50.
|
|
logits_warpers (List[LogitsWarper], optional): List of logits warpers. Defaults to [].
|
|
logits_processors (List[CustomRepetitionPenaltyLogitsProcessorRepeat], optional): List of logits processors. Defaults to [].
|
|
show_tqdm (bool, optional): Whether to show progress bar. Defaults to True.
|
|
|
|
Returns:
|
|
GenerationOutputs: Generation outputs.
|
|
"""
|
|
|
|
# We only support batch size `1` for now
|
|
assert input_ids.shape[0] == 1
|
|
assert past_key_values is not None
|
|
|
|
# fix: this should not be `input_ids.shape[1]`
|
|
# start_idx = input_ids.shape[1]
|
|
start_idx = (
|
|
1
|
|
+ self.num_spk_embs * self.use_speaker_embedding
|
|
+ self.streaming_text_reserved_len
|
|
+ 1
|
|
)
|
|
|
|
finish = torch.zeros(input_ids.shape[0], device=input_ids.device).bool()
|
|
|
|
temperature = (
|
|
temperature.unsqueeze(0)
|
|
.expand(input_ids.shape[0], -1)
|
|
.contiguous()
|
|
.view(-1, 1)
|
|
)
|
|
|
|
progress = input_ids.shape[1]
|
|
|
|
# Pre-allocate input_ids, shape is [batch_size=1, max_possible_seq_len, self.num_vqs]
|
|
input_ids_buf = torch.zeros(
|
|
input_ids.shape[0], # batch_size
|
|
progress
|
|
+ max_new_token, # max_possible_seq_len = input_ids.shape[1] + max_new_token
|
|
input_ids.shape[2], # self.num_vqs
|
|
dtype=input_ids.dtype,
|
|
device=input_ids.device,
|
|
)
|
|
|
|
# Copy existing `input_ids` to `input_ids_buf`
|
|
input_ids_buf.narrow(1, 0, progress).copy_(input_ids)
|
|
|
|
del input_ids
|
|
input_ids = input_ids_buf.narrow(1, 0, progress)
|
|
|
|
pbar: Optional[tqdm] = None
|
|
if show_tqdm:
|
|
pbar = tqdm(
|
|
total=max_new_token,
|
|
desc="code",
|
|
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]",
|
|
)
|
|
|
|
condition_length = (
|
|
1
|
|
+ self.num_spk_embs * self.use_speaker_embedding
|
|
+ self.streaming_text_reserved_len
|
|
+ 1
|
|
)
|
|
|
|
for i in range(max_new_token):
|
|
# Prepare generation inputs
|
|
audio_bos = False
|
|
|
|
# If this is the first audio token, the case is SPECIAL
|
|
if progress == condition_length:
|
|
audio_bos = True
|
|
|
|
assert progress == (
|
|
past_key_values[0][0].shape[2] + 1
|
|
) # If you are using according to the guidelines, this should be passed.
|
|
|
|
if audio_bos:
|
|
# Generate the first token, activate the model with `self.audio_bos_token_id`, the model will predict
|
|
# a new audio token. This is a special case because without the `audio bos token`, it is impossible
|
|
# to generate the first audio token in our streaming setting.
|
|
narrowed_input_ids = torch.tensor(
|
|
[[self.audio_bos_token_id]], dtype=torch.long, device=self.device
|
|
)
|
|
inputs_embeds = self.emb_text(narrowed_input_ids)
|
|
del narrowed_input_ids
|
|
else:
|
|
# Generate the following audio tokens, it is applicable to all other cases, including second and the
|
|
# following calling of `generate`.
|
|
narrowed_input_ids = input_ids.narrow(
|
|
dim=1, start=input_ids.shape[1] - 1, length=1
|
|
)
|
|
code_emb = [
|
|
self.emb_code[i](narrowed_input_ids[:, :, i])
|
|
for i in range(self.num_vq)
|
|
]
|
|
inputs_embeds = torch.stack(code_emb, 3).sum(3)
|
|
|
|
position_ids = torch.tensor(
|
|
[past_key_values[0][0].shape[2]], dtype=torch.long, device=self.device
|
|
).unsqueeze(0)
|
|
|
|
cache_position = position_ids.clone()
|
|
|
|
# Make causal mask
|
|
causal_mask = make_streaming_chunk_mask_generation(
|
|
inputs_embeds=inputs_embeds,
|
|
past_seen_tokens=past_key_values[0][0].shape[2],
|
|
streaming_tts_text_mask=streaming_tts_text_mask,
|
|
streaming_reserved_length=self.streaming_text_reserved_len,
|
|
streaming_text_chunk_size=self.streaming_text_chunk_size,
|
|
)
|
|
|
|
# Model forward
|
|
outputs: BaseModelOutputWithPast = self.model(
|
|
attention_mask=causal_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=True,
|
|
output_attentions=False,
|
|
cache_position=cache_position,
|
|
)
|
|
|
|
del position_ids
|
|
del inputs_embeds
|
|
del cache_position
|
|
del causal_mask
|
|
|
|
hidden_states = outputs.last_hidden_state
|
|
past_key_values = outputs.past_key_values
|
|
|
|
with P.cached():
|
|
logits = torch.empty(
|
|
hidden_states.size(0),
|
|
hidden_states.size(1),
|
|
self.num_audio_tokens,
|
|
self.num_vq,
|
|
dtype=torch.float,
|
|
device=self.device,
|
|
)
|
|
for num_vq_iter in range(self.num_vq):
|
|
x: torch.Tensor = self.head_code[num_vq_iter](hidden_states)
|
|
logits[..., num_vq_iter] = x
|
|
del x
|
|
|
|
del hidden_states
|
|
|
|
# logits = logits[:, -1].float()
|
|
logits = logits.narrow(1, -1, 1).squeeze_(1).float()
|
|
|
|
# logits = rearrange(logits, "b c n -> (b n) c")
|
|
logits = logits.permute(0, 2, 1)
|
|
logits = logits.reshape(-1, logits.size(2))
|
|
# logits_token = rearrange(input_ids[:, start_idx:], "b c n -> (b n) c")
|
|
input_ids_sliced = input_ids.narrow(
|
|
1,
|
|
start_idx,
|
|
input_ids.size(1) - start_idx,
|
|
).permute(0, 2, 1)
|
|
logits_token = input_ids_sliced.reshape(
|
|
input_ids_sliced.size(0) * input_ids_sliced.size(1),
|
|
-1,
|
|
).to(self.device)
|
|
del input_ids_sliced
|
|
|
|
logits /= temperature
|
|
|
|
if not audio_bos:
|
|
for logitsProcessors in logits_processors:
|
|
logits = logitsProcessors(logits_token, logits)
|
|
if not audio_bos:
|
|
for logitsWarpers in logits_warpers:
|
|
logits = logitsWarpers(logits_token, logits)
|
|
|
|
del logits_token
|
|
|
|
if i < min_new_token:
|
|
logits[:, eos_token] = -torch.inf
|
|
|
|
if force_no_stop:
|
|
logits[:, eos_token] = -torch.inf
|
|
|
|
scores = F.softmax(logits, dim=-1)
|
|
|
|
del logits
|
|
idx_next = torch.multinomial(scores, num_samples=1) # .to(finish.device)
|
|
|
|
del scores
|
|
|
|
# idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
|
|
idx_next = idx_next.view(-1, self.num_vq)
|
|
finish_or = idx_next.eq(eos_token).any(1)
|
|
finish.logical_or_(finish_or)
|
|
|
|
del finish_or
|
|
# Store new `token` into `input_ids_buf`
|
|
input_ids_buf.narrow(1, progress, 1).copy_(idx_next.unsqueeze_(1))
|
|
|
|
if i == 0 and finish.any():
|
|
# raise Exception
|
|
break
|
|
|
|
del idx_next
|
|
progress += 1
|
|
input_ids = input_ids_buf.narrow(1, 0, progress)
|
|
|
|
if finish.all():
|
|
break
|
|
|
|
if pbar is not None:
|
|
pbar.update(1)
|
|
|
|
if pbar is not None:
|
|
pbar.close()
|
|
|
|
if not finish.all():
|
|
if show_tqdm:
|
|
logger.info(f"incomplete result. hit max_new_token: {max_new_token}")
|
|
|
|
del input_ids_buf
|
|
|
|
if finish.all():
|
|
# the last may contains eos token
|
|
genrated_input_ids = input_ids[:, condition_length:-1, :]
|
|
else:
|
|
# there is no eos token
|
|
genrated_input_ids = input_ids[:, condition_length:, :]
|
|
|
|
return ConditionalChatTTSGenerationOutput(
|
|
new_ids=genrated_input_ids,
|
|
audio_input_ids=input_ids, # for update purpose
|
|
past_key_values=past_key_values, # for update purpose
|
|
finished=finish.all(),
|
|
)
|
|
|
|
@torch.inference_mode()
|
|
def decode_to_mel_specs(
|
|
self,
|
|
result_list: List[torch.Tensor],
|
|
):
|
|
"""Decode discrete audio codes to mel spectrograms.
|
|
|
|
Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/core.py`
|
|
|
|
Args:
|
|
result_list (List[torch.Tensor]): Audio codes output from `generate`.
|
|
|
|
Returns:
|
|
torch.Tensor: Mel spectrograms.
|
|
"""
|
|
|
|
decoder = self.dvae
|
|
max_x_len = -1
|
|
if len(result_list) == 0:
|
|
return np.array([], dtype=np.float32)
|
|
for result in result_list:
|
|
if result.size(0) > max_x_len:
|
|
max_x_len = result.size(0)
|
|
batch_result = torch.zeros(
|
|
(len(result_list), result_list[0].size(1), max_x_len),
|
|
dtype=result_list[0].dtype,
|
|
device=result_list[0].device,
|
|
)
|
|
for i in range(len(result_list)):
|
|
src = result_list[i]
|
|
batch_result[i].narrow(1, 0, src.size(0)).copy_(src.permute(1, 0))
|
|
del src
|
|
|
|
mel_specs = decoder(batch_result)
|
|
del batch_result
|
|
return mel_specs
|
|
|
|
|
|
# Copied from transformers.models.whisper.modeling_whisper.WhisperEncoderLayer and add use_cache for streaming inference
|
|
class MiniCPMWhisperEncoderLayer(nn.Module):
|
|
def __init__(self, config: WhisperConfig, layer_idx: int = None):
|
|
super().__init__()
|
|
self.embed_dim = config.d_model
|
|
self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
|
|
embed_dim=self.embed_dim,
|
|
num_heads=config.encoder_attention_heads,
|
|
dropout=config.attention_dropout,
|
|
config=config,
|
|
layer_idx=layer_idx,
|
|
)
|
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
self.dropout = config.dropout
|
|
self.activation_fn = ACT2FN[config.activation_function]
|
|
self.activation_dropout = config.activation_dropout
|
|
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
|
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
layer_head_mask: torch.Tensor,
|
|
output_attentions: bool = False,
|
|
past_key_values: Optional[EncoderDecoderCache] = None,
|
|
use_cache: Optional[bool] = False,
|
|
) -> torch.Tensor:
|
|
r"""
|
|
Args:
|
|
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, embed_dim)`):
|
|
Hidden states to be fed into the encoder layer.
|
|
attention_mask (`torch.FloatTensor` of shape `(batch_size, 1, tgt_len, src_len)`):
|
|
Attention mask where padding elements are indicated by large negative values.
|
|
layer_head_mask (`torch.FloatTensor` of shape `(encoder_attention_heads,)`):
|
|
Mask to nullify selected heads of the attention modules.
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attention weights.
|
|
past_key_values (`EncoderDecoderCache`, *optional*):
|
|
Past key-value pairs used for incremental decoding.
|
|
use_cache (`bool`, *optional*):
|
|
Whether or not to return updated `past_key_values` for caching.
|
|
|
|
Returns:
|
|
A tuple of shape `(hidden_states, optional(attn_weights), optional(past_key_values))`.
|
|
"""
|
|
residual = hidden_states
|
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
hidden_states, attn_weights, past_key_values = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
layer_head_mask=layer_head_mask,
|
|
output_attentions=output_attentions,
|
|
past_key_value=past_key_values,
|
|
)
|
|
hidden_states = nn.functional.dropout(
|
|
hidden_states, p=self.dropout, training=False
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
|
hidden_states = nn.functional.dropout(
|
|
hidden_states, p=self.activation_dropout, training=False
|
|
)
|
|
hidden_states = self.fc2(hidden_states)
|
|
hidden_states = nn.functional.dropout(
|
|
hidden_states, p=self.dropout, training=False
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
|
|
if hidden_states.dtype == torch.float16 and (
|
|
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
|
):
|
|
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
|
hidden_states = torch.clamp(
|
|
hidden_states, min=-clamp_value, max=clamp_value
|
|
)
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (attn_weights,)
|
|
|
|
if use_cache:
|
|
outputs += (past_key_values,)
|
|
|
|
return outputs
|
|
|
|
|
|
# Copied from from transformers.models.whisper.modeling_whisper.WhisperEncoder and add use_cache for streaming inference
|
|
class MiniCPMWhisperEncoder(WhisperEncoder):
|
|
|
|
def __init__(self, config: WhisperConfig):
|
|
super().__init__(config)
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
MiniCPMWhisperEncoderLayer(config, layer_idx=i)
|
|
for i in range(config.encoder_layers)
|
|
]
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_features,
|
|
attention_mask=None,
|
|
head_mask=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
past_key_values: Optional[EncoderDecoderCache] = None,
|
|
use_cache: Optional[bool] = None,
|
|
):
|
|
r"""
|
|
Forward pass of the Whisper encoder.
|
|
|
|
Args:
|
|
input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`):
|
|
Float values of log-mel features extracted from the raw audio waveform. Typically generated
|
|
by a feature extractor (e.g., `WhisperFeatureExtractor`) that processes `.flac` or `.wav`
|
|
files into padded 2D mel spectrogram frames. These features are projected via convolution layers
|
|
(`conv1` and `conv2`) and then transformed into embeddings for the encoder.
|
|
|
|
attention_mask (`torch.Tensor`, *optional*):
|
|
Not used by Whisper for masking `input_features`, but included for API compatibility with
|
|
other models. If provided, it is simply ignored within the model. By default, Whisper
|
|
effectively ignores silence in the input log-mel spectrogram.
|
|
|
|
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
|
Mask to nullify selected attention heads. The elements should be either 1 or 0, where:
|
|
- 1 indicates the head is **not masked**,
|
|
- 0 indicates the head is **masked** (i.e., the attention head is dropped).
|
|
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attention tensors of all encoder layers. If set to `True`, the
|
|
returned tuple (or `BaseModelOutputWithPast`) will contain an additional element with
|
|
attention weights for each encoder layer.
|
|
|
|
output_hidden_states (`bool`, *optional*):
|
|
Whether or not to return the hidden states of all layers. If set to `True`, the returned
|
|
tuple (or `BaseModelOutputWithPast`) will contain a tuple of hidden states, including the
|
|
initial embedding output as well as the outputs of each layer.
|
|
|
|
return_dict (`bool`, *optional*):
|
|
Whether or not to return a `BaseModelOutputWithPast` (a subclass of `ModelOutput`) instead
|
|
of a plain tuple. If set to `True`, the output will be a `BaseModelOutputWithPast` object,
|
|
otherwise it will be a tuple.
|
|
|
|
past_key_values (`EncoderDecoderCache`, *optional*):
|
|
When using caching for faster inference, this is an object that stores the key-value pairs
|
|
for attention states. If provided, the model will append new states to the existing cache
|
|
and return the updated cache. This speeds up sequential decoding or chunked inference.
|
|
|
|
- If `past_key_values` is `None`, no past states are used or returned.
|
|
- If `past_key_values` is not `None` and `use_cache=True`, the model will use the provided
|
|
cache and return the updated cache (as `next_encoder_cache`).
|
|
|
|
use_cache (`bool`, *optional*):
|
|
Whether or not the model should use caching (`past_key_values`) to speed up processing
|
|
during inference. When set to `True`, the model will:
|
|
- Inspect and use `past_key_values` if provided.
|
|
- Return updated `past_key_values` (under the name `next_encoder_cache` in
|
|
`BaseModelOutputWithPast`).
|
|
|
|
Returns:
|
|
`BaseModelOutputWithPast` or `tuple` (depending on `return_dict`):
|
|
If `return_dict=True`, a `BaseModelOutputWithPast` is returned, which contains:
|
|
- **last_hidden_state** (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
The output of the final encoder layer.
|
|
- **hidden_states** (`tuple(torch.FloatTensor)`, *optional*, returned if `output_hidden_states=True`):
|
|
Hidden states of the model at each layer (including the initial projection).
|
|
- **attentions** (`tuple(torch.FloatTensor)`, *optional*, returned if `output_attentions=True`):
|
|
Attention weights from each encoder layer.
|
|
- **past_key_values** (an object of type `EncoderDecoderCache` or `None`, *optional*):
|
|
Updated cache of key-value pairs if `use_cache=True`.
|
|
|
|
If `return_dict=False`, a tuple is returned, where the format is:
|
|
`(last_hidden_state, hidden_states, attentions)`, with `hidden_states` and `attentions`
|
|
only present if their respective `output_*` arguments are set to `True`.
|
|
|
|
"""
|
|
output_attentions = (
|
|
output_attentions
|
|
if output_attentions is not None
|
|
else self.config.output_attentions
|
|
)
|
|
output_hidden_states = (
|
|
output_hidden_states
|
|
if output_hidden_states is not None
|
|
else self.config.output_hidden_states
|
|
)
|
|
return_dict = (
|
|
return_dict if return_dict is not None else self.config.use_return_dict
|
|
)
|
|
|
|
# Ignore copy
|
|
input_features = input_features.to(
|
|
dtype=self.conv1.weight.dtype, device=self.conv1.weight.device
|
|
)
|
|
|
|
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
|
|
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
|
|
|
|
inputs_embeds = inputs_embeds.permute(0, 2, 1)
|
|
|
|
embed_pos = self.embed_positions.weight
|
|
past_key_values_length = 0
|
|
if use_cache:
|
|
if past_key_values is None:
|
|
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
|
|
elif isinstance(past_key_values, list):
|
|
past_key_values = EncoderDecoderCache(
|
|
DynamicCache.from_legacy_cache(past_key_values), DynamicCache()
|
|
)
|
|
elif isinstance(past_key_values, DynamicCache):
|
|
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
|
|
else:
|
|
pass
|
|
past_key_values_length = (
|
|
past_key_values.self_attention_cache.get_usable_length(
|
|
inputs_embeds.shape[1]
|
|
)
|
|
)
|
|
if inputs_embeds.shape[1] + past_key_values_length > embed_pos.shape[0]:
|
|
logger.warning(
|
|
"seems the audio is longer than 30s. repeating the last part of the audio"
|
|
)
|
|
embed_pos_front = embed_pos[past_key_values_length:, :]
|
|
embed_pos = torch.cat(
|
|
(
|
|
embed_pos_front,
|
|
torch.repeat_interleave(
|
|
embed_pos[-1, :].unsqueeze(0),
|
|
inputs_embeds.shape[1]
|
|
- embed_pos.shape[0]
|
|
+ past_key_values_length,
|
|
dim=0,
|
|
),
|
|
)
|
|
)
|
|
else:
|
|
embed_pos = embed_pos[
|
|
past_key_values_length : inputs_embeds.shape[1]
|
|
+ past_key_values_length,
|
|
:,
|
|
]
|
|
else:
|
|
embed_pos = embed_pos[: inputs_embeds.shape[1], :]
|
|
|
|
hidden_states = inputs_embeds + embed_pos
|
|
hidden_states = nn.functional.dropout(
|
|
hidden_states, p=self.dropout, training=False
|
|
)
|
|
|
|
encoder_states = () if output_hidden_states else None
|
|
all_attentions = () if output_attentions else None
|
|
|
|
# check if head_mask has a correct number of layers specified if desired
|
|
if head_mask is not None:
|
|
assert head_mask.size()[0] == (
|
|
len(self.layers)
|
|
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
|
|
|
for idx, encoder_layer in enumerate(self.layers):
|
|
if output_hidden_states:
|
|
encoder_states = encoder_states + (hidden_states,)
|
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
|
to_drop = False
|
|
|
|
# Ignore copy
|
|
if to_drop:
|
|
layer_outputs = (None, None)
|
|
else:
|
|
layer_outputs = encoder_layer(
|
|
hidden_states,
|
|
attention_mask,
|
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
|
output_attentions=output_attentions,
|
|
past_key_values=past_key_values,
|
|
use_cache=use_cache,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if use_cache:
|
|
next_encoder_cache = layer_outputs[2 if output_attentions else 1]
|
|
else:
|
|
next_encoder_cache = None
|
|
|
|
if output_attentions:
|
|
all_attentions = all_attentions + (layer_outputs[1],)
|
|
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
if output_hidden_states:
|
|
encoder_states = encoder_states + (hidden_states,)
|
|
|
|
if not return_dict:
|
|
return tuple(
|
|
v
|
|
for v in [hidden_states, encoder_states, all_attentions]
|
|
if v is not None
|
|
)
|
|
return BaseModelOutputWithPast(
|
|
last_hidden_state=hidden_states,
|
|
hidden_states=encoder_states,
|
|
attentions=all_attentions,
|
|
past_key_values=next_encoder_cache,
|
|
)
|
|
|
|
|
|
class MultiModalProjector(nn.Module):
|
|
def __init__(self, in_dim, out_dim):
|
|
super().__init__()
|
|
self.linear1 = nn.Linear(in_features=in_dim, out_features=out_dim, bias=True)
|
|
self.relu = nn.ReLU()
|
|
self.linear2 = nn.Linear(in_features=out_dim, out_features=out_dim, bias=True)
|
|
|
|
def forward(self, audio_features):
|
|
hidden_states = self.relu(self.linear1(audio_features))
|
|
hidden_states = self.linear2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class MiniCPMO(MiniCPMVBaseModel):
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
) -> None:
|
|
super().__init__(config=config, quant_config=quant_config)
|
|
|
|
self.llm = self.init_llm(config=config, quant_config=quant_config)
|
|
|
|
self.embed_dim = self.llm.config.hidden_size
|
|
|
|
# init vision module
|
|
if self.config.init_vision:
|
|
# print("vision-understanding enabled")
|
|
self.vpm = self.init_vision_module(config=config, quant_config=quant_config)
|
|
self.vision_dim = self.vpm.embed_dim
|
|
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
|
|
|
|
# init audio module
|
|
self.config.init_audio = True
|
|
if self.config.init_audio:
|
|
# print("audio-understanding enabled")
|
|
self.apm = self.init_audio_module()
|
|
audio_output_dim = int(self.apm.config.encoder_ffn_dim // 4)
|
|
self.audio_avg_pooler = nn.AvgPool1d(
|
|
self.config.audio_pool_step, stride=self.config.audio_pool_step
|
|
)
|
|
self.audio_projection_layer = MultiModalProjector(
|
|
in_dim=audio_output_dim, out_dim=self.embed_dim
|
|
)
|
|
self.audio_encoder_layer = -1
|
|
|
|
# init tts module
|
|
self.config.init_tts = False
|
|
logger.info("TTS is disabled for now")
|
|
if self.config.init_tts:
|
|
# print("tts enabled")
|
|
assert (
|
|
_tts_deps
|
|
), "please make sure vector_quantize_pytorch and vocos are installed."
|
|
self.tts = self.init_tts_module()
|
|
|
|
def init_tts_module(self):
|
|
model = ConditionalChatTTS(self.config.tts_config)
|
|
return model
|
|
|
|
def init_audio_module(self):
|
|
model = MiniCPMWhisperEncoder(self.config.audio_config)
|
|
return model
|
|
|
|
def init_llm(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> nn.Module:
|
|
return Qwen2ForCausalLM(config=config, quant_config=quant_config, prefix=prefix)
|
|
|
|
def init_vision_module(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig],
|
|
prefix: str = "",
|
|
):
|
|
if self.config._attn_implementation == "flash_attention_2":
|
|
self.config.vision_config._attn_implementation = "flash_attention_2"
|
|
else:
|
|
self.config.vision_config._attn_implementation = "eager"
|
|
model = Idefics2VisionTransformer(
|
|
config=config.vision_config, quant_config=quant_config, prefix=prefix
|
|
)
|
|
if self.config.drop_vision_last_layer:
|
|
model.encoder.layers = model.encoder.layers[:-1]
|
|
|
|
setattr(model, "embed_dim", model.embeddings.embed_dim)
|
|
setattr(model, "patch_size", model.embeddings.patch_size)
|
|
|
|
return model
|
|
|
|
def init_resampler(
|
|
self,
|
|
embed_dim: int,
|
|
vision_dim: int,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> nn.Module:
|
|
with set_default_torch_dtype(torch.float16):
|
|
# The resampler in 2.6 remains consistent with the one in 2.5.
|
|
resampler = Resampler2_5(
|
|
num_queries=self.config.query_num,
|
|
embed_dim=embed_dim,
|
|
num_heads=embed_dim // 128,
|
|
kv_dim=vision_dim,
|
|
quant_config=quant_config,
|
|
prefix=prefix,
|
|
)
|
|
|
|
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
|
|
|
|
def pad_input_ids(self, input_ids: List[int], mm_input: MultimodalInputs):
|
|
# Get all special token IDs
|
|
im_start_id: int = mm_input.im_start_id
|
|
im_end_id: int = mm_input.im_end_id
|
|
slice_start_id: int = mm_input.slice_start_id
|
|
slice_end_id: int = mm_input.slice_end_id
|
|
|
|
media_token_pairs = [
|
|
(im_start_id, im_end_id),
|
|
(slice_start_id, slice_end_id),
|
|
(mm_input.audio_start_id, mm_input.audio_end_id),
|
|
]
|
|
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
|
|
|
return pattern.pad_input_tokens(input_ids, mm_input)
|
|
|
|
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
|
|
"""
|
|
Computes the output length of the convolutional layers and the output length of the audio encoder
|
|
"""
|
|
input_lengths_after_cnn = (input_lengths - 1) // 2 + 1
|
|
input_lengths_after_pooling = (
|
|
input_lengths_after_cnn - self.config.audio_pool_step
|
|
) // self.config.audio_pool_step + 1
|
|
input_lengths_after_pooling = input_lengths_after_pooling.to(dtype=torch.int32)
|
|
|
|
return input_lengths_after_cnn, input_lengths_after_pooling
|
|
|
|
def get_audio_embedding_streaming(self, multimodal_input: MultimodalInputs):
|
|
r"""
|
|
Extract audio embeddings in a streaming manner using cached key-value pairs.
|
|
|
|
This method processes incoming audio features incrementally and stores/updates `past_key_values`
|
|
for faster inference on subsequent audio frames. It only supports batch_size=1 and is intended
|
|
for streaming scenarios.
|
|
|
|
Args:
|
|
multimodal_input (dict):
|
|
- **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
|
|
- **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
|
|
|
|
Returns:
|
|
List[List[torch.Tensor]]: audio embeddings
|
|
"""
|
|
# print("audio embedding")
|
|
|
|
wavforms = (
|
|
[]
|
|
if multimodal_input.audio_features is None
|
|
else multimodal_input.audio_features
|
|
)
|
|
# list, [[x1, x2], [y1], [z1]]
|
|
audio_feature_lens_raw = (
|
|
[]
|
|
if multimodal_input.audio_feature_lens is None
|
|
else multimodal_input.audio_feature_lens
|
|
)
|
|
|
|
# exist audio
|
|
if len(wavforms) > 0:
|
|
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
|
|
batch_size, _, max_mel_seq_len = wavforms.shape
|
|
assert batch_size == 1
|
|
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
|
|
|
|
if self.audio_past_key_values is not None:
|
|
cache_length = self.audio_past_key_values[0][0].shape[2]
|
|
apm_max_len = self.apm.embed_positions.weight.shape[0]
|
|
if cache_length + max_seq_len >= apm_max_len:
|
|
logger.warning(
|
|
f"audio_past_key_values length {cache_length + max_seq_len} exceed {apm_max_len}, reset."
|
|
)
|
|
self.audio_past_key_values = None
|
|
|
|
audio_outputs = self.apm(
|
|
wavforms, past_key_values=self.audio_past_key_values, use_cache=True
|
|
)
|
|
audio_states = (
|
|
audio_outputs.last_hidden_state
|
|
) # [:, :audio_feat_lengths, :]
|
|
self.audio_past_key_values = audio_outputs.past_key_values
|
|
|
|
audio_embeds = self.audio_projection_layer(audio_states)
|
|
|
|
audio_embeds = audio_embeds.transpose(1, 2)
|
|
audio_embeds = self.audio_avg_pooler(audio_embeds)
|
|
audio_embeds = audio_embeds.transpose(1, 2)
|
|
|
|
_, feature_lens_after_pooling = self._get_feat_extract_output_lengths(
|
|
audio_feature_lens
|
|
)
|
|
|
|
num_audio_tokens = feature_lens_after_pooling
|
|
|
|
final_audio_embeds = []
|
|
idx = 0
|
|
for i in range(len(audio_feature_lens_raw)):
|
|
target_audio_embeds = []
|
|
for _ in range(len(audio_feature_lens_raw[i])):
|
|
target_audio_embeds.append(
|
|
audio_embeds[idx, : num_audio_tokens[idx], :]
|
|
)
|
|
idx += 1
|
|
final_audio_embeds.append(target_audio_embeds)
|
|
return final_audio_embeds
|
|
else:
|
|
return []
|
|
|
|
def subsequent_chunk_mask(
|
|
self,
|
|
size: int,
|
|
chunk_size: int,
|
|
num_left_chunks: int = -1,
|
|
device: torch.device = torch.device("cpu"),
|
|
num_lookhead: int = 0,
|
|
) -> torch.Tensor:
|
|
"""Create mask for subsequent steps (size, size) with chunk size,
|
|
this is for streaming encoder
|
|
|
|
Args:
|
|
size (int): size of mask
|
|
chunk_size (int): size of chunk
|
|
num_left_chunks (int): number of left chunks
|
|
<0: use full chunk
|
|
>=0: use num_left_chunks
|
|
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
|
|
|
Returns:
|
|
torch.Tensor: mask
|
|
|
|
"""
|
|
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
|
|
for i in range(size):
|
|
if num_left_chunks < 0:
|
|
start = 0
|
|
else:
|
|
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
|
|
ending = min((i // chunk_size + 1) * chunk_size + num_lookhead, size)
|
|
ret[i, start:ending] = True
|
|
return ret
|
|
|
|
def get_audio_embedding(self, multimodal_input: MultimodalInputs, chunk_length=-1):
|
|
r"""
|
|
Extract full audio embeddings with optional chunk-based attention.
|
|
|
|
This method computes embeddings for all audio frames at once, either using full attention (when
|
|
`chunk_length` is -1) or chunk-based attention (when `chunk_length` is a positive number). It does
|
|
not use key-value caching and is suitable for non-streaming inference.
|
|
|
|
Args:
|
|
multimodal_input (dict):
|
|
- **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
|
|
- **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
|
|
chunk_length (int, optional): Determines whether to use full attention (-1) or chunk-based
|
|
attention (>0) during embedding computation.
|
|
|
|
Returns:
|
|
List[List[torch.Tensor]]: audio embeddings
|
|
"""
|
|
# print("audio embedding")
|
|
# (bs, 80, frames) or [], multi audios need filled in advance
|
|
wavforms = (
|
|
[]
|
|
if multimodal_input.audio_features is None
|
|
else multimodal_input.audio_features
|
|
)
|
|
# list, [[x1, x2], [y1], [z1]]
|
|
audio_feature_lens_raw = (
|
|
[]
|
|
if multimodal_input.audio_feature_lens is None
|
|
else multimodal_input.audio_feature_lens
|
|
)
|
|
|
|
final_audio_embeds = []
|
|
|
|
# exist audio
|
|
for wavform in wavforms:
|
|
if len(wavform) > 0:
|
|
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
|
|
batch_size, _, max_mel_seq_len = wavform.shape
|
|
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
|
|
|
|
# Create a sequence tensor of shape (batch_size, max_seq_len)
|
|
seq_range = (
|
|
torch.arange(
|
|
0,
|
|
max_seq_len,
|
|
dtype=audio_feature_lens.dtype,
|
|
device=audio_feature_lens.device,
|
|
)
|
|
.unsqueeze(0)
|
|
.expand(batch_size, max_seq_len)
|
|
)
|
|
lengths_expand = audio_feature_lens.unsqueeze(1).expand(
|
|
batch_size, max_seq_len
|
|
)
|
|
# Create mask
|
|
padding_mask = seq_range >= lengths_expand # 1 for padded values
|
|
|
|
audio_attention_mask_ = padding_mask.view(
|
|
batch_size, 1, 1, max_seq_len
|
|
).expand(batch_size, 1, max_seq_len, max_seq_len)
|
|
audio_attention_mask = audio_attention_mask_.to(
|
|
dtype=self.apm.conv1.weight.dtype,
|
|
device=self.apm.conv1.weight.device,
|
|
)
|
|
|
|
if chunk_length > 0:
|
|
chunk_num_frame = int(chunk_length * 50)
|
|
chunk_mask = self.subsequent_chunk_mask(
|
|
size=max_seq_len,
|
|
chunk_size=chunk_num_frame,
|
|
num_left_chunks=-1,
|
|
device=audio_attention_mask_.device,
|
|
)
|
|
audio_attention_mask_ = torch.logical_or(
|
|
audio_attention_mask_, torch.logical_not(chunk_mask)
|
|
)
|
|
|
|
audio_attention_mask[audio_attention_mask_] = float("-inf")
|
|
audio_states = self.apm(
|
|
wavform,
|
|
output_hidden_states=True,
|
|
attention_mask=audio_attention_mask,
|
|
).hidden_states[self.audio_encoder_layer]
|
|
audio_embeds = self.audio_projection_layer(audio_states)
|
|
|
|
audio_embeds = audio_embeds.transpose(1, 2)
|
|
audio_embeds = self.audio_avg_pooler(audio_embeds)
|
|
audio_embeds = audio_embeds.transpose(1, 2)
|
|
|
|
_, feature_lens_after_pooling = self._get_feat_extract_output_lengths(
|
|
audio_feature_lens
|
|
)
|
|
|
|
num_audio_tokens = feature_lens_after_pooling
|
|
|
|
idx = 0
|
|
for i in range(len(audio_feature_lens_raw)):
|
|
target_audio_embeds = []
|
|
for _ in range(len(audio_feature_lens_raw[i])):
|
|
target_audio_embeds.append(
|
|
audio_embeds[idx, : num_audio_tokens[idx], :]
|
|
)
|
|
idx += 1
|
|
final_audio_embeds.append(target_audio_embeds)
|
|
return final_audio_embeds
|
|
|
|
def get_omni_embedding(
|
|
self,
|
|
input_ids,
|
|
multimodal_input: MultimodalInputs,
|
|
input_embeds: torch.Tensor,
|
|
forward_mode: ForwardMode,
|
|
chunk_length=-1,
|
|
stream_input=False,
|
|
):
|
|
"""
|
|
Args:
|
|
multimodal_input:
|
|
input_embeds:
|
|
chunk_length: whisper use full attention or chunk attention
|
|
stream_input: use streaming audio embedding
|
|
Returns:
|
|
final embeddings with audio feature
|
|
"""
|
|
input_embeds = input_embeds.unsqueeze(0)
|
|
if not forward_mode.is_decode() and multimodal_input.contains_audio_inputs():
|
|
audio_bounds = get_multimodal_data_bounds(
|
|
input_ids=input_ids,
|
|
pad_values=multimodal_input.pad_values,
|
|
token_pairs=[
|
|
(multimodal_input.audio_start_id, multimodal_input.audio_end_id)
|
|
],
|
|
)
|
|
if audio_bounds.numel() == 0:
|
|
input_embeds = input_embeds.squeeze(0)
|
|
# TODO
|
|
logger.warn("Unimplemented logic. Please try disabling chunked prefill")
|
|
return input_embeds
|
|
audio_bounds = audio_bounds.unsqueeze(0)
|
|
bs = len(input_embeds)
|
|
|
|
if stream_input:
|
|
audio_embeddings = self.get_audio_embedding_streaming(multimodal_input)
|
|
else:
|
|
audio_embeddings = self.get_audio_embedding(
|
|
multimodal_input, chunk_length
|
|
)
|
|
# batch size
|
|
assert len(audio_embeddings) == len(input_embeds)
|
|
if len(audio_embeddings) > 0:
|
|
if self.config.chunk_input:
|
|
for i in range(bs):
|
|
audio_embs = torch.cat(audio_embeddings[i], dim=0).to(
|
|
device=input_embeds.device, dtype=input_embeds.dtype
|
|
)
|
|
audio_start_pos = 0
|
|
for bound in audio_bounds[i]:
|
|
audio_len = bound[1] - bound[0] + 1
|
|
input_embeds[0, bound[0] : bound[1] + 1] = audio_embs[
|
|
audio_start_pos : audio_start_pos + audio_len, :
|
|
]
|
|
audio_start_pos += audio_len
|
|
else:
|
|
for i in range(bs):
|
|
audio_embs = audio_embeddings[i]
|
|
bounds = audio_bounds[i]
|
|
for embs, bound in zip(audio_embs, bounds):
|
|
audio_indices = torch.arange(
|
|
bound[0], bound[1], dtype=torch.long
|
|
).to(input_embeds.device)
|
|
|
|
if embs.shape[0] != len(audio_indices):
|
|
raise ValueError(
|
|
f"Shape mismatch: Trying to assign embeddings of shape {embs.shape} "
|
|
f"to input indices of length {len(audio_indices)}"
|
|
)
|
|
input_embeds[i, audio_indices] = embs.to(input_embeds.dtype)
|
|
input_embeds = input_embeds.squeeze(0)
|
|
return input_embeds
|
|
|
|
def get_image_features(
|
|
self,
|
|
image_inputs: MultimodalInputs,
|
|
) -> torch.Tensor:
|
|
pixel_values = image_inputs.pixel_values
|
|
tgt_sizes = image_inputs.tgt_sizes
|
|
device = self.vpm.embeddings.position_embedding.weight.device
|
|
dtype = self.vpm.embeddings.position_embedding.weight.dtype
|
|
all_pixel_values_lst = [
|
|
i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
|
|
]
|
|
|
|
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
|
|
assert isinstance(max_patches, int)
|
|
|
|
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
|
|
all_pixel_values_lst, batch_first=True, padding_value=0.0
|
|
)
|
|
B, L, _ = all_pixel_values.shape
|
|
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
|
|
patch_attn_mask = torch.zeros(
|
|
(B, 1, max_patches), dtype=torch.bool, device=device
|
|
)
|
|
|
|
tgt_sizes_tensor = tgt_sizes.clone().to(device=patch_attn_mask.device)
|
|
mask_shapes = tgt_sizes_tensor[:, 0] * tgt_sizes_tensor[:, 1]
|
|
patch_attn_mask[:, 0, :] = torch.arange(
|
|
patch_attn_mask.size(2), device=patch_attn_mask.device
|
|
).unsqueeze(0) < mask_shapes.unsqueeze(1)
|
|
|
|
vision_embedding = self.vpm(
|
|
all_pixel_values.type(dtype),
|
|
patch_attention_mask=patch_attn_mask,
|
|
tgt_sizes=tgt_sizes,
|
|
)
|
|
return self.resampler(vision_embedding, tgt_sizes)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
**kwargs: Any,
|
|
) -> torch.Tensor:
|
|
inputs_embeds = None
|
|
# TODO(mick): optimize the logic here: clamp, merge and embedding should happens at most once
|
|
if (
|
|
not forward_batch.forward_mode.is_decode()
|
|
and forward_batch.contains_image_inputs()
|
|
):
|
|
mm_inputs = forward_batch.merge_mm_inputs()
|
|
inputs_embeds = embed_mm_inputs(
|
|
mm_input=mm_inputs,
|
|
input_ids=input_ids,
|
|
input_embedding=self.get_input_embeddings(),
|
|
mm_data_embedding_func=self.get_image_features,
|
|
placeholder_token_ids=[mm_inputs.im_token_id] + mm_inputs.pad_values,
|
|
)
|
|
|
|
input_ids = input_ids.clamp(
|
|
min=0, max=self.get_input_embeddings().num_embeddings - 1
|
|
)
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.llm.get_input_embeddings(input_ids)
|
|
if (
|
|
not forward_batch.forward_mode.is_decode()
|
|
and self.config.init_audio
|
|
and forward_batch.contains_audio_inputs()
|
|
):
|
|
mm_input = forward_batch.merge_mm_inputs()
|
|
inputs_embeds = self.get_omni_embedding(
|
|
input_ids=input_ids,
|
|
multimodal_input=mm_input,
|
|
input_embeds=inputs_embeds,
|
|
forward_mode=forward_batch.forward_mode,
|
|
chunk_length=self.config.audio_chunk_length,
|
|
stream_input=False,
|
|
)
|
|
|
|
forward_batch.mm_inputs = None
|
|
|
|
hidden_states = self.llm.model(
|
|
input_ids=None,
|
|
positions=positions,
|
|
forward_batch=forward_batch,
|
|
input_embeds=inputs_embeds,
|
|
)
|
|
|
|
return self.logits_processor(
|
|
input_ids, hidden_states, self.llm.lm_head, forward_batch
|
|
)
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("qkv_proj", "q_proj", "q"),
|
|
("qkv_proj", "k_proj", "k"),
|
|
("qkv_proj", "v_proj", "v"),
|
|
("gate_up_proj", "gate_proj", 0),
|
|
("gate_up_proj", "up_proj", 1),
|
|
]
|
|
|
|
params_dict = dict(self.named_parameters())
|
|
for name, loaded_weight in weights:
|
|
|
|
if "rotary_emb.inv_freq~" in name or "projector" in name:
|
|
continue
|
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
|
# Models trained using ColossalAI may include these tensors in
|
|
# the checkpoint. Skip them.
|
|
continue
|
|
|
|
# adapt to parametrization
|
|
if self.config.init_tts and "tts" in name:
|
|
name = name.replace(".parametrizations", "")
|
|
name = name.replace(".weight.original0", ".weight_g")
|
|
name = name.replace(".weight.original1", ".weight_v")
|
|
|
|
# adapt to VisionAttention
|
|
if "vpm" in name:
|
|
name = name.replace(r"self_attn.out_proj", r"self_attn.proj")
|
|
|
|
if not self.config.init_tts and "tts" in name:
|
|
continue
|
|
if not self.config.init_audio and ("apm" in name or "audio" in name):
|
|
continue
|
|
if not self.config.init_vision and "vpm" in name:
|
|
continue
|
|
|
|
if (
|
|
"sampler" in name
|
|
or "apm" in name
|
|
or ("tts" in name and "self_attn" in name)
|
|
or ("tts.model.layers" in name and ".mlp" in name)
|
|
):
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
continue
|
|
|
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
# replace the name and load with customized loader
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
# # Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
|
|
|
|
EntryClass = [MiniCPMO]
|