diff --git a/backend/main/xiaozhi-server/core/providers/tts/sherpa_tts.py b/backend/main/xiaozhi-server/core/providers/tts/sherpa_tts.py index 04b2e5c..a76cdc4 100644 --- a/backend/main/xiaozhi-server/core/providers/tts/sherpa_tts.py +++ b/backend/main/xiaozhi-server/core/providers/tts/sherpa_tts.py @@ -1,5 +1,7 @@ import io +import os import wave +import asyncio import numpy as np import sherpa_onnx from config.logger import setup_logging @@ -19,8 +21,7 @@ class TTSProvider(TTSProviderBase): # 优先使用 int8 量化模型(更快) model_file = f"{model_dir}/model.int8.onnx" - import os - if not os.path.exists(model_file): + if not os.path.exists(model_file) or os.path.getsize(model_file) < 1024: model_file = f"{model_dir}/model.onnx" num_threads = int(config.get("num_threads", 8)) @@ -44,7 +45,8 @@ class TTSProvider(TTSProviderBase): f"SherpaOnnxTTS 初始化完成: model_dir={model_dir}, sample_rate={self.sample_rate}, sid={self.sid}" ) - async def text_to_speak(self, text, output_file): + def _generate_wav(self, text): + """同步合成,在线程池中调用""" audio = self.tts.generate(text, sid=self.sid, speed=self.speed) samples = np.array(audio.samples, dtype=np.float32) pcm = (samples * 32767).astype(np.int16) @@ -55,7 +57,11 @@ class TTSProvider(TTSProviderBase): wf.setsampwidth(2) wf.setframerate(self.sample_rate) wf.writeframes(pcm.tobytes()) - wav_data = wav_io.getvalue() + return wav_io.getvalue() + + async def text_to_speak(self, text, output_file): + loop = asyncio.get_event_loop() + wav_data = await loop.run_in_executor(None, self._generate_wav, text) if output_file: with open(output_file, "wb") as f: