From 53285eecadc5588cbfadacfc18faa27296bccc19 Mon Sep 17 00:00:00 2001 From: hailin Date: Tue, 5 Aug 2025 16:35:42 +0800 Subject: [PATCH] . --- app/main.py | 46 +++++++++++++++++----------------------------- 1 file changed, 17 insertions(+), 29 deletions(-) diff --git a/app/main.py b/app/main.py index 18cd4c0..f86c632 100644 --- a/app/main.py +++ b/app/main.py @@ -27,12 +27,9 @@ from FlagEmbedding import BGEM3FlagModel # Config # -----------------------------------------------------------------------------# MODEL_PATH = "model/bge-m3" # 按需改成你的权重路径 -SAFE_MIN_FREE_MB = int( - os.getenv("SAFE_MIN_FREE_MB", "16384") -) # 启动时要求的最小空闲显存,MB - MODEL_VRAM_MB = int(os.getenv("MODEL_VRAM_MB", "16384")) # bge-m3-large fp16 ≈ 16 GiB POST_LOAD_GAP_MB = 512 +SAFE_MIN_FREE_MB = MODEL_VRAM_MB + POST_LOAD_GAP_MB # == 16896 MB # -----------------------------------------------------------------------------# # Logging @@ -110,7 +107,7 @@ def load_model(device: str): # -----------------------------------------------------------------------------# # Auto-select device (startup) # -----------------------------------------------------------------------------# -def auto_select_and_load(min_free_mb: int = 4096): +def auto_select_and_load() -> tuple: """ 1. 过滤掉空闲显存 < MODEL_VRAM_MB 的 GPU 2. 按空闲显存降序依次尝试加载 @@ -121,47 +118,37 @@ def auto_select_and_load(min_free_mb: int = 4096): logger.info("No GPU detected → CPU") return (*load_model("cpu"), "cpu") - # 收集每张卡的空闲显存 - candidates: list[tuple[int, int]] = [] # (free_MB, idx) + # 收集候选卡 (free_MB, idx) + candidates = [] for idx in range(torch.cuda.device_count()): - free, _ = _gpu_mem_info(idx) - free_mb = free // 2**20 - candidates.append((free_mb, idx)) + free_mb = _gpu_mem_info(idx)[0] // 2**20 + if free_mb >= MODEL_VRAM_MB: # 至少能放下权重 + candidates.append((free_mb, idx)) - # 先按 min_free_mb 做一次粗过滤(通常 16384 MB) - candidates = [c for c in candidates if c[0] >= MODEL_VRAM_MB] if not candidates: logger.warning("All GPUs free_mem < %d MB → CPU", MODEL_VRAM_MB) return (*load_model("cpu"), "cpu") - # 按空闲显存从高到低尝试 - candidates.sort(reverse=True) - for free_mb, idx in candidates: + # 空闲显存从高到低 + for free_mb, idx in sorted(candidates, reverse=True): dev = f"cuda:{idx}" try: logger.info("Trying %s (free=%d MB)", dev, free_mb) mdl, prec = load_model(dev) - # 载入后再测一次,确保还剩余 ≥ POST_LOAD_GAP_MB - new_free, _ = _gpu_mem_info(idx) - new_free_mb = new_free // 2**20 - if new_free_mb < POST_LOAD_GAP_MB: + remain_mb = _gpu_mem_info(idx)[0] // 2**20 + if remain_mb < POST_LOAD_GAP_MB: raise RuntimeError( - f"post-load free {new_free_mb} MB < {POST_LOAD_GAP_MB} MB" - ) - - return mdl, prec, dev # GPU 占用成功 + f"post-load free {remain_mb} MB < {POST_LOAD_GAP_MB} MB") + return mdl, prec, dev # 成功 except RuntimeError as e: logger.warning("%s unusable (%s) → next", dev, e) torch.cuda.empty_cache() - continue - # 全部 GPU 都不满足 logger.warning("No suitable GPU left → CPU fallback") return (*load_model("cpu"), "cpu") - # -----------------------------------------------------------------------------# # CLI # -----------------------------------------------------------------------------# @@ -180,7 +167,7 @@ if args.device is not None: model, PRECISION = load_model(DEVICE) else: # Auto path with VRAM check - model, PRECISION, DEVICE = auto_select_and_load(SAFE_MIN_FREE_MB) + model, PRECISION, DEVICE = auto_select_and_load() tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) @@ -195,7 +182,7 @@ logger.info("Using SAFE_MIN_FREE_MB = %d MB", SAFE_MIN_FREE_MB) async def warm_up(): logger.info("Warm-up on %s", DEVICE) try: - _ = model.encode(["warmup"], return_dense=True) # GPU 建池,CPU 单进程 + _ = model.encode(["warmup"], return_dense=True, num_processes=1) except Exception as e: logger.warning("Warm-up failed: %s — 首条请求时再退避", e) @@ -216,7 +203,8 @@ def _encode(texts: List[str]): # ③ -------- _encode() 里 worker 调用 -------- def _worker(t, q): try: - out = model.encode(t, return_dense=True) # GPU or CPU 均安全 + # out = model.encode(t, return_dense=True) # GPU or CPU 均安全 + out = model.encode(t, return_dense=True, num_processes=1) q.put(("ok", out)) except Exception as e: q.put(("err", str(e)))