From 17923f3bdc2baa8efd5b1ed69f6bed4d5aa3e248 Mon Sep 17 00:00:00 2001 From: hailin Date: Tue, 7 Apr 2026 03:28:16 -0700 Subject: [PATCH] feat: TTS on 2 GPUs (cuda:2,cuda:3) for faster inference --- config/.config.yaml | 2 +- modules/tts/qwen3_tts.py | 24 ++++++++++++++++++------ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/config/.config.yaml b/config/.config.yaml index 66758c8..6e67c9b 100644 --- a/config/.config.yaml +++ b/config/.config.yaml @@ -42,7 +42,7 @@ TTS: type: qwen3_tts model_path: /home/ZeroStack/xiaozhi/Qwen3-TTS-12Hz-1.7B-CustomVoice tokenizer_path: /home/ZeroStack/xiaozhi/Qwen3-TTS-Tokenizer-12Hz - device: cuda:2 + device: cuda:2,cuda:3 dtype: bfloat16 speaker: uncle_fu language: Chinese diff --git a/modules/tts/qwen3_tts.py b/modules/tts/qwen3_tts.py index 8c6ff59..0dd21bc 100644 --- a/modules/tts/qwen3_tts.py +++ b/modules/tts/qwen3_tts.py @@ -37,12 +37,24 @@ class TTSProvider(TTSProviderBase): ) t0 = time.time() - self.model = Qwen3TTSModel.from_pretrained( - model_path, - device_map=device, - dtype=dtype, - attn_implementation="flash_attention_2", - ) + # Use multiple GPUs if specified (e.g. "cuda:2,cuda:3") + if "," in device: + gpu_ids = [d.strip().replace("cuda:", "") for d in device.split(",")] + max_memory = {int(g): "22GiB" for g in gpu_ids} + self.model = Qwen3TTSModel.from_pretrained( + model_path, + device_map="auto", + max_memory=max_memory, + dtype=dtype, + attn_implementation="flash_attention_2", + ) + else: + self.model = Qwen3TTSModel.from_pretrained( + model_path, + device_map=device, + dtype=dtype, + attn_implementation="flash_attention_2", + ) self.tokenizer = Qwen3TTSTokenizer.from_pretrained(tokenizer_path) # Get supported speakers