This commit is contained in:
parent
5f7eb2b7ba
commit
30b865b0ff
|
|
@ -92,9 +92,9 @@ def load_model(device: str):
|
||||||
torch.cuda.device_count = lambda: 0
|
torch.cuda.device_count = lambda: 0
|
||||||
mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device)
|
mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device)
|
||||||
|
|
||||||
# 多 GPU 时才包 DataParallel
|
# # 多 GPU 时才包 DataParallel
|
||||||
if device.startswith("cuda") and torch.cuda.device_count() > 1:
|
# if device.startswith("cuda") and torch.cuda.device_count() > 1:
|
||||||
mdl = torch.nn.DataParallel(mdl)
|
# mdl = torch.nn.DataParallel(mdl)
|
||||||
|
|
||||||
return mdl, precision
|
return mdl, precision
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue