feat: TTS on 2 GPUs (cuda:2,cuda:3) for faster inference
This commit is contained in:
parent
b75e813c03
commit
17923f3bdc
|
|
@ -42,7 +42,7 @@ TTS:
|
||||||
type: qwen3_tts
|
type: qwen3_tts
|
||||||
model_path: /home/ZeroStack/xiaozhi/Qwen3-TTS-12Hz-1.7B-CustomVoice
|
model_path: /home/ZeroStack/xiaozhi/Qwen3-TTS-12Hz-1.7B-CustomVoice
|
||||||
tokenizer_path: /home/ZeroStack/xiaozhi/Qwen3-TTS-Tokenizer-12Hz
|
tokenizer_path: /home/ZeroStack/xiaozhi/Qwen3-TTS-Tokenizer-12Hz
|
||||||
device: cuda:2
|
device: cuda:2,cuda:3
|
||||||
dtype: bfloat16
|
dtype: bfloat16
|
||||||
speaker: uncle_fu
|
speaker: uncle_fu
|
||||||
language: Chinese
|
language: Chinese
|
||||||
|
|
|
||||||
|
|
@ -37,12 +37,24 @@ class TTSProvider(TTSProviderBase):
|
||||||
)
|
)
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
self.model = Qwen3TTSModel.from_pretrained(
|
# Use multiple GPUs if specified (e.g. "cuda:2,cuda:3")
|
||||||
model_path,
|
if "," in device:
|
||||||
device_map=device,
|
gpu_ids = [d.strip().replace("cuda:", "") for d in device.split(",")]
|
||||||
dtype=dtype,
|
max_memory = {int(g): "22GiB" for g in gpu_ids}
|
||||||
attn_implementation="flash_attention_2",
|
self.model = Qwen3TTSModel.from_pretrained(
|
||||||
)
|
model_path,
|
||||||
|
device_map="auto",
|
||||||
|
max_memory=max_memory,
|
||||||
|
dtype=dtype,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.model = Qwen3TTSModel.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
device_map=device,
|
||||||
|
dtype=dtype,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
)
|
||||||
self.tokenizer = Qwen3TTSTokenizer.from_pretrained(tokenizer_path)
|
self.tokenizer = Qwen3TTSTokenizer.from_pretrained(tokenizer_path)
|
||||||
|
|
||||||
# Get supported speakers
|
# Get supported speakers
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue