taixf/modules/tts/sherpa_tts.py

80 lines
2.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import io
import os
import wave
import asyncio
import numpy as np
import sherpa_onnx
from config.logger import setup_logging
from core.providers.tts.base import TTSProviderBase
TAG = __name__
logger = setup_logging()
class TTSProvider(TTSProviderBase):
def __init__(self, config, delete_audio_file):
super().__init__(config, delete_audio_file)
model_dir = config.get("model_dir", "models/vits-melo-tts-zh_en")
speed = config.get("speed", 1.0)
self.speed = float(speed) if speed else 1.0
self.sid = int(config.get("sid", 0))
# 优先使用 int8 量化模型(更快)
model_file = f"{model_dir}/model.int8.onnx"
if not os.path.exists(model_file) or os.path.getsize(model_file) < 1024:
model_file = f"{model_dir}/model.onnx"
num_threads = int(config.get("num_threads", 8))
tts_config = sherpa_onnx.OfflineTtsConfig(
model=sherpa_onnx.OfflineTtsModelConfig(
vits=sherpa_onnx.OfflineTtsVitsModelConfig(
model=model_file,
lexicon=f"{model_dir}/lexicon.txt",
tokens=f"{model_dir}/tokens.txt",
dict_dir=f"{model_dir}/dict",
),
num_threads=num_threads,
),
rule_fsts=f"{model_dir}/date.fst,{model_dir}/phone.fst,{model_dir}/number.fst,{model_dir}/new_heteronym.fst",
max_num_sentences=1,
)
self.tts = sherpa_onnx.OfflineTts(tts_config)
self.sample_rate = self.tts.sample_rate
logger.bind(tag=TAG).info(
f"SherpaOnnxTTS 初始化完成: model_dir={model_dir}, sample_rate={self.sample_rate}, sid={self.sid}"
)
def _generate_wav(self, text):
"""同步合成,在线程池中调用"""
from scipy.signal import resample_poly
from math import gcd
audio = self.tts.generate(text, sid=self.sid, speed=self.speed)
samples = np.array(audio.samples, dtype=np.float32)
# 重采样到目标采样率(设备要求 24000Hz模型输出 44100Hz
target_sr = 24000
if self.sample_rate != target_sr:
g = gcd(self.sample_rate, target_sr)
samples = resample_poly(samples, target_sr // g, self.sample_rate // g)
pcm = (samples * 32767).astype(np.int16)
wav_io = io.BytesIO()
with wave.open(wav_io, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(target_sr)
wf.writeframes(pcm.tobytes())
return wav_io.getvalue()
async def text_to_speak(self, text, output_file):
wav_data = self._generate_wav(text)
if output_file:
with open(output_file, "wb") as f:
f.write(wav_data)
else:
return wav_data