From 04146a7f2816260ecc6f8a6bfa41c72131379f5b Mon Sep 17 00:00:00 2001 From: hailin Date: Tue, 5 Aug 2025 16:17:57 +0800 Subject: [PATCH] . --- app/main.py | 48 +++++++++++++++++++++++++++++++----------------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/app/main.py b/app/main.py index f3c3518..18cd4c0 100644 --- a/app/main.py +++ b/app/main.py @@ -31,6 +31,9 @@ SAFE_MIN_FREE_MB = int( os.getenv("SAFE_MIN_FREE_MB", "16384") ) # 启动时要求的最小空闲显存,MB +MODEL_VRAM_MB = int(os.getenv("MODEL_VRAM_MB", "16384")) # bge-m3-large fp16 ≈ 16 GiB +POST_LOAD_GAP_MB = 512 + # -----------------------------------------------------------------------------# # Logging # -----------------------------------------------------------------------------# @@ -109,42 +112,53 @@ def load_model(device: str): # -----------------------------------------------------------------------------# def auto_select_and_load(min_free_mb: int = 4096): """ - 1. Gather GPUs with free_mem ≥ min_free_mb - 2. Sort by free_mem desc, attempt to load - 3. All fail → CPU - Return (model, device, precision) + 1. 过滤掉空闲显存 < MODEL_VRAM_MB 的 GPU + 2. 按空闲显存降序依次尝试加载 + 3. 载入后再次检查:若剩余 < POST_LOAD_GAP_MB → 视为失败 + 4. 若全部 GPU 不满足 → CPU """ if not torch.cuda.is_available(): logger.info("No GPU detected → CPU") return (*load_model("cpu"), "cpu") - # Build candidate list - candidates: list[tuple[int, int]] = [] # (free_MB, idx) + # 收集每张卡的空闲显存 + candidates: list[tuple[int, int]] = [] # (free_MB, idx) for idx in range(torch.cuda.device_count()): free, _ = _gpu_mem_info(idx) free_mb = free // 2**20 candidates.append((free_mb, idx)) - candidates = [c for c in candidates if c[0] >= min_free_mb] + # 先按 min_free_mb 做一次粗过滤(通常 16384 MB) + candidates = [c for c in candidates if c[0] >= MODEL_VRAM_MB] if not candidates: - logger.warning("All GPUs free_mem < %d MB → CPU", min_free_mb) + logger.warning("All GPUs free_mem < %d MB → CPU", MODEL_VRAM_MB) return (*load_model("cpu"), "cpu") - candidates.sort(reverse=True) # high free_mem first + # 按空闲显存从高到低尝试 + candidates.sort(reverse=True) for free_mb, idx in candidates: dev = f"cuda:{idx}" try: logger.info("Trying %s (free=%d MB)", dev, free_mb) mdl, prec = load_model(dev) - return mdl, prec, dev - except RuntimeError as e: - if "out of memory" in str(e).lower(): - logger.warning("%s OOM → next GPU", dev) - torch.cuda.empty_cache() - continue - raise # non-OOM error - logger.warning("All GPUs failed → CPU fallback") + # 载入后再测一次,确保还剩余 ≥ POST_LOAD_GAP_MB + new_free, _ = _gpu_mem_info(idx) + new_free_mb = new_free // 2**20 + if new_free_mb < POST_LOAD_GAP_MB: + raise RuntimeError( + f"post-load free {new_free_mb} MB < {POST_LOAD_GAP_MB} MB" + ) + + return mdl, prec, dev # GPU 占用成功 + + except RuntimeError as e: + logger.warning("%s unusable (%s) → next", dev, e) + torch.cuda.empty_cache() + continue + + # 全部 GPU 都不满足 + logger.warning("No suitable GPU left → CPU fallback") return (*load_model("cpu"), "cpu")