embed-bge-m3/app/main.py

433 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
BGEM3 inference server (FastAPI) with robust GPU / CPU management.
Launch examples:
python server.py # 自动选卡 / 自动降级
python server.py --device 1 # 固定用第 1 张 GPU
CUDA_VISIBLE_DEVICES=0,1 python server.py
"""
import argparse
import logging
import os
from typing import List, Union
from threading import Lock
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer
from FlagEmbedding import BGEM3FlagModel
# -----------------------------------------------------------------------------#
# Config
# -----------------------------------------------------------------------------#
MODEL_PATH = "model/bge-m3" # 按需改成你的权重路径
MODEL_VRAM_MB = int(os.getenv("MODEL_VRAM_MB", "4800")) # bge-m3-large fp16=2.4 fp32 ≈ 4.8 GiB
POST_LOAD_GAP_MB = 200
SAFE_MIN_FREE_MB = MODEL_VRAM_MB + POST_LOAD_GAP_MB # == 5000 MB
# 请求批次与单条最大长度上限(防御异常大 payload
MAX_BATCH = int(os.getenv("MAX_BATCH", "1024"))
MAX_TEXT_LEN = int(os.getenv("MAX_TEXT_LEN", "200000"))
BATCH_SIZE = max(1, int(os.getenv("BATCH_SIZE", "32")))
# GPU 连续失败多少次后尝试重载(自愈)
GPU_MAX_CONSEC_FAILS = int(os.getenv("GPU_MAX_CONSEC_FAILS", "3"))
# CPU 兜底模型缓存
CPU_MODEL_CACHE = None
# 记录 GPU 连续失败次数;仅用于触发重载,不影响正常路径
_GPU_FAIL_COUNT = 0
_reload_lock = Lock()
_READY = False
class EmbeddingRequest(BaseModel):
input: Union[str, List[str]]
model: str = "text-embedding-bge-m3"
# -----------------------------------------------------------------------------#
# Logging
# -----------------------------------------------------------------------------#
logger = logging.getLogger("bge-m3-server")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
)
def _http_error(status: int, code: str, message: str):
"""统一错误体:便于日志/客户端识别。"""
raise HTTPException(status_code=status, detail={"code": code, "message": message})
# -----------------------------------------------------------------------------#
# GPU memory helpers (NVML → torch fallback)
# -----------------------------------------------------------------------------#
try:
import pynvml
pynvml.nvmlInit()
_USE_NVML = True
except Exception:
_USE_NVML = False
logger.warning("pynvml 不可用,将使用 torch.cuda.mem_get_info() 探测显存")
def _gpu_mem_info(idx: int) -> tuple[int, int]:
"""Return (free_bytes, total_bytes) for GPU `idx`(全局序号)."""
if _USE_NVML:
handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
return mem.free, mem.total
# torch fallback注意若进程已屏蔽仅能访问进程内索引
torch.cuda.set_device(idx)
free, total = torch.cuda.mem_get_info(idx)
return free, total
# -----------------------------------------------------------------------------#
# Precision & model loader
# -----------------------------------------------------------------------------#
def _choose_precision_by_idx(idx: int) -> str:
"""Return 'fp16' or 'fp32' by compute capability."""
try:
major, _ = torch.cuda.get_device_capability(idx) # 这里 idx 为全局序号
if major >= 8: # Ampere/Hopper
return "fp16"
if major >= 7: # Volta/Turing
return "fp16"
return "fp32"
except Exception:
return "fp32"
def load_model(device: str):
"""
device: "cpu""cuda:{global_idx}"
屏蔽后在本进程内统一使用 "cuda:0" 加载,避免 invalid device ordinal。
"""
if device == "cpu":
os.environ["CUDA_VISIBLE_DEVICES"] = ""
precision = "fp32"
mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=False, device="cpu")
logger.info("Loading BGEM3 on cpu (fp32)")
return mdl, precision
# 解析全局序号并在屏蔽前确定精度
idx = int(device.split(":")[1])
precision = _choose_precision_by_idx(idx)
use_fp16 = (precision == "fp16")
# 仅暴露这张卡;进程内映射为 cuda:0
mapped = "cuda:0" if device.startswith("cuda") else "cpu"
logger.info("Loading BGEM3 on %s (mapped=%s, %s)", device, mapped, precision)
mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=mapped)
return mdl, precision
# -----------------------------------------------------------------------------#
# Auto-select device (startup)
# -----------------------------------------------------------------------------#
def auto_select_and_load() -> tuple:
"""
只用 NVML 选卡并在首次 CUDA 调用前设置 CUDA_VISIBLE_DEVICES。
选卡规则:
- 过滤空闲显存 < MODEL_VRAM_MB 的卡
- 按空闲显存降序尝试加载
- 加载后再用 NVML 复检剩余显存 < POST_LOAD_GAP_MB 则换下一张
- 全部不满足则 CPU
"""
# 1) 没有 NVML无法安全做显存筛选 → 尝试盲选 0 号卡(提前 MASK失败就 CPU
if not _USE_NVML:
if "CUDA_VISIBLE_DEVICES" not in os.environ or os.environ["CUDA_VISIBLE_DEVICES"] == "":
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
try:
mdl, prec = load_model("cuda:0") # 进程内看到的就是单卡 0
return mdl, prec, "cuda:0"
except Exception as e:
logger.warning("No NVML or CUDA unusable (%s) → CPU fallback", e)
mdl, prec = load_model("cpu")
return mdl, prec, "cpu"
# 2) NVML 可用:按空闲显存挑卡(全程不触碰 torch.cuda
try:
gpu_count = pynvml.nvmlDeviceGetCount()
except Exception as e:
logger.warning("NVML getCount failed (%s) → CPU", e)
mdl, prec = load_model("cpu")
return mdl, prec, "cpu"
candidates = []
for idx in range(gpu_count):
try:
free_b, total_b = _gpu_mem_info(idx) # NVML 路径
free_mb = free_b // 2**20
if free_mb >= MODEL_VRAM_MB:
candidates.append((free_mb, idx))
except Exception as e:
logger.warning("NVML query gpu %d failed: %s", idx, e)
if not candidates:
logger.warning("All GPUs free_mem < %d MB → CPU", MODEL_VRAM_MB)
mdl, prec = load_model("cpu")
return mdl, prec, "cpu"
# 3) 从大到小尝试加载;每次尝试前先 MASK 该卡
for free_mb, idx in sorted(candidates, reverse=True):
try:
os.environ["CUDA_VISIBLE_DEVICES"] = str(idx) # **关键:先 MASK再触碰 torch**
dev_label = f"cuda:{idx}" # 对外标注用全局序号
mdl, prec = load_model("cuda:0") # 进程内实际就是 0 号
# 载入后用 NVML 复检剩余显存(仍按全局 idx
remain_mb = _gpu_mem_info(idx)[0] // 2**20
if remain_mb < POST_LOAD_GAP_MB:
raise RuntimeError(f"post-load free {remain_mb} MB < {POST_LOAD_GAP_MB} MB")
return mdl, prec, dev_label
except Exception as e:
logger.warning("GPU %d unusable (%s) → next", idx, e)
# 不要在这里调用 torch.cuda.empty_cache(),以免无意中初始化其他设备
continue
# 4) 都不行 → CPU
logger.warning("No suitable GPU left → CPU fallback")
mdl, prec = load_model("cpu")
return mdl, prec, "cpu"
# -----------------------------------------------------------------------------#
# CLI
# -----------------------------------------------------------------------------#
FORCE_DEVICE = os.getenv("FORCE_DEVICE") # "cpu" 或 "0" / "1" / ...
parser = argparse.ArgumentParser()
parser.add_argument("--device", help="GPU index (e.g. 0) or 'cpu'; overrides auto-selection")
args, _ = parser.parse_known_args()
if FORCE_DEVICE is not None:
if FORCE_DEVICE.lower() == "cpu":
os.environ["CUDA_VISIBLE_DEVICES"] = ""
DEVICE = "cpu"
model, PRECISION = load_model("cpu")
else:
os.environ["CUDA_VISIBLE_DEVICES"] = str(int(FORCE_DEVICE)) # 先掩蔽
DEVICE = f"cuda:{int(FORCE_DEVICE)}" # 对外展示全局序号
model, PRECISION = load_model("cuda:0") # 进程内使用 0 号
elif args.device is not None:
if args.device.lower() == "cpu":
os.environ["CUDA_VISIBLE_DEVICES"] = ""
DEVICE = "cpu"
model, PRECISION = load_model("cpu")
else:
os.environ["CUDA_VISIBLE_DEVICES"] = str(int(args.device)) # 先掩蔽
DEVICE = f"cuda:{int(args.device)}"
model, PRECISION = load_model("cuda:0")
else:
model, PRECISION, DEVICE = auto_select_and_load()
# --- global tokenizer (need this or you'll get "name 'tokenizer' is not defined") ---
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, local_files_only=True)
# -----------------------------------------------------------------------------#
# FastAPI
# -----------------------------------------------------------------------------#
app = FastAPI()
logger.info("Using SAFE_MIN_FREE_MB = %d MB, BATCH_SIZE = %d", SAFE_MIN_FREE_MB, BATCH_SIZE)
@app.on_event("startup")
def _warmup():
global _READY
try:
with torch.inference_mode():
try:
model.encode(["warmup sentence"], return_dense=True, batch_size=BATCH_SIZE)
except TypeError:
_ = _encode_chunked(model, ["warmup sentence"], max(1, min(BATCH_SIZE, 8)))
_READY = True
logger.info("Warm-up complete.")
except Exception as e:
_READY = False
logger.warning("Warm-up failed: %s", e)
@app.get("/ready")
def ready():
if not _READY:
# 503 让编排器/HEALTHCHECK 等待
raise HTTPException(status_code=503, detail={"code": "not_ready", "message": "warming up"})
return {"status": "ready"}
# ---------------------- GPU 自愈:连续失败后重载模型 ---------------------- #
def _maybe_reload_gpu_model():
"""
当 GPU 连挂达到阈值时尝试重载;失败不影响请求流程(继续用 CPU 兜底)。
"""
global _GPU_FAIL_COUNT, model, PRECISION, DEVICE
if not str(DEVICE).startswith("cuda"):
return
if _GPU_FAIL_COUNT < GPU_MAX_CONSEC_FAILS:
return
with _reload_lock:
# 双重检查,避免并发重复 reload
if _GPU_FAIL_COUNT < GPU_MAX_CONSEC_FAILS:
return
try:
logger.warning("GPU unhealthy (fail_count=%d) → reloading on %s", _GPU_FAIL_COUNT, DEVICE)
model, PRECISION = load_model(DEVICE)
_GPU_FAIL_COUNT = 0
logger.info("GPU model reloaded on %s", DEVICE)
except Exception as e:
logger.exception("GPU reload failed: %s", e)
# 保持 CPU 兜底路径继续工作
def _encode_chunked(model_obj, texts: List[str], chunk: int):
"""不依赖 batch_size 形参的手动分块编码;返回结构与 BGEM3 一致。"""
chunk = max(1, int(chunk)) # guard
dense_all = []
for i in range(0, len(texts), chunk):
part = texts[i:i+chunk]
out = model_obj.encode(part, return_dense=True) # 不传 batch_size
dense_all.extend(out["dense_vecs"])
return {"dense_vecs": dense_all}
def _encode(texts: List[str]):
"""
返回 (output_dict, served_device)
正常走 GPU若 OOM/CUDA 错误,自动切 CPU 兜底并触发自愈计数。
"""
served = DEVICE
try:
try:
return model.encode(texts, return_dense=True, batch_size=BATCH_SIZE), served
except TypeError:
# encode() 不支持 batch_size 形参 → 手动分块
return _encode_chunked(model, texts, BATCH_SIZE), served
except RuntimeError as e:
emsg = str(e).lower()
if "out of memory" in emsg or "cuda error" in emsg:
logger.warning("GPU OOM/CUDA error → fallback to CPU: %s", str(e))
torch.cuda.empty_cache()
global CPU_MODEL_CACHE, _GPU_FAIL_COUNT
if CPU_MODEL_CACHE is None:
CPU_MODEL_CACHE, _ = load_model("cpu")
_GPU_FAIL_COUNT += 1 # 记录连续失败
_maybe_reload_gpu_model() # 达到阈值尝试自愈
served = "cpu"
# CPU 也先尝试 batch_size再回退分块
try:
return CPU_MODEL_CACHE.encode(texts, return_dense=True, batch_size=BATCH_SIZE), served
except TypeError:
return _encode_chunked(CPU_MODEL_CACHE, texts, BATCH_SIZE), served
raise
# =============== 参数规范化(入口就拦住会炸栈的输入) =============== #
def _normalize_inputs(raw) -> List[str]:
"""
行为约束:
- 字符串:保持原行为,转为 [str](允许 ""
- 列表 :要求 list[str];不 strip、不丢弃空白保持原行为
- 空列表:直接 400以前会让 HF tokenizer 抛 IndexError
- 非字符串元素400
- 批量/文本超大413防御性限制
"""
if isinstance(raw, str):
texts = [raw]
elif isinstance(raw, list):
if len(raw) == 0:
_http_error(400, "empty_list", "`input` is an empty list.")
for i, x in enumerate(raw):
if not isinstance(x, str):
_http_error(400, "non_string_item", f"`input[{i}]` must be a string.")
texts = raw
else:
_http_error(400, "invalid_type", "`input` must be a string or list of strings.")
if len(texts) > MAX_BATCH:
_http_error(413, "batch_too_large", f"batch size {len(texts)} > MAX_BATCH={MAX_BATCH}.")
for i, t in enumerate(texts):
if t is not None and len(t) > MAX_TEXT_LEN:
_http_error(413, "text_too_long", f"`input[{i}]` exceeds MAX_TEXT_LEN={MAX_TEXT_LEN}.")
return texts
# ------------------------------- 主路由 --------------------------------------#
@app.post("/v1/embeddings")
def create_embedding(request: EmbeddingRequest):
# 1) 参数规范化(入口就拦住会炸栈的输入;不改变允许 "" 的原行为)
texts = _normalize_inputs(request.input)
# 2) Token 统计:失败不影响主流程(记 0
try:
enc = tokenizer(
texts,
padding=True,
truncation=True,
max_length=8192,
return_tensors="pt",
)
prompt_tokens = int(enc["attention_mask"].sum().item())
except IndexError:
# 双保险(理论上不会再发生)
_http_error(400, "empty_after_tokenize", "`input` contains no encodable texts.")
except Exception as e:
logger.warning("token counting failed: %s", e)
prompt_tokens = 0
# 3) 编码GPU→CPU 兜底;自愈计数与重载)
try:
output, served_device = _encode(texts)
embeddings = output["dense_vecs"]
except HTTPException:
raise
except Exception as e:
logger.exception("Embedding failed")
_http_error(500, "embedding_failed", str(e))
# 4) 原样返回(结构/字段保持一致)
return {
"object": "list",
"data": [
{
"object": "embedding",
"index": i,
"embedding": emb.tolist() if hasattr(emb, "tolist") else emb,
}
for i, emb in enumerate(embeddings)
],
"model": request.model,
"usage": {
"prompt_tokens": prompt_tokens,
"total_tokens": prompt_tokens,
},
"device": served_device, # 返回真实服务设备
"precision": "fp32" if served_device == "cpu" else PRECISION,
}
# --------------------------- 健康检查(可选) --------------------------------#
@app.get("/health")
def healthz():
return {
"status": "ok",
"device": DEVICE,
"precision": PRECISION,
"gpu_fail_count": _GPU_FAIL_COUNT,
"batch_size": BATCH_SIZE,
"ready": _READY,
}
# -----------------------------------------------------------------------------#
# Entry-point for `python server.py`
# -----------------------------------------------------------------------------#
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"server:app",
host="0.0.0.0",
port=int(os.getenv("PORT", 8000)),
log_level="info",
workers=1, # multi-process → use gunicorn/uvicorn externally
)