This commit is contained in:
parent
357deccf86
commit
6a2ddde60b
42
app/main.py
42
app/main.py
|
|
@ -82,26 +82,24 @@ def _choose_precision(device: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
def load_model(device: str):
|
def load_model(device: str):
|
||||||
"""Instantiate model on `device`; return (model, precision)."""
|
|
||||||
precision = _choose_precision(device)
|
precision = _choose_precision(device)
|
||||||
use_fp16 = precision == "fp16"
|
use_fp16 = precision == "fp16"
|
||||||
logger.info("Loading BGEM3 on %s (%s)", device, precision)
|
logger.info("Loading BGEM3 on %s (%s)", device, precision)
|
||||||
|
|
||||||
|
# ---------- 关键:若走 CPU,先把整个 CUDA 伪装掉 ----------
|
||||||
|
if device == "cpu":
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "" # 后续所有 fork 也看不到 GPU
|
||||||
|
torch.cuda.is_available = lambda: False
|
||||||
|
torch.cuda.device_count = lambda: 0
|
||||||
|
# ----------------------------------------------------------
|
||||||
|
|
||||||
mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device)
|
mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device)
|
||||||
|
|
||||||
if device == "cpu":
|
# 多 GPU 时才包 DataParallel
|
||||||
# 屏蔽 GPU,让后续 torch / BGEM3 都认不出 CUDA
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
||||||
|
|
||||||
# Simple DataParallel for multi-GPU inference
|
|
||||||
if device.startswith("cuda") and torch.cuda.device_count() > 1:
|
if device.startswith("cuda") and torch.cuda.device_count() > 1:
|
||||||
logger.info(
|
|
||||||
"Wrapping model with torch.nn.DataParallel (%d GPUs)",
|
|
||||||
torch.cuda.device_count(),
|
|
||||||
)
|
|
||||||
mdl = torch.nn.DataParallel(mdl)
|
mdl = torch.nn.DataParallel(mdl)
|
||||||
return mdl, precision
|
|
||||||
|
|
||||||
|
return mdl, precision
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------#
|
# -----------------------------------------------------------------------------#
|
||||||
# Auto-select device (startup)
|
# Auto-select device (startup)
|
||||||
|
|
@ -177,20 +175,11 @@ logger.info("Using SAFE_MIN_FREE_MB = %d MB", SAFE_MIN_FREE_MB)
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def warm_up_mp_pool():
|
async def warm_up_mp_pool():
|
||||||
|
logger.info("Warm-up on %s", DEVICE)
|
||||||
try:
|
try:
|
||||||
if DEVICE.startswith("cuda"):
|
|
||||||
logger.info("Warm-up (GPU) → 建多进程池")
|
|
||||||
_ = model.encode(["warmup"], return_dense=True)
|
|
||||||
else:
|
|
||||||
logger.info("Warm-up (CPU) → 单进程初始化")
|
|
||||||
# ── 临时让库“以为”没有 GPU ────────────────────────────
|
|
||||||
orig_cnt = torch.cuda.device_count
|
|
||||||
torch.cuda.device_count = lambda: 0
|
|
||||||
_ = model.encode(["warmup"], return_dense=True) # 不传 num_processes
|
_ = model.encode(["warmup"], return_dense=True) # 不传 num_processes
|
||||||
torch.cuda.device_count = orig_cnt
|
|
||||||
# ──────────────────────────────────────────────────────
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Warm-up failed: %s —— 首条请求时再退避", e)
|
logger.warning("Warm-up failed: %s — 首条请求时再退避", e)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -208,14 +197,7 @@ def _encode(texts: List[str]):
|
||||||
"""
|
"""
|
||||||
def _worker(t, q):
|
def _worker(t, q):
|
||||||
try:
|
try:
|
||||||
if DEVICE.startswith("cuda"):
|
out = model.encode(t, return_dense=True) # GPU / CPU 路径都安全
|
||||||
out = model.encode(t, return_dense=True) # GPU 正常跑
|
|
||||||
else:
|
|
||||||
# 临时屏蔽 GPU,单进程 CPU 推理
|
|
||||||
orig_cnt = torch.cuda.device_count
|
|
||||||
torch.cuda.device_count = lambda: 0
|
|
||||||
out = model.encode(t, return_dense=True) # 不传 num_processes
|
|
||||||
torch.cuda.device_count = orig_cnt
|
|
||||||
q.put(("ok", out))
|
q.put(("ok", out))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
q.put(("err", str(e)))
|
q.put(("err", str(e)))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue