taixf/modules/tts/qwen3_tts.py

88 lines
2.6 KiB
Python

"""
Qwen3-TTS CustomVoice Provider for xiaozhi-server
Based on sherpa_tts.py structure.
GPU inference using qwen-tts package.
"""
import io
import os
import time
import wave
import asyncio
import numpy as np
from config.logger import setup_logging
from core.providers.tts.base import TTSProviderBase
TAG = __name__
logger = setup_logging()
class TTSProvider(TTSProviderBase):
def __init__(self, config, delete_audio_file):
super().__init__(config, delete_audio_file)
import torch
from qwen_tts import Qwen3TTSModel, Qwen3TTSTokenizer
model_path = config.get("model_path", "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice")
tokenizer_path = config.get("tokenizer_path", "Qwen/Qwen3-TTS-Tokenizer-12Hz")
device = config.get("device", "cuda:0")
dtype_str = config.get("dtype", "bfloat16")
dtype = getattr(torch, dtype_str, torch.bfloat16)
self.speaker = config.get("speaker", "Chelsie")
self.language = config.get("language", "Chinese")
logger.bind(tag=TAG).info(
"Qwen3TTS loading: model=%s device=%s speaker=%s" % (model_path, device, self.speaker)
)
t0 = time.time()
self.model = Qwen3TTSModel.from_pretrained(
model_path,
device_map=device,
dtype=dtype,
)
self.tokenizer = Qwen3TTSTokenizer.from_pretrained(tokenizer_path)
# Get supported speakers
speakers = self.model.get_supported_speakers()
logger.bind(tag=TAG).info(
"Qwen3TTS loaded in %.1fs, speakers=%s" % (time.time() - t0, speakers)
)
async def text_to_speak(self, text, output_file):
t0 = time.time()
# Run in thread pool to avoid blocking
loop = asyncio.get_event_loop()
wavs, sr = await loop.run_in_executor(
None,
lambda: self.model.generate_custom_voice(
text=text,
speaker=self.speaker,
language=self.language,
)
)
audio = wavs[0]
duration = len(audio) / sr
logger.bind(tag=TAG).info(
"TTS: %.2fs合成 %.1fs音频 sr=%d [%s]" % (time.time() - t0, duration, sr, text[:30])
)
# Convert to WAV bytes
pcm = (audio * 32767).astype(np.int16)
wav_io = io.BytesIO()
with wave.open(wav_io, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(sr)
wf.writeframes(pcm.tobytes())
wav_data = wav_io.getvalue()
if output_file:
with open(output_file, "wb") as f:
f.write(wav_data)
else:
return wav_data