This commit is contained in:
hailin 2025-08-05 16:17:57 +08:00
parent c119771511
commit 04146a7f28
1 changed files with 31 additions and 17 deletions

View File

@ -31,6 +31,9 @@ SAFE_MIN_FREE_MB = int(
os.getenv("SAFE_MIN_FREE_MB", "16384")
) # 启动时要求的最小空闲显存MB
MODEL_VRAM_MB = int(os.getenv("MODEL_VRAM_MB", "16384")) # bge-m3-large fp16 ≈ 16 GiB
POST_LOAD_GAP_MB = 512
# -----------------------------------------------------------------------------#
# Logging
# -----------------------------------------------------------------------------#
@ -109,42 +112,53 @@ def load_model(device: str):
# -----------------------------------------------------------------------------#
def auto_select_and_load(min_free_mb: int = 4096):
"""
1. Gather GPUs with free_mem min_free_mb
2. Sort by free_mem desc, attempt to load
3. All fail CPU
Return (model, device, precision)
1. 过滤掉空闲显存 < MODEL_VRAM_MB GPU
2. 按空闲显存降序依次尝试加载
3. 载入后再次检查若剩余 < POST_LOAD_GAP_MB 视为失败
4. 若全部 GPU 不满足 CPU
"""
if not torch.cuda.is_available():
logger.info("No GPU detected → CPU")
return (*load_model("cpu"), "cpu")
# Build candidate list
# 收集每张卡的空闲显存
candidates: list[tuple[int, int]] = [] # (free_MB, idx)
for idx in range(torch.cuda.device_count()):
free, _ = _gpu_mem_info(idx)
free_mb = free // 2**20
candidates.append((free_mb, idx))
candidates = [c for c in candidates if c[0] >= min_free_mb]
# 先按 min_free_mb 做一次粗过滤(通常 16384 MB
candidates = [c for c in candidates if c[0] >= MODEL_VRAM_MB]
if not candidates:
logger.warning("All GPUs free_mem < %d MB → CPU", min_free_mb)
logger.warning("All GPUs free_mem < %d MB → CPU", MODEL_VRAM_MB)
return (*load_model("cpu"), "cpu")
candidates.sort(reverse=True) # high free_mem first
# 按空闲显存从高到低尝试
candidates.sort(reverse=True)
for free_mb, idx in candidates:
dev = f"cuda:{idx}"
try:
logger.info("Trying %s (free=%d MB)", dev, free_mb)
mdl, prec = load_model(dev)
return mdl, prec, dev
# 载入后再测一次,确保还剩余 ≥ POST_LOAD_GAP_MB
new_free, _ = _gpu_mem_info(idx)
new_free_mb = new_free // 2**20
if new_free_mb < POST_LOAD_GAP_MB:
raise RuntimeError(
f"post-load free {new_free_mb} MB < {POST_LOAD_GAP_MB} MB"
)
return mdl, prec, dev # GPU 占用成功
except RuntimeError as e:
if "out of memory" in str(e).lower():
logger.warning("%s OOM → next GPU", dev)
logger.warning("%s unusable (%s) → next", dev, e)
torch.cuda.empty_cache()
continue
raise # non-OOM error
logger.warning("All GPUs failed → CPU fallback")
# 全部 GPU 都不满足
logger.warning("No suitable GPU left → CPU fallback")
return (*load_model("cpu"), "cpu")