This commit is contained in:
hailin 2025-08-05 16:17:57 +08:00
parent c119771511
commit 04146a7f28
1 changed files with 31 additions and 17 deletions

View File

@ -31,6 +31,9 @@ SAFE_MIN_FREE_MB = int(
os.getenv("SAFE_MIN_FREE_MB", "16384") os.getenv("SAFE_MIN_FREE_MB", "16384")
) # 启动时要求的最小空闲显存MB ) # 启动时要求的最小空闲显存MB
MODEL_VRAM_MB = int(os.getenv("MODEL_VRAM_MB", "16384")) # bge-m3-large fp16 ≈ 16 GiB
POST_LOAD_GAP_MB = 512
# -----------------------------------------------------------------------------# # -----------------------------------------------------------------------------#
# Logging # Logging
# -----------------------------------------------------------------------------# # -----------------------------------------------------------------------------#
@ -109,42 +112,53 @@ def load_model(device: str):
# -----------------------------------------------------------------------------# # -----------------------------------------------------------------------------#
def auto_select_and_load(min_free_mb: int = 4096): def auto_select_and_load(min_free_mb: int = 4096):
""" """
1. Gather GPUs with free_mem min_free_mb 1. 过滤掉空闲显存 < MODEL_VRAM_MB GPU
2. Sort by free_mem desc, attempt to load 2. 按空闲显存降序依次尝试加载
3. All fail CPU 3. 载入后再次检查若剩余 < POST_LOAD_GAP_MB 视为失败
Return (model, device, precision) 4. 若全部 GPU 不满足 CPU
""" """
if not torch.cuda.is_available(): if not torch.cuda.is_available():
logger.info("No GPU detected → CPU") logger.info("No GPU detected → CPU")
return (*load_model("cpu"), "cpu") return (*load_model("cpu"), "cpu")
# Build candidate list # 收集每张卡的空闲显存
candidates: list[tuple[int, int]] = [] # (free_MB, idx) candidates: list[tuple[int, int]] = [] # (free_MB, idx)
for idx in range(torch.cuda.device_count()): for idx in range(torch.cuda.device_count()):
free, _ = _gpu_mem_info(idx) free, _ = _gpu_mem_info(idx)
free_mb = free // 2**20 free_mb = free // 2**20
candidates.append((free_mb, idx)) candidates.append((free_mb, idx))
candidates = [c for c in candidates if c[0] >= min_free_mb] # 先按 min_free_mb 做一次粗过滤(通常 16384 MB
candidates = [c for c in candidates if c[0] >= MODEL_VRAM_MB]
if not candidates: if not candidates:
logger.warning("All GPUs free_mem < %d MB → CPU", min_free_mb) logger.warning("All GPUs free_mem < %d MB → CPU", MODEL_VRAM_MB)
return (*load_model("cpu"), "cpu") return (*load_model("cpu"), "cpu")
candidates.sort(reverse=True) # high free_mem first # 按空闲显存从高到低尝试
candidates.sort(reverse=True)
for free_mb, idx in candidates: for free_mb, idx in candidates:
dev = f"cuda:{idx}" dev = f"cuda:{idx}"
try: try:
logger.info("Trying %s (free=%d MB)", dev, free_mb) logger.info("Trying %s (free=%d MB)", dev, free_mb)
mdl, prec = load_model(dev) mdl, prec = load_model(dev)
return mdl, prec, dev
except RuntimeError as e:
if "out of memory" in str(e).lower():
logger.warning("%s OOM → next GPU", dev)
torch.cuda.empty_cache()
continue
raise # non-OOM error
logger.warning("All GPUs failed → CPU fallback") # 载入后再测一次,确保还剩余 ≥ POST_LOAD_GAP_MB
new_free, _ = _gpu_mem_info(idx)
new_free_mb = new_free // 2**20
if new_free_mb < POST_LOAD_GAP_MB:
raise RuntimeError(
f"post-load free {new_free_mb} MB < {POST_LOAD_GAP_MB} MB"
)
return mdl, prec, dev # GPU 占用成功
except RuntimeError as e:
logger.warning("%s unusable (%s) → next", dev, e)
torch.cuda.empty_cache()
continue
# 全部 GPU 都不满足
logger.warning("No suitable GPU left → CPU fallback")
return (*load_model("cpu"), "cpu") return (*load_model("cpu"), "cpu")