This commit is contained in:
parent
5f7eb2b7ba
commit
30b865b0ff
|
|
@ -92,9 +92,9 @@ def load_model(device: str):
|
|||
torch.cuda.device_count = lambda: 0
|
||||
mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device)
|
||||
|
||||
# 多 GPU 时才包 DataParallel
|
||||
if device.startswith("cuda") and torch.cuda.device_count() > 1:
|
||||
mdl = torch.nn.DataParallel(mdl)
|
||||
# # 多 GPU 时才包 DataParallel
|
||||
# if device.startswith("cuda") and torch.cuda.device_count() > 1:
|
||||
# mdl = torch.nn.DataParallel(mdl)
|
||||
|
||||
return mdl, precision
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue