diff --git a/app/main.py b/app/main.py index 759fd15..f3c3518 100644 --- a/app/main.py +++ b/app/main.py @@ -83,19 +83,25 @@ def _choose_precision(device: str) -> str: def load_model(device: str): precision = _choose_precision(device) - use_fp16 = precision == "fp16" + use_fp16 = precision == "fp16" logger.info("Loading BGEM3 on %s (%s)", device, precision) - if device == "cpu": # ← 仅 CPU 路径 - os.environ["CUDA_VISIBLE_DEVICES"] = "" # 彻底摘掉 CUDA + if device == "cpu": + # CPU 路径:彻底摘掉 CUDA + os.environ["CUDA_VISIBLE_DEVICES"] = "" torch.cuda.is_available = lambda: False torch.cuda.device_count = lambda: 0 + else: + # GPU 路径:只暴露本 worker 选中的那张卡 + # 例:device == "cuda:3" → 只让当前进程看到 GPU 3 + idx = device.split(":")[1] + os.environ["CUDA_VISIBLE_DEVICES"] = idx + # device_count 现在返回 1,BGEM3 只会在这张卡上建 1 个子进程 + torch.cuda.device_count = lambda: 1 + mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device) - # # 多 GPU 时才包 DataParallel - # if device.startswith("cuda") and torch.cuda.device_count() > 1: - # mdl = torch.nn.DataParallel(mdl) - + # 不再包 DataParallel;每个 worker 单卡即可 return mdl, precision # -----------------------------------------------------------------------------#