diff --git a/app/main.py b/app/main.py index 124b947..759fd15 100644 --- a/app/main.py +++ b/app/main.py @@ -92,9 +92,9 @@ def load_model(device: str): torch.cuda.device_count = lambda: 0 mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device) - # 多 GPU 时才包 DataParallel - if device.startswith("cuda") and torch.cuda.device_count() > 1: - mdl = torch.nn.DataParallel(mdl) + # # 多 GPU 时才包 DataParallel + # if device.startswith("cuda") and torch.cuda.device_count() > 1: + # mdl = torch.nn.DataParallel(mdl) return mdl, precision