diff --git a/app/main.py b/app/main.py index 868a642..3e473fa 100644 --- a/app/main.py +++ b/app/main.py @@ -82,26 +82,24 @@ def _choose_precision(device: str) -> str: def load_model(device: str): - """Instantiate model on `device`; return (model, precision).""" precision = _choose_precision(device) use_fp16 = precision == "fp16" logger.info("Loading BGEM3 on %s (%s)", device, precision) + # ---------- 关键:若走 CPU,先把整个 CUDA 伪装掉 ---------- + if device == "cpu": + os.environ["CUDA_VISIBLE_DEVICES"] = "" # 后续所有 fork 也看不到 GPU + torch.cuda.is_available = lambda: False + torch.cuda.device_count = lambda: 0 + # ---------------------------------------------------------- + mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device) - if device == "cpu": - # 屏蔽 GPU,让后续 torch / BGEM3 都认不出 CUDA - os.environ["CUDA_VISIBLE_DEVICES"] = "" - - # Simple DataParallel for multi-GPU inference + # 多 GPU 时才包 DataParallel if device.startswith("cuda") and torch.cuda.device_count() > 1: - logger.info( - "Wrapping model with torch.nn.DataParallel (%d GPUs)", - torch.cuda.device_count(), - ) mdl = torch.nn.DataParallel(mdl) - return mdl, precision + return mdl, precision # -----------------------------------------------------------------------------# # Auto-select device (startup) @@ -177,20 +175,11 @@ logger.info("Using SAFE_MIN_FREE_MB = %d MB", SAFE_MIN_FREE_MB) @app.on_event("startup") async def warm_up_mp_pool(): + logger.info("Warm-up on %s", DEVICE) try: - if DEVICE.startswith("cuda"): - logger.info("Warm-up (GPU) → 建多进程池") - _ = model.encode(["warmup"], return_dense=True) - else: - logger.info("Warm-up (CPU) → 单进程初始化") - # ── 临时让库“以为”没有 GPU ──────────────────────────── - orig_cnt = torch.cuda.device_count - torch.cuda.device_count = lambda: 0 - _ = model.encode(["warmup"], return_dense=True) # 不传 num_processes - torch.cuda.device_count = orig_cnt - # ────────────────────────────────────────────────────── + _ = model.encode(["warmup"], return_dense=True) # 不传 num_processes except Exception as e: - logger.warning("Warm-up failed: %s —— 首条请求时再退避", e) + logger.warning("Warm-up failed: %s — 首条请求时再退避", e) @@ -208,14 +197,7 @@ def _encode(texts: List[str]): """ def _worker(t, q): try: - if DEVICE.startswith("cuda"): - out = model.encode(t, return_dense=True) # GPU 正常跑 - else: - # 临时屏蔽 GPU,单进程 CPU 推理 - orig_cnt = torch.cuda.device_count - torch.cuda.device_count = lambda: 0 - out = model.encode(t, return_dense=True) # 不传 num_processes - torch.cuda.device_count = orig_cnt + out = model.encode(t, return_dense=True) # GPU / CPU 路径都安全 q.put(("ok", out)) except Exception as e: q.put(("err", str(e)))