diff --git a/backend/main/xiaozhi-server/core/providers/tts/sherpa_tts.py b/backend/main/xiaozhi-server/core/providers/tts/sherpa_tts.py new file mode 100644 index 0000000..3137f8e --- /dev/null +++ b/backend/main/xiaozhi-server/core/providers/tts/sherpa_tts.py @@ -0,0 +1,55 @@ +import io +import wave +import numpy as np +import sherpa_onnx +from config.logger import setup_logging +from core.providers.tts.base import TTSProviderBase + +TAG = __name__ +logger = setup_logging() + + +class TTSProvider(TTSProviderBase): + def __init__(self, config, delete_audio_file): + super().__init__(config, delete_audio_file) + model_dir = config.get("model_dir", "models/vits-melo-tts-zh_en") + speed = config.get("speed", 1.0) + self.speed = float(speed) if speed else 1.0 + self.sid = int(config.get("sid", 0)) + + tts_config = sherpa_onnx.OfflineTtsConfig( + model=sherpa_onnx.OfflineTtsModelConfig( + vits=sherpa_onnx.OfflineTtsVitsModelConfig( + model=f"{model_dir}/model.onnx", + lexicon=f"{model_dir}/lexicon.txt", + tokens=f"{model_dir}/tokens.txt", + dict_dir=f"{model_dir}/dict", + ), + ), + rule_fsts=f"{model_dir}/date.fst,{model_dir}/phone.fst,{model_dir}/number.fst,{model_dir}/new_heteronym.fst", + max_num_sentences=1, + ) + self.tts = sherpa_onnx.OfflineTts(tts_config) + self.sample_rate = self.tts.sample_rate + logger.bind(tag=TAG).info( + f"SherpaOnnxTTS 初始化完成: model_dir={model_dir}, sample_rate={self.sample_rate}, sid={self.sid}" + ) + + async def text_to_speak(self, text, output_file): + audio = self.tts.generate(text, sid=self.sid, speed=self.speed) + samples = np.array(audio.samples, dtype=np.float32) + pcm = (samples * 32767).astype(np.int16) + + wav_io = io.BytesIO() + with wave.open(wav_io, "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(self.sample_rate) + wf.writeframes(pcm.tobytes()) + wav_data = wav_io.getvalue() + + if output_file: + with open(output_file, "wb") as f: + f.write(wav_data) + else: + return wav_data