283 lines
9.0 KiB
Python
283 lines
9.0 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
BGEM3 inference server (FastAPI) with robust GPU / CPU management.
|
||
|
||
Launch examples:
|
||
python server.py # 自动选卡 / 自动降级
|
||
python server.py --device 1 # 固定用第 1 张 GPU
|
||
CUDA_VISIBLE_DEVICES=0,1 python server.py
|
||
"""
|
||
|
||
import argparse
|
||
import logging
|
||
import os
|
||
import sys
|
||
import time
|
||
from typing import List, Union
|
||
import multiprocessing as mp
|
||
|
||
import torch
|
||
from fastapi import FastAPI, HTTPException
|
||
from pydantic import BaseModel
|
||
from transformers import AutoTokenizer
|
||
from FlagEmbedding import BGEM3FlagModel
|
||
|
||
# -----------------------------------------------------------------------------#
|
||
# Config
|
||
# -----------------------------------------------------------------------------#
|
||
MODEL_PATH = "model/bge-m3" # 按需改成你的权重路径
|
||
SAFE_MIN_FREE_MB = int(
|
||
os.getenv("SAFE_MIN_FREE_MB", "16384")
|
||
) # 启动时要求的最小空闲显存,MB
|
||
|
||
# -----------------------------------------------------------------------------#
|
||
# Logging
|
||
# -----------------------------------------------------------------------------#
|
||
logger = logging.getLogger("bge-m3-server")
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||
)
|
||
|
||
# -----------------------------------------------------------------------------#
|
||
# GPU memory helpers (NVML → torch fallback)
|
||
# -----------------------------------------------------------------------------#
|
||
try:
|
||
import pynvml
|
||
|
||
pynvml.nvmlInit()
|
||
_USE_NVML = True
|
||
except Exception:
|
||
_USE_NVML = False
|
||
logger.warning("pynvml 不可用,将使用 torch.cuda.mem_get_info() 探测显存")
|
||
|
||
|
||
def _gpu_mem_info(idx: int) -> tuple[int, int]:
|
||
"""Return (free_bytes, total_bytes) for GPU `idx`."""
|
||
if _USE_NVML:
|
||
handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
|
||
mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||
return mem.free, mem.total
|
||
# torch fallback
|
||
torch.cuda.set_device(idx)
|
||
free, total = torch.cuda.mem_get_info(idx)
|
||
return free, total
|
||
|
||
|
||
# -----------------------------------------------------------------------------#
|
||
# Precision & model loader
|
||
# -----------------------------------------------------------------------------#
|
||
def _choose_precision(device: str) -> str:
|
||
"""Return 'fp16'|'bf16'|'fp32'."""
|
||
if not device.startswith("cuda"):
|
||
return "fp32"
|
||
|
||
major, _ = torch.cuda.get_device_capability(device)
|
||
if major >= 8: # Ampere/Hopper
|
||
return "fp16"
|
||
if major >= 7:
|
||
return "bf16"
|
||
return "fp32"
|
||
|
||
|
||
def load_model(device: str):
|
||
precision = _choose_precision(device)
|
||
use_fp16 = precision == "fp16"
|
||
logger.info("Loading BGEM3 on %s (%s)", device, precision)
|
||
|
||
if device == "cpu":
|
||
# CPU 路径:彻底摘掉 CUDA
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||
torch.cuda.is_available = lambda: False
|
||
torch.cuda.device_count = lambda: 0
|
||
else:
|
||
# GPU 路径:只暴露本 worker 选中的那张卡
|
||
# 例:device == "cuda:3" → 只让当前进程看到 GPU 3
|
||
idx = device.split(":")[1]
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = idx
|
||
# device_count 现在返回 1,BGEM3 只会在这张卡上建 1 个子进程
|
||
torch.cuda.device_count = lambda: 1
|
||
|
||
mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device)
|
||
|
||
# 不再包 DataParallel;每个 worker 单卡即可
|
||
return mdl, precision
|
||
|
||
# -----------------------------------------------------------------------------#
|
||
# Auto-select device (startup)
|
||
# -----------------------------------------------------------------------------#
|
||
def auto_select_and_load(min_free_mb: int = 4096):
|
||
"""
|
||
1. Gather GPUs with free_mem ≥ min_free_mb
|
||
2. Sort by free_mem desc, attempt to load
|
||
3. All fail → CPU
|
||
Return (model, device, precision)
|
||
"""
|
||
if not torch.cuda.is_available():
|
||
logger.info("No GPU detected → CPU")
|
||
return (*load_model("cpu"), "cpu")
|
||
|
||
# Build candidate list
|
||
candidates: list[tuple[int, int]] = [] # (free_MB, idx)
|
||
for idx in range(torch.cuda.device_count()):
|
||
free, _ = _gpu_mem_info(idx)
|
||
free_mb = free // 2**20
|
||
candidates.append((free_mb, idx))
|
||
|
||
candidates = [c for c in candidates if c[0] >= min_free_mb]
|
||
if not candidates:
|
||
logger.warning("All GPUs free_mem < %d MB → CPU", min_free_mb)
|
||
return (*load_model("cpu"), "cpu")
|
||
|
||
candidates.sort(reverse=True) # high free_mem first
|
||
for free_mb, idx in candidates:
|
||
dev = f"cuda:{idx}"
|
||
try:
|
||
logger.info("Trying %s (free=%d MB)", dev, free_mb)
|
||
mdl, prec = load_model(dev)
|
||
return mdl, prec, dev
|
||
except RuntimeError as e:
|
||
if "out of memory" in str(e).lower():
|
||
logger.warning("%s OOM → next GPU", dev)
|
||
torch.cuda.empty_cache()
|
||
continue
|
||
raise # non-OOM error
|
||
|
||
logger.warning("All GPUs failed → CPU fallback")
|
||
return (*load_model("cpu"), "cpu")
|
||
|
||
|
||
# -----------------------------------------------------------------------------#
|
||
# CLI
|
||
# -----------------------------------------------------------------------------#
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument(
|
||
"--device", help="GPU index (e.g. 0) or 'cpu'; overrides auto-selection"
|
||
)
|
||
args, _ = parser.parse_known_args()
|
||
|
||
if args.device is not None:
|
||
# Forced path
|
||
if args.device.lower() == "cpu":
|
||
DEVICE = "cpu"
|
||
else:
|
||
DEVICE = f"cuda:{int(args.device)}" if torch.cuda.is_available() else "cpu"
|
||
model, PRECISION = load_model(DEVICE)
|
||
else:
|
||
# Auto path with VRAM check
|
||
model, PRECISION, DEVICE = auto_select_and_load(SAFE_MIN_FREE_MB)
|
||
|
||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||
|
||
# -----------------------------------------------------------------------------#
|
||
# FastAPI
|
||
# -----------------------------------------------------------------------------#
|
||
app = FastAPI()
|
||
logger.info("Using SAFE_MIN_FREE_MB = %d MB", SAFE_MIN_FREE_MB)
|
||
|
||
# ② -------- FastAPI 启动预热 --------
|
||
@app.on_event("startup")
|
||
async def warm_up():
|
||
logger.info("Warm-up on %s", DEVICE)
|
||
try:
|
||
_ = model.encode(["warmup"], return_dense=True) # GPU 建池,CPU 单进程
|
||
except Exception as e:
|
||
logger.warning("Warm-up failed: %s — 首条请求时再退避", e)
|
||
|
||
|
||
|
||
class EmbeddingRequest(BaseModel):
|
||
input: Union[str, List[str]]
|
||
model: str = "text-embedding-bge-m3"
|
||
|
||
|
||
def _encode(texts: List[str]):
|
||
"""
|
||
单次请求:
|
||
1. 子进程跑 GPU 推理;成功→返回
|
||
2. 若子进程 OOM / CUDA Error → 同一次请求 fallback 到 CPU
|
||
绝不改全局状态,其他并发请求不受影响
|
||
"""
|
||
# ③ -------- _encode() 里 worker 调用 --------
|
||
def _worker(t, q):
|
||
try:
|
||
out = model.encode(t, return_dense=True) # GPU or CPU 均安全
|
||
q.put(("ok", out))
|
||
except Exception as e:
|
||
q.put(("err", str(e)))
|
||
|
||
q = mp.Queue()
|
||
p = mp.Process(target=_worker, args=(texts, q))
|
||
p.start()
|
||
p.join(timeout=60)
|
||
|
||
if not q.empty():
|
||
status, payload = q.get()
|
||
if status == "ok":
|
||
return payload
|
||
if "out of memory" in payload.lower() or "cuda error" in payload.lower():
|
||
logger.warning("GPU OOM → 本次请求改走 CPU:%s", payload)
|
||
torch.cuda.empty_cache()
|
||
cpu_model, _ = load_model("cpu")
|
||
return cpu_model.encode(texts, return_dense=True)
|
||
raise RuntimeError(payload)
|
||
|
||
raise RuntimeError("子进程异常退出,无返回")
|
||
|
||
|
||
@app.post("/v1/embeddings")
|
||
def create_embedding(request: EmbeddingRequest):
|
||
texts = [request.input] if isinstance(request.input, str) else request.input
|
||
|
||
# Token stats
|
||
enc = tokenizer(
|
||
texts,
|
||
padding=True,
|
||
truncation=True,
|
||
max_length=8192,
|
||
return_tensors="pt",
|
||
)
|
||
prompt_tokens = int(enc["attention_mask"].sum().item())
|
||
|
||
try:
|
||
output = _encode(texts)
|
||
embeddings = output["dense_vecs"]
|
||
except Exception as e:
|
||
logger.exception("Embedding failed")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
return {
|
||
"object": "list",
|
||
"data": [
|
||
{
|
||
"object": "embedding",
|
||
"index": i,
|
||
"embedding": emb.tolist() if hasattr(emb, "tolist") else emb,
|
||
}
|
||
for i, emb in enumerate(embeddings)
|
||
],
|
||
"model": request.model,
|
||
"usage": {
|
||
"prompt_tokens": prompt_tokens,
|
||
"total_tokens": prompt_tokens,
|
||
},
|
||
"device": DEVICE,
|
||
"precision": PRECISION,
|
||
}
|
||
|
||
|
||
# -----------------------------------------------------------------------------#
|
||
# Entry-point for `python server.py`
|
||
# -----------------------------------------------------------------------------#
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
|
||
uvicorn.run(
|
||
"server:app",
|
||
host="0.0.0.0",
|
||
port=int(os.getenv("PORT", 8000)),
|
||
log_level="info",
|
||
workers=1, # multi-process → use gunicorn/uvicorn externally
|
||
)
|