This commit is contained in:
hailin 2025-08-05 16:35:42 +08:00
parent 04146a7f28
commit 53285eecad
1 changed files with 17 additions and 29 deletions

View File

@ -27,12 +27,9 @@ from FlagEmbedding import BGEM3FlagModel
# Config # Config
# -----------------------------------------------------------------------------# # -----------------------------------------------------------------------------#
MODEL_PATH = "model/bge-m3" # 按需改成你的权重路径 MODEL_PATH = "model/bge-m3" # 按需改成你的权重路径
SAFE_MIN_FREE_MB = int(
os.getenv("SAFE_MIN_FREE_MB", "16384")
) # 启动时要求的最小空闲显存MB
MODEL_VRAM_MB = int(os.getenv("MODEL_VRAM_MB", "16384")) # bge-m3-large fp16 ≈ 16 GiB MODEL_VRAM_MB = int(os.getenv("MODEL_VRAM_MB", "16384")) # bge-m3-large fp16 ≈ 16 GiB
POST_LOAD_GAP_MB = 512 POST_LOAD_GAP_MB = 512
SAFE_MIN_FREE_MB = MODEL_VRAM_MB + POST_LOAD_GAP_MB # == 16896 MB
# -----------------------------------------------------------------------------# # -----------------------------------------------------------------------------#
# Logging # Logging
@ -110,7 +107,7 @@ def load_model(device: str):
# -----------------------------------------------------------------------------# # -----------------------------------------------------------------------------#
# Auto-select device (startup) # Auto-select device (startup)
# -----------------------------------------------------------------------------# # -----------------------------------------------------------------------------#
def auto_select_and_load(min_free_mb: int = 4096): def auto_select_and_load() -> tuple:
""" """
1. 过滤掉空闲显存 < MODEL_VRAM_MB GPU 1. 过滤掉空闲显存 < MODEL_VRAM_MB GPU
2. 按空闲显存降序依次尝试加载 2. 按空闲显存降序依次尝试加载
@ -121,47 +118,37 @@ def auto_select_and_load(min_free_mb: int = 4096):
logger.info("No GPU detected → CPU") logger.info("No GPU detected → CPU")
return (*load_model("cpu"), "cpu") return (*load_model("cpu"), "cpu")
# 收集每张卡的空闲显存 # 收集候选卡 (free_MB, idx)
candidates: list[tuple[int, int]] = [] # (free_MB, idx) candidates = []
for idx in range(torch.cuda.device_count()): for idx in range(torch.cuda.device_count()):
free, _ = _gpu_mem_info(idx) free_mb = _gpu_mem_info(idx)[0] // 2**20
free_mb = free // 2**20 if free_mb >= MODEL_VRAM_MB: # 至少能放下权重
candidates.append((free_mb, idx)) candidates.append((free_mb, idx))
# 先按 min_free_mb 做一次粗过滤(通常 16384 MB
candidates = [c for c in candidates if c[0] >= MODEL_VRAM_MB]
if not candidates: if not candidates:
logger.warning("All GPUs free_mem < %d MB → CPU", MODEL_VRAM_MB) logger.warning("All GPUs free_mem < %d MB → CPU", MODEL_VRAM_MB)
return (*load_model("cpu"), "cpu") return (*load_model("cpu"), "cpu")
# 按空闲显存从高到低尝试 # 空闲显存从高到低
candidates.sort(reverse=True) for free_mb, idx in sorted(candidates, reverse=True):
for free_mb, idx in candidates:
dev = f"cuda:{idx}" dev = f"cuda:{idx}"
try: try:
logger.info("Trying %s (free=%d MB)", dev, free_mb) logger.info("Trying %s (free=%d MB)", dev, free_mb)
mdl, prec = load_model(dev) mdl, prec = load_model(dev)
# 载入后再测一次,确保还剩余 ≥ POST_LOAD_GAP_MB remain_mb = _gpu_mem_info(idx)[0] // 2**20
new_free, _ = _gpu_mem_info(idx) if remain_mb < POST_LOAD_GAP_MB:
new_free_mb = new_free // 2**20
if new_free_mb < POST_LOAD_GAP_MB:
raise RuntimeError( raise RuntimeError(
f"post-load free {new_free_mb} MB < {POST_LOAD_GAP_MB} MB" f"post-load free {remain_mb} MB < {POST_LOAD_GAP_MB} MB")
)
return mdl, prec, dev # GPU 占用成功
return mdl, prec, dev # 成功
except RuntimeError as e: except RuntimeError as e:
logger.warning("%s unusable (%s) → next", dev, e) logger.warning("%s unusable (%s) → next", dev, e)
torch.cuda.empty_cache() torch.cuda.empty_cache()
continue
# 全部 GPU 都不满足
logger.warning("No suitable GPU left → CPU fallback") logger.warning("No suitable GPU left → CPU fallback")
return (*load_model("cpu"), "cpu") return (*load_model("cpu"), "cpu")
# -----------------------------------------------------------------------------# # -----------------------------------------------------------------------------#
# CLI # CLI
# -----------------------------------------------------------------------------# # -----------------------------------------------------------------------------#
@ -180,7 +167,7 @@ if args.device is not None:
model, PRECISION = load_model(DEVICE) model, PRECISION = load_model(DEVICE)
else: else:
# Auto path with VRAM check # Auto path with VRAM check
model, PRECISION, DEVICE = auto_select_and_load(SAFE_MIN_FREE_MB) model, PRECISION, DEVICE = auto_select_and_load()
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
@ -195,7 +182,7 @@ logger.info("Using SAFE_MIN_FREE_MB = %d MB", SAFE_MIN_FREE_MB)
async def warm_up(): async def warm_up():
logger.info("Warm-up on %s", DEVICE) logger.info("Warm-up on %s", DEVICE)
try: try:
_ = model.encode(["warmup"], return_dense=True) # GPU 建池CPU 单进程 _ = model.encode(["warmup"], return_dense=True, num_processes=1)
except Exception as e: except Exception as e:
logger.warning("Warm-up failed: %s — 首条请求时再退避", e) logger.warning("Warm-up failed: %s — 首条请求时再退避", e)
@ -216,7 +203,8 @@ def _encode(texts: List[str]):
# ③ -------- _encode() 里 worker 调用 -------- # ③ -------- _encode() 里 worker 调用 --------
def _worker(t, q): def _worker(t, q):
try: try:
out = model.encode(t, return_dense=True) # GPU or CPU 均安全 # out = model.encode(t, return_dense=True) # GPU or CPU 均安全
out = model.encode(t, return_dense=True, num_processes=1)
q.put(("ok", out)) q.put(("ok", out))
except Exception as e: except Exception as e:
q.put(("err", str(e))) q.put(("err", str(e)))