taixf/modules/tts/sherpa_tts.py

92 lines
3.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import io
import os
import wave
import asyncio
import numpy as np
import sherpa_onnx
from config.logger import setup_logging
from core.providers.tts.base import TTSProviderBase
TAG = __name__
logger = setup_logging()
class TTSProvider(TTSProviderBase):
def __init__(self, config, delete_audio_file):
super().__init__(config, delete_audio_file)
model_dir = config.get("model_dir", "models/vits-melo-tts-zh_en")
speed = config.get("speed", 1.0)
self.speed = float(speed) if speed else 1.0
self.sid = int(config.get("sid", 0))
# 优先使用 int8 量化模型(更快)
model_file = f"{model_dir}/model.int8.onnx"
if not os.path.exists(model_file) or os.path.getsize(model_file) < 1024:
model_file = f"{model_dir}/model.onnx"
num_threads = int(config.get("num_threads", 8))
tts_config = sherpa_onnx.OfflineTtsConfig(
model=sherpa_onnx.OfflineTtsModelConfig(
vits=sherpa_onnx.OfflineTtsVitsModelConfig(
model=model_file,
lexicon=f"{model_dir}/lexicon.txt",
tokens=f"{model_dir}/tokens.txt",
dict_dir=f"{model_dir}/dict",
),
num_threads=num_threads,
),
rule_fsts=f"{model_dir}/date.fst,{model_dir}/phone.fst,{model_dir}/number.fst,{model_dir}/new_heteronym.fst",
max_num_sentences=1,
)
self.tts = sherpa_onnx.OfflineTts(tts_config)
self.sample_rate = self.tts.sample_rate
logger.bind(tag=TAG).info(
f"SherpaOnnxTTS 初始化完成: model_dir={model_dir}, sample_rate={self.sample_rate}, sid={self.sid}"
)
def _generate_wav(self, text):
"""同步合成"""
import time
from scipy.signal import resample_poly
from math import gcd
logger.bind(tag=TAG).info(f"TTS 收到文字: [{text}]")
t0 = time.time()
audio = self.tts.generate(text, sid=self.sid, speed=self.speed)
samples = np.array(audio.samples, dtype=np.float32)
t1 = time.time()
# 重采样到目标采样率(设备要求 24000Hz模型输出 44100Hz
target_sr = 24000
if self.sample_rate != target_sr:
g = gcd(self.sample_rate, target_sr)
samples = resample_poly(samples, target_sr // g, self.sample_rate // g)
t2 = time.time()
pcm = (samples * 32767).astype(np.int16)
audio_duration = len(pcm) / target_sr
wav_io = io.BytesIO()
with wave.open(wav_io, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(target_sr)
wf.writeframes(pcm.tobytes())
wav_data = wav_io.getvalue()
logger.bind(tag=TAG).info(
f"TTS 完成: 合成={t1-t0:.2f}s 重采样={t2-t1:.2f}s 音频时长={audio_duration:.1f}s 大小={len(wav_data)}B [{text[:20]}]"
)
return wav_data
async def text_to_speak(self, text, output_file):
wav_data = self._generate_wav(text)
if output_file:
with open(output_file, "wb") as f:
f.write(wav_data)
else:
return wav_data