diff --git a/app/main.py b/app/main.py index 41bc936..2f48691 100644 --- a/app/main.py +++ b/app/main.py @@ -217,29 +217,18 @@ def _worker(t, q): def _encode(texts: List[str]): - """ - 单次请求: - 1. 子进程跑 GPU 推理;成功→返回 - 2. 若子进程 OOM / CUDA Error → 同一次请求 fallback 到 CPU - 绝不改全局状态,其他并发请求不受影响 - """ - q = mp.Queue() - p = mp.Process(target=_worker, args=(texts, q)) - p.start() - p.join(timeout=60) - - if not q.empty(): - status, payload = q.get() - if status == "ok": - return payload - if "out of memory" in payload.lower() or "cuda error" in payload.lower(): - logger.warning("GPU OOM → 本次请求改走 CPU:%s", payload) + try: + return model.encode(texts, return_dense=True) + except RuntimeError as e: + if "out of memory" in str(e).lower() or "cuda error" in str(e).lower(): + logger.warning("GPU OOM → fallback to CPU: %s", str(e)) torch.cuda.empty_cache() - cpu_model, _ = load_model("cpu") - return cpu_model.encode(texts, return_dense=True) - raise RuntimeError(payload) + global CPU_MODEL_CACHE + if CPU_MODEL_CACHE is None: + CPU_MODEL_CACHE, _ = load_model("cpu") + return CPU_MODEL_CACHE.encode(texts, return_dense=True) + raise - raise RuntimeError("子进程异常退出,无返回") @app.post("/v1/embeddings")