101 lines
3.2 KiB
Python
101 lines
3.2 KiB
Python
"""
|
|
Qwen3-TTS CustomVoice Provider for xiaozhi-server
|
|
Based on sherpa_tts.py structure.
|
|
GPU inference using qwen-tts package.
|
|
"""
|
|
|
|
import io
|
|
import os
|
|
import time
|
|
import wave
|
|
import asyncio
|
|
import numpy as np
|
|
|
|
from config.logger import setup_logging
|
|
from core.providers.tts.base import TTSProviderBase
|
|
|
|
TAG = __name__
|
|
logger = setup_logging()
|
|
|
|
|
|
class TTSProvider(TTSProviderBase):
|
|
def __init__(self, config, delete_audio_file):
|
|
super().__init__(config, delete_audio_file)
|
|
import torch
|
|
from qwen_tts import Qwen3TTSModel, Qwen3TTSTokenizer
|
|
|
|
model_path = config.get("model_path", "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice")
|
|
tokenizer_path = config.get("tokenizer_path", "Qwen/Qwen3-TTS-Tokenizer-12Hz")
|
|
device = config.get("device", "cuda:0")
|
|
dtype_str = config.get("dtype", "bfloat16")
|
|
dtype = getattr(torch, dtype_str, torch.bfloat16)
|
|
self.speaker = config.get("speaker", "Chelsie")
|
|
self.language = config.get("language", "Chinese")
|
|
|
|
logger.bind(tag=TAG).info(
|
|
"Qwen3TTS loading: model=%s device=%s speaker=%s" % (model_path, device, self.speaker)
|
|
)
|
|
t0 = time.time()
|
|
|
|
# Use multiple GPUs if specified (e.g. "cuda:2,cuda:3")
|
|
if "," in device:
|
|
gpu_ids = [d.strip().replace("cuda:", "") for d in device.split(",")]
|
|
max_memory = {int(g): "22GiB" for g in gpu_ids}
|
|
self.model = Qwen3TTSModel.from_pretrained(
|
|
model_path,
|
|
device_map="auto",
|
|
max_memory=max_memory,
|
|
dtype=dtype,
|
|
attn_implementation="flash_attention_2",
|
|
)
|
|
else:
|
|
self.model = Qwen3TTSModel.from_pretrained(
|
|
model_path,
|
|
device_map=device,
|
|
dtype=dtype,
|
|
attn_implementation="flash_attention_2",
|
|
)
|
|
self.tokenizer = Qwen3TTSTokenizer.from_pretrained(tokenizer_path)
|
|
|
|
# Get supported speakers
|
|
speakers = self.model.get_supported_speakers()
|
|
logger.bind(tag=TAG).info(
|
|
"Qwen3TTS loaded in %.1fs, speakers=%s" % (time.time() - t0, speakers)
|
|
)
|
|
|
|
async def text_to_speak(self, text, output_file):
|
|
t0 = time.time()
|
|
|
|
# Run in thread pool to avoid blocking
|
|
loop = asyncio.get_event_loop()
|
|
wavs, sr = await loop.run_in_executor(
|
|
None,
|
|
lambda: self.model.generate_custom_voice(
|
|
text=text,
|
|
speaker=self.speaker,
|
|
language=self.language,
|
|
)
|
|
)
|
|
|
|
audio = wavs[0]
|
|
duration = len(audio) / sr
|
|
logger.bind(tag=TAG).info(
|
|
"TTS: %.2fs合成 %.1fs音频 sr=%d [%s]" % (time.time() - t0, duration, sr, text[:30])
|
|
)
|
|
|
|
# Convert to WAV bytes
|
|
pcm = (audio * 32767).astype(np.int16)
|
|
wav_io = io.BytesIO()
|
|
with wave.open(wav_io, "wb") as wf:
|
|
wf.setnchannels(1)
|
|
wf.setsampwidth(2)
|
|
wf.setframerate(sr)
|
|
wf.writeframes(pcm.tobytes())
|
|
wav_data = wav_io.getvalue()
|
|
|
|
if output_file:
|
|
with open(output_file, "wb") as f:
|
|
f.write(wav_data)
|
|
else:
|
|
return wav_data
|