This commit is contained in:
hailin 2025-08-05 15:41:26 +08:00
parent 5f7eb2b7ba
commit 30b865b0ff
1 changed files with 3 additions and 3 deletions

View File

@ -92,9 +92,9 @@ def load_model(device: str):
torch.cuda.device_count = lambda: 0
mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device)
# 多 GPU 时才包 DataParallel
if device.startswith("cuda") and torch.cuda.device_count() > 1:
mdl = torch.nn.DataParallel(mdl)
# # 多 GPU 时才包 DataParallel
# if device.startswith("cuda") and torch.cuda.device_count() > 1:
# mdl = torch.nn.DataParallel(mdl)
return mdl, precision