This commit is contained in:
hailin 2025-08-05 16:35:42 +08:00
parent 04146a7f28
commit 53285eecad
1 changed files with 17 additions and 29 deletions

View File

@ -27,12 +27,9 @@ from FlagEmbedding import BGEM3FlagModel
# Config
# -----------------------------------------------------------------------------#
MODEL_PATH = "model/bge-m3" # 按需改成你的权重路径
SAFE_MIN_FREE_MB = int(
os.getenv("SAFE_MIN_FREE_MB", "16384")
) # 启动时要求的最小空闲显存MB
MODEL_VRAM_MB = int(os.getenv("MODEL_VRAM_MB", "16384")) # bge-m3-large fp16 ≈ 16 GiB
POST_LOAD_GAP_MB = 512
SAFE_MIN_FREE_MB = MODEL_VRAM_MB + POST_LOAD_GAP_MB # == 16896 MB
# -----------------------------------------------------------------------------#
# Logging
@ -110,7 +107,7 @@ def load_model(device: str):
# -----------------------------------------------------------------------------#
# Auto-select device (startup)
# -----------------------------------------------------------------------------#
def auto_select_and_load(min_free_mb: int = 4096):
def auto_select_and_load() -> tuple:
"""
1. 过滤掉空闲显存 < MODEL_VRAM_MB GPU
2. 按空闲显存降序依次尝试加载
@ -121,47 +118,37 @@ def auto_select_and_load(min_free_mb: int = 4096):
logger.info("No GPU detected → CPU")
return (*load_model("cpu"), "cpu")
# 收集每张卡的空闲显存
candidates: list[tuple[int, int]] = [] # (free_MB, idx)
# 收集候选卡 (free_MB, idx)
candidates = []
for idx in range(torch.cuda.device_count()):
free, _ = _gpu_mem_info(idx)
free_mb = free // 2**20
free_mb = _gpu_mem_info(idx)[0] // 2**20
if free_mb >= MODEL_VRAM_MB: # 至少能放下权重
candidates.append((free_mb, idx))
# 先按 min_free_mb 做一次粗过滤(通常 16384 MB
candidates = [c for c in candidates if c[0] >= MODEL_VRAM_MB]
if not candidates:
logger.warning("All GPUs free_mem < %d MB → CPU", MODEL_VRAM_MB)
return (*load_model("cpu"), "cpu")
# 按空闲显存从高到低尝试
candidates.sort(reverse=True)
for free_mb, idx in candidates:
# 空闲显存从高到低
for free_mb, idx in sorted(candidates, reverse=True):
dev = f"cuda:{idx}"
try:
logger.info("Trying %s (free=%d MB)", dev, free_mb)
mdl, prec = load_model(dev)
# 载入后再测一次,确保还剩余 ≥ POST_LOAD_GAP_MB
new_free, _ = _gpu_mem_info(idx)
new_free_mb = new_free // 2**20
if new_free_mb < POST_LOAD_GAP_MB:
remain_mb = _gpu_mem_info(idx)[0] // 2**20
if remain_mb < POST_LOAD_GAP_MB:
raise RuntimeError(
f"post-load free {new_free_mb} MB < {POST_LOAD_GAP_MB} MB"
)
return mdl, prec, dev # GPU 占用成功
f"post-load free {remain_mb} MB < {POST_LOAD_GAP_MB} MB")
return mdl, prec, dev # 成功
except RuntimeError as e:
logger.warning("%s unusable (%s) → next", dev, e)
torch.cuda.empty_cache()
continue
# 全部 GPU 都不满足
logger.warning("No suitable GPU left → CPU fallback")
return (*load_model("cpu"), "cpu")
# -----------------------------------------------------------------------------#
# CLI
# -----------------------------------------------------------------------------#
@ -180,7 +167,7 @@ if args.device is not None:
model, PRECISION = load_model(DEVICE)
else:
# Auto path with VRAM check
model, PRECISION, DEVICE = auto_select_and_load(SAFE_MIN_FREE_MB)
model, PRECISION, DEVICE = auto_select_and_load()
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
@ -195,7 +182,7 @@ logger.info("Using SAFE_MIN_FREE_MB = %d MB", SAFE_MIN_FREE_MB)
async def warm_up():
logger.info("Warm-up on %s", DEVICE)
try:
_ = model.encode(["warmup"], return_dense=True) # GPU 建池CPU 单进程
_ = model.encode(["warmup"], return_dense=True, num_processes=1)
except Exception as e:
logger.warning("Warm-up failed: %s — 首条请求时再退避", e)
@ -216,7 +203,8 @@ def _encode(texts: List[str]):
# ③ -------- _encode() 里 worker 调用 --------
def _worker(t, q):
try:
out = model.encode(t, return_dense=True) # GPU or CPU 均安全
# out = model.encode(t, return_dense=True) # GPU or CPU 均安全
out = model.encode(t, return_dense=True, num_processes=1)
q.put(("ok", out))
except Exception as e:
q.put(("err", str(e)))