feat: TTS on 2 GPUs (cuda:2,cuda:3) for faster inference
This commit is contained in:
parent
b75e813c03
commit
17923f3bdc
|
|
@ -42,7 +42,7 @@ TTS:
|
|||
type: qwen3_tts
|
||||
model_path: /home/ZeroStack/xiaozhi/Qwen3-TTS-12Hz-1.7B-CustomVoice
|
||||
tokenizer_path: /home/ZeroStack/xiaozhi/Qwen3-TTS-Tokenizer-12Hz
|
||||
device: cuda:2
|
||||
device: cuda:2,cuda:3
|
||||
dtype: bfloat16
|
||||
speaker: uncle_fu
|
||||
language: Chinese
|
||||
|
|
|
|||
|
|
@ -37,6 +37,18 @@ class TTSProvider(TTSProviderBase):
|
|||
)
|
||||
t0 = time.time()
|
||||
|
||||
# Use multiple GPUs if specified (e.g. "cuda:2,cuda:3")
|
||||
if "," in device:
|
||||
gpu_ids = [d.strip().replace("cuda:", "") for d in device.split(",")]
|
||||
max_memory = {int(g): "22GiB" for g in gpu_ids}
|
||||
self.model = Qwen3TTSModel.from_pretrained(
|
||||
model_path,
|
||||
device_map="auto",
|
||||
max_memory=max_memory,
|
||||
dtype=dtype,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
else:
|
||||
self.model = Qwen3TTSModel.from_pretrained(
|
||||
model_path,
|
||||
device_map=device,
|
||||
|
|
|
|||
Loading…
Reference in New Issue