import io import os import wave import asyncio import numpy as np import sherpa_onnx from config.logger import setup_logging from core.providers.tts.base import TTSProviderBase TAG = __name__ logger = setup_logging() class TTSProvider(TTSProviderBase): def __init__(self, config, delete_audio_file): super().__init__(config, delete_audio_file) model_dir = config.get("model_dir", "models/vits-melo-tts-zh_en") speed = config.get("speed", 1.0) self.speed = float(speed) if speed else 1.0 self.sid = int(config.get("sid", 0)) # 优先使用 int8 量化模型(更快) model_file = f"{model_dir}/model.int8.onnx" if not os.path.exists(model_file) or os.path.getsize(model_file) < 1024: model_file = f"{model_dir}/model.onnx" num_threads = int(config.get("num_threads", 8)) tts_config = sherpa_onnx.OfflineTtsConfig( model=sherpa_onnx.OfflineTtsModelConfig( vits=sherpa_onnx.OfflineTtsVitsModelConfig( model=model_file, lexicon=f"{model_dir}/lexicon.txt", tokens=f"{model_dir}/tokens.txt", dict_dir=f"{model_dir}/dict", ), num_threads=num_threads, ), rule_fsts=f"{model_dir}/date.fst,{model_dir}/phone.fst,{model_dir}/number.fst,{model_dir}/new_heteronym.fst", max_num_sentences=1, ) self.tts = sherpa_onnx.OfflineTts(tts_config) self.sample_rate = self.tts.sample_rate logger.bind(tag=TAG).info( f"SherpaOnnxTTS 初始化完成: model_dir={model_dir}, sample_rate={self.sample_rate}, sid={self.sid}" ) def _generate_wav(self, text): """同步合成,在线程池中调用""" from scipy.signal import resample_poly from math import gcd audio = self.tts.generate(text, sid=self.sid, speed=self.speed) samples = np.array(audio.samples, dtype=np.float32) # 重采样到目标采样率(设备要求 24000Hz,模型输出 44100Hz) target_sr = 24000 if self.sample_rate != target_sr: g = gcd(self.sample_rate, target_sr) samples = resample_poly(samples, target_sr // g, self.sample_rate // g) pcm = (samples * 32767).astype(np.int16) wav_io = io.BytesIO() with wave.open(wav_io, "wb") as wf: wf.setnchannels(1) wf.setsampwidth(2) wf.setframerate(target_sr) wf.writeframes(pcm.tobytes()) return wav_io.getvalue() async def text_to_speak(self, text, output_file): wav_data = self._generate_wav(text) if output_file: with open(output_file, "wb") as f: f.write(wav_data) else: return wav_data