80 lines
2.8 KiB
Python
80 lines
2.8 KiB
Python
import io
|
||
import os
|
||
import wave
|
||
import asyncio
|
||
import numpy as np
|
||
import sherpa_onnx
|
||
from config.logger import setup_logging
|
||
from core.providers.tts.base import TTSProviderBase
|
||
|
||
TAG = __name__
|
||
logger = setup_logging()
|
||
|
||
|
||
class TTSProvider(TTSProviderBase):
|
||
def __init__(self, config, delete_audio_file):
|
||
super().__init__(config, delete_audio_file)
|
||
model_dir = config.get("model_dir", "models/vits-melo-tts-zh_en")
|
||
speed = config.get("speed", 1.0)
|
||
self.speed = float(speed) if speed else 1.0
|
||
self.sid = int(config.get("sid", 0))
|
||
|
||
# 优先使用 int8 量化模型(更快)
|
||
model_file = f"{model_dir}/model.int8.onnx"
|
||
if not os.path.exists(model_file) or os.path.getsize(model_file) < 1024:
|
||
model_file = f"{model_dir}/model.onnx"
|
||
|
||
num_threads = int(config.get("num_threads", 8))
|
||
|
||
tts_config = sherpa_onnx.OfflineTtsConfig(
|
||
model=sherpa_onnx.OfflineTtsModelConfig(
|
||
vits=sherpa_onnx.OfflineTtsVitsModelConfig(
|
||
model=model_file,
|
||
lexicon=f"{model_dir}/lexicon.txt",
|
||
tokens=f"{model_dir}/tokens.txt",
|
||
dict_dir=f"{model_dir}/dict",
|
||
),
|
||
num_threads=num_threads,
|
||
),
|
||
rule_fsts=f"{model_dir}/date.fst,{model_dir}/phone.fst,{model_dir}/number.fst,{model_dir}/new_heteronym.fst",
|
||
max_num_sentences=1,
|
||
)
|
||
self.tts = sherpa_onnx.OfflineTts(tts_config)
|
||
self.sample_rate = self.tts.sample_rate
|
||
logger.bind(tag=TAG).info(
|
||
f"SherpaOnnxTTS 初始化完成: model_dir={model_dir}, sample_rate={self.sample_rate}, sid={self.sid}"
|
||
)
|
||
|
||
def _generate_wav(self, text):
|
||
"""同步合成,在线程池中调用"""
|
||
from scipy.signal import resample_poly
|
||
from math import gcd
|
||
|
||
audio = self.tts.generate(text, sid=self.sid, speed=self.speed)
|
||
samples = np.array(audio.samples, dtype=np.float32)
|
||
|
||
# 重采样到目标采样率(设备要求 24000Hz,模型输出 44100Hz)
|
||
target_sr = 24000
|
||
if self.sample_rate != target_sr:
|
||
g = gcd(self.sample_rate, target_sr)
|
||
samples = resample_poly(samples, target_sr // g, self.sample_rate // g)
|
||
|
||
pcm = (samples * 32767).astype(np.int16)
|
||
|
||
wav_io = io.BytesIO()
|
||
with wave.open(wav_io, "wb") as wf:
|
||
wf.setnchannels(1)
|
||
wf.setsampwidth(2)
|
||
wf.setframerate(target_sr)
|
||
wf.writeframes(pcm.tobytes())
|
||
return wav_io.getvalue()
|
||
|
||
async def text_to_speak(self, text, output_file):
|
||
wav_data = self._generate_wav(text)
|
||
|
||
if output_file:
|
||
with open(output_file, "wb") as f:
|
||
f.write(wav_data)
|
||
else:
|
||
return wav_data
|