This commit is contained in:
hailin 2025-08-05 15:06:23 +08:00
parent 357deccf86
commit 6a2ddde60b
1 changed files with 13 additions and 31 deletions

View File

@ -82,26 +82,24 @@ def _choose_precision(device: str) -> str:
def load_model(device: str): def load_model(device: str):
"""Instantiate model on `device`; return (model, precision)."""
precision = _choose_precision(device) precision = _choose_precision(device)
use_fp16 = precision == "fp16" use_fp16 = precision == "fp16"
logger.info("Loading BGEM3 on %s (%s)", device, precision) logger.info("Loading BGEM3 on %s (%s)", device, precision)
# ---------- 关键:若走 CPU先把整个 CUDA 伪装掉 ----------
if device == "cpu":
os.environ["CUDA_VISIBLE_DEVICES"] = "" # 后续所有 fork 也看不到 GPU
torch.cuda.is_available = lambda: False
torch.cuda.device_count = lambda: 0
# ----------------------------------------------------------
mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device) mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device)
if device == "cpu": # 多 GPU 时才包 DataParallel
# 屏蔽 GPU让后续 torch / BGEM3 都认不出 CUDA
os.environ["CUDA_VISIBLE_DEVICES"] = ""
# Simple DataParallel for multi-GPU inference
if device.startswith("cuda") and torch.cuda.device_count() > 1: if device.startswith("cuda") and torch.cuda.device_count() > 1:
logger.info(
"Wrapping model with torch.nn.DataParallel (%d GPUs)",
torch.cuda.device_count(),
)
mdl = torch.nn.DataParallel(mdl) mdl = torch.nn.DataParallel(mdl)
return mdl, precision
return mdl, precision
# -----------------------------------------------------------------------------# # -----------------------------------------------------------------------------#
# Auto-select device (startup) # Auto-select device (startup)
@ -177,20 +175,11 @@ logger.info("Using SAFE_MIN_FREE_MB = %d MB", SAFE_MIN_FREE_MB)
@app.on_event("startup") @app.on_event("startup")
async def warm_up_mp_pool(): async def warm_up_mp_pool():
logger.info("Warm-up on %s", DEVICE)
try: try:
if DEVICE.startswith("cuda"): _ = model.encode(["warmup"], return_dense=True) # 不传 num_processes
logger.info("Warm-up (GPU) → 建多进程池")
_ = model.encode(["warmup"], return_dense=True)
else:
logger.info("Warm-up (CPU) → 单进程初始化")
# ── 临时让库“以为”没有 GPU ────────────────────────────
orig_cnt = torch.cuda.device_count
torch.cuda.device_count = lambda: 0
_ = model.encode(["warmup"], return_dense=True) # 不传 num_processes
torch.cuda.device_count = orig_cnt
# ──────────────────────────────────────────────────────
except Exception as e: except Exception as e:
logger.warning("Warm-up failed: %s 首条请求时再退避", e) logger.warning("Warm-up failed: %s — 首条请求时再退避", e)
@ -208,14 +197,7 @@ def _encode(texts: List[str]):
""" """
def _worker(t, q): def _worker(t, q):
try: try:
if DEVICE.startswith("cuda"): out = model.encode(t, return_dense=True) # GPU / CPU 路径都安全
out = model.encode(t, return_dense=True) # GPU 正常跑
else:
# 临时屏蔽 GPU单进程 CPU 推理
orig_cnt = torch.cuda.device_count
torch.cuda.device_count = lambda: 0
out = model.encode(t, return_dense=True) # 不传 num_processes
torch.cuda.device_count = orig_cnt
q.put(("ok", out)) q.put(("ok", out))
except Exception as e: except Exception as e:
q.put(("err", str(e))) q.put(("err", str(e)))