diff --git a/app/main.py b/app/main.py index 3e473fa..124b947 100644 --- a/app/main.py +++ b/app/main.py @@ -86,13 +86,10 @@ def load_model(device: str): use_fp16 = precision == "fp16" logger.info("Loading BGEM3 on %s (%s)", device, precision) - # ---------- 关键:若走 CPU,先把整个 CUDA 伪装掉 ---------- - if device == "cpu": - os.environ["CUDA_VISIBLE_DEVICES"] = "" # 后续所有 fork 也看不到 GPU + if device == "cpu": # ← 仅 CPU 路径 + os.environ["CUDA_VISIBLE_DEVICES"] = "" # 彻底摘掉 CUDA torch.cuda.is_available = lambda: False torch.cuda.device_count = lambda: 0 - # ---------------------------------------------------------- - mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device) # 多 GPU 时才包 DataParallel @@ -173,11 +170,12 @@ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) app = FastAPI() logger.info("Using SAFE_MIN_FREE_MB = %d MB", SAFE_MIN_FREE_MB) +# ② -------- FastAPI 启动预热 -------- @app.on_event("startup") -async def warm_up_mp_pool(): +async def warm_up(): logger.info("Warm-up on %s", DEVICE) try: - _ = model.encode(["warmup"], return_dense=True) # 不传 num_processes + _ = model.encode(["warmup"], return_dense=True) # GPU 建池,CPU 单进程 except Exception as e: logger.warning("Warm-up failed: %s — 首条请求时再退避", e) @@ -195,9 +193,10 @@ def _encode(texts: List[str]): 2. 若子进程 OOM / CUDA Error → 同一次请求 fallback 到 CPU 绝不改全局状态,其他并发请求不受影响 """ + # ③ -------- _encode() 里 worker 调用 -------- def _worker(t, q): try: - out = model.encode(t, return_dense=True) # GPU / CPU 路径都安全 + out = model.encode(t, return_dense=True) # GPU or CPU 均安全 q.put(("ok", out)) except Exception as e: q.put(("err", str(e))) @@ -220,31 +219,6 @@ def _encode(texts: List[str]): raise RuntimeError("子进程异常退出,无返回") -# fallback_done = False # prevent endless downgrade loop - -# def _encode(texts: List[str]): -# """Encode with single downgrade to CPU on OOM / CUDA failure.""" -# global model, DEVICE, PRECISION, fallback_done - -# try: -# return model.encode(texts, return_dense=True) - -# except RuntimeError as err: -# is_oom = "out of memory" in str(err).lower() -# is_cuda_fail = "cuda error" in str(err).lower() or "device-side assert" in str( -# err -# ).lower() - -# if (is_oom or is_cuda_fail) and not fallback_done: -# logger.error("GPU failure (%s). Falling back to CPU…", err) -# fallback_done = True -# torch.cuda.empty_cache() -# DEVICE = "cpu" -# model, PRECISION = load_model(DEVICE) -# return model.encode(texts, return_dense=True) - -# raise # second failure → propagate - @app.post("/v1/embeddings") def create_embedding(request: EmbeddingRequest):