embed-bge-m3/app/main.py

407 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
BGEM3 inference server (FastAPI) with robust GPU / CPU management.
Launch examples:
python server.py # 自动选卡 / 自动降级
python server.py --device 1 # 固定用第 1 张 GPU
CUDA_VISIBLE_DEVICES=0,1 python server.py
"""
import argparse
import logging
import os
from typing import List, Union
from threading import Lock
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer
from FlagEmbedding import BGEM3FlagModel
# -----------------------------------------------------------------------------#
# Config
# -----------------------------------------------------------------------------#
MODEL_PATH = "model/bge-m3" # 按需改成你的权重路径
MODEL_VRAM_MB = int(os.getenv("MODEL_VRAM_MB", "8000")) # bge-m3-large fp32 ≈ 8 GiB
POST_LOAD_GAP_MB = 192
SAFE_MIN_FREE_MB = MODEL_VRAM_MB + POST_LOAD_GAP_MB # == 8192 MB
# 请求批次与单条最大长度上限(防御异常大 payload
MAX_BATCH = int(os.getenv("MAX_BATCH", "1024"))
MAX_TEXT_LEN = int(os.getenv("MAX_TEXT_LEN", "200000"))
BATCH_SIZE = max(1, int(os.getenv("BATCH_SIZE", "32")))
# GPU 连续失败多少次后尝试重载(自愈)
GPU_MAX_CONSEC_FAILS = int(os.getenv("GPU_MAX_CONSEC_FAILS", "3"))
# CPU 兜底模型缓存
CPU_MODEL_CACHE = None
# 记录 GPU 连续失败次数;仅用于触发重载,不影响正常路径
_GPU_FAIL_COUNT = 0
_reload_lock = Lock()
_READY = False
class EmbeddingRequest(BaseModel):
input: Union[str, List[str]]
model: str = "text-embedding-bge-m3"
# -----------------------------------------------------------------------------#
# Logging
# -----------------------------------------------------------------------------#
logger = logging.getLogger("bge-m3-server")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
)
def _http_error(status: int, code: str, message: str):
"""统一错误体:便于日志/客户端识别。"""
raise HTTPException(status_code=status, detail={"code": code, "message": message})
# -----------------------------------------------------------------------------#
# GPU memory helpers (NVML → torch fallback)
# -----------------------------------------------------------------------------#
try:
import pynvml
pynvml.nvmlInit()
_USE_NVML = True
except Exception:
_USE_NVML = False
logger.warning("pynvml 不可用,将使用 torch.cuda.mem_get_info() 探测显存")
def _gpu_mem_info(idx: int) -> tuple[int, int]:
"""Return (free_bytes, total_bytes) for GPU `idx`(全局序号)."""
if _USE_NVML:
handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
return mem.free, mem.total
# torch fallback注意若进程已屏蔽仅能访问进程内索引
torch.cuda.set_device(idx)
free, total = torch.cuda.mem_get_info(idx)
return free, total
# -----------------------------------------------------------------------------#
# Precision & model loader
# -----------------------------------------------------------------------------#
def _choose_precision_by_idx(idx: int) -> str:
"""Return 'fp16' or 'fp32' by compute capability."""
try:
major, _ = torch.cuda.get_device_capability(idx) # 这里 idx 为全局序号
if major >= 8: # Ampere/Hopper
return "fp16"
if major >= 7: # Volta/Turing
return "fp16"
return "fp32"
except Exception:
return "fp16" if torch.cuda.is_available() else "fp32"
def load_model(device: str):
"""
device: "cpu""cuda:{global_idx}"
屏蔽后在本进程内统一使用 "cuda:0" 加载,避免 invalid device ordinal。
"""
if device == "cpu":
os.environ["CUDA_VISIBLE_DEVICES"] = ""
precision = "fp32"
mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=False, device="cpu")
logger.info("Loading BGEM3 on cpu (fp32)")
return mdl, precision
# 解析全局序号并在屏蔽前确定精度
idx = int(device.split(":")[1])
precision = _choose_precision_by_idx(idx)
use_fp16 = (precision == "fp16")
# 仅暴露这张卡;进程内映射为 cuda:0
os.environ["CUDA_VISIBLE_DEVICES"] = str(idx)
mapped = "cuda:0"
logger.info("Loading BGEM3 on %s (mapped=%s, %s)", device, mapped, precision)
mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=mapped)
return mdl, precision
# -----------------------------------------------------------------------------#
# Auto-select device (startup)
# -----------------------------------------------------------------------------#
def auto_select_and_load() -> tuple:
"""
1. 过滤掉空闲显存 < MODEL_VRAM_MB 的 GPU
2. 按空闲显存降序依次尝试加载
3. 载入后再次检查:若剩余 < POST_LOAD_GAP_MB → 视为失败
4. 若全部 GPU 不满足 → CPU
"""
if not torch.cuda.is_available():
logger.info("No GPU detected → CPU")
mdl, prec = load_model("cpu")
return mdl, prec, "cpu"
# 收集候选卡 (free_MB, idx)
candidates = []
for idx in range(torch.cuda.device_count()):
free_mb = _gpu_mem_info(idx)[0] // 2**20
if free_mb >= MODEL_VRAM_MB:
candidates.append((free_mb, idx))
if not candidates:
logger.warning("All GPUs free_mem < %d MB → CPU", MODEL_VRAM_MB)
mdl, prec = load_model("cpu")
return mdl, prec, "cpu"
# 空闲显存从高到低
for free_mb, idx in sorted(candidates, reverse=True):
dev = f"cuda:{idx}"
try:
logger.info("Trying %s (free=%d MB)", dev, free_mb)
mdl, prec = load_model(dev)
# 载入后余量检查NVML 用全局 idx无 NVML 时,用进程内 0 号
if _USE_NVML:
remain_mb = _gpu_mem_info(idx)[0] // 2**20
else:
remain_mb = _gpu_mem_info(0)[0] // 2**20
if remain_mb < POST_LOAD_GAP_MB:
raise RuntimeError(f"post-load free {remain_mb} MB < {POST_LOAD_GAP_MB} MB")
return mdl, prec, dev
except RuntimeError as e:
logger.warning("%s unusable (%s) → next", dev, e)
torch.cuda.empty_cache()
logger.warning("No suitable GPU left → CPU fallback")
mdl, prec = load_model("cpu")
return mdl, prec, "cpu"
# -----------------------------------------------------------------------------#
# CLI
# -----------------------------------------------------------------------------#
FORCE_DEVICE = os.getenv("FORCE_DEVICE") # "cpu" 或 "0" / "1" / ...
parser = argparse.ArgumentParser()
parser.add_argument("--device", help="GPU index (e.g. 0) or 'cpu'; overrides auto-selection")
args, _ = parser.parse_known_args()
if FORCE_DEVICE is not None:
if FORCE_DEVICE.lower() == "cpu":
DEVICE = "cpu"
else:
DEVICE = f"cuda:{int(FORCE_DEVICE)}" if torch.cuda.is_available() else "cpu"
model, PRECISION = load_model(DEVICE)
elif args.device is not None:
if args.device.lower() == "cpu":
DEVICE = "cpu"
else:
DEVICE = f"cuda:{int(args.device)}" if torch.cuda.is_available() else "cpu"
model, PRECISION = load_model(DEVICE)
else:
model, PRECISION, DEVICE = auto_select_and_load()
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
# -----------------------------------------------------------------------------#
# FastAPI
# -----------------------------------------------------------------------------#
app = FastAPI()
logger.info("Using SAFE_MIN_FREE_MB = %d MB, BATCH_SIZE = %d", SAFE_MIN_FREE_MB, BATCH_SIZE)
@app.on_event("startup")
def _warmup():
global _READY
try:
# 尝试用 batch_size 预热;不支持就回退
try:
model.encode(["warmup sentence"], return_dense=True, batch_size=BATCH_SIZE)
except TypeError:
_ = _encode_chunked(model, ["warmup sentence"], max(1, min(BATCH_SIZE, 8)))
_READY = True
logger.info("Warm-up complete.")
except Exception as e:
_READY = False
logger.warning("Warm-up failed: %s", e)
@app.get("/ready")
def ready():
if not _READY:
# 503 让编排器/HEALTHCHECK 等待
raise HTTPException(status_code=503, detail={"code": "not_ready", "message": "warming up"})
return {"status": "ready"}
# ---------------------- GPU 自愈:连续失败后重载模型 ---------------------- #
def _maybe_reload_gpu_model():
"""
当 GPU 连挂达到阈值时尝试重载;失败不影响请求流程(继续用 CPU 兜底)。
"""
global _GPU_FAIL_COUNT, model, PRECISION, DEVICE
if not str(DEVICE).startswith("cuda"):
return
if _GPU_FAIL_COUNT < GPU_MAX_CONSEC_FAILS:
return
with _reload_lock:
# 双重检查,避免并发重复 reload
if _GPU_FAIL_COUNT < GPU_MAX_CONSEC_FAILS:
return
try:
logger.warning("GPU unhealthy (fail_count=%d) → reloading on %s", _GPU_FAIL_COUNT, DEVICE)
model, PRECISION = load_model(DEVICE)
_GPU_FAIL_COUNT = 0
logger.info("GPU model reloaded on %s", DEVICE)
except Exception as e:
logger.exception("GPU reload failed: %s", e)
# 保持 CPU 兜底路径继续工作
def _encode_chunked(model_obj, texts: List[str], chunk: int):
"""不依赖 batch_size 形参的手动分块编码;返回结构与 BGEM3 一致。"""
chunk = max(1, int(chunk)) # guard
dense_all = []
for i in range(0, len(texts), chunk):
part = texts[i:i+chunk]
out = model_obj.encode(part, return_dense=True) # 不传 batch_size
dense_all.extend(out["dense_vecs"])
return {"dense_vecs": dense_all}
def _encode(texts: List[str]):
"""
返回 (output_dict, served_device)
正常走 GPU若 OOM/CUDA 错误,自动切 CPU 兜底并触发自愈计数。
"""
served = DEVICE
try:
try:
return model.encode(texts, return_dense=True, batch_size=BATCH_SIZE), served
except TypeError:
# encode() 不支持 batch_size 形参 → 手动分块
return _encode_chunked(model, texts, BATCH_SIZE), served
except RuntimeError as e:
emsg = str(e).lower()
if "out of memory" in emsg or "cuda error" in emsg:
logger.warning("GPU OOM/CUDA error → fallback to CPU: %s", str(e))
torch.cuda.empty_cache()
global CPU_MODEL_CACHE, _GPU_FAIL_COUNT
if CPU_MODEL_CACHE is None:
CPU_MODEL_CACHE, _ = load_model("cpu")
_GPU_FAIL_COUNT += 1 # 记录连续失败
_maybe_reload_gpu_model() # 达到阈值尝试自愈
served = "cpu"
# CPU 也先尝试 batch_size再回退分块
try:
return CPU_MODEL_CACHE.encode(texts, return_dense=True, batch_size=BATCH_SIZE), served
except TypeError:
return _encode_chunked(CPU_MODEL_CACHE, texts, BATCH_SIZE), served
raise
# =============== 参数规范化(入口就拦住会炸栈的输入) =============== #
def _normalize_inputs(raw) -> List[str]:
"""
行为约束:
- 字符串:保持原行为,转为 [str](允许 ""
- 列表 :要求 list[str];不 strip、不丢弃空白保持原行为
- 空列表:直接 400以前会让 HF tokenizer 抛 IndexError
- 非字符串元素400
- 批量/文本超大413防御性限制
"""
if isinstance(raw, str):
texts = [raw]
elif isinstance(raw, list):
if len(raw) == 0:
_http_error(400, "empty_list", "`input` is an empty list.")
for i, x in enumerate(raw):
if not isinstance(x, str):
_http_error(400, "non_string_item", f"`input[{i}]` must be a string.")
texts = raw
else:
_http_error(400, "invalid_type", "`input` must be a string or list of strings.")
if len(texts) > MAX_BATCH:
_http_error(413, "batch_too_large", f"batch size {len(texts)} > MAX_BATCH={MAX_BATCH}.")
for i, t in enumerate(texts):
if t is not None and len(t) > MAX_TEXT_LEN:
_http_error(413, "text_too_long", f"`input[{i}]` exceeds MAX_TEXT_LEN={MAX_TEXT_LEN}.")
return texts
# ------------------------------- 主路由 --------------------------------------#
@app.post("/v1/embeddings")
def create_embedding(request: EmbeddingRequest):
# 1) 参数规范化(入口就拦住会炸栈的输入;不改变允许 "" 的原行为)
texts = _normalize_inputs(request.input)
# 2) Token 统计:失败不影响主流程(记 0
try:
enc = tokenizer(
texts,
padding=True,
truncation=True,
max_length=8192,
return_tensors="pt",
)
prompt_tokens = int(enc["attention_mask"].sum().item())
except IndexError:
# 双保险(理论上不会再发生)
_http_error(400, "empty_after_tokenize", "`input` contains no encodable texts.")
except Exception as e:
logger.warning("token counting failed: %s", e)
prompt_tokens = 0
# 3) 编码GPU→CPU 兜底;自愈计数与重载)
try:
output, served_device = _encode(texts)
embeddings = output["dense_vecs"]
except HTTPException:
raise
except Exception as e:
logger.exception("Embedding failed")
_http_error(500, "embedding_failed", str(e))
# 4) 原样返回(结构/字段保持一致)
return {
"object": "list",
"data": [
{
"object": "embedding",
"index": i,
"embedding": emb.tolist() if hasattr(emb, "tolist") else emb,
}
for i, emb in enumerate(embeddings)
],
"model": request.model,
"usage": {
"prompt_tokens": prompt_tokens,
"total_tokens": prompt_tokens,
},
"device": served_device, # 返回真实服务设备
"precision": "fp32" if served_device == "cpu" else PRECISION,
}
# --------------------------- 健康检查(可选) --------------------------------#
@app.get("/health")
def healthz():
return {
"status": "ok",
"device": DEVICE,
"precision": PRECISION,
"gpu_fail_count": _GPU_FAIL_COUNT,
"batch_size": BATCH_SIZE,
"ready": _READY,
}
# -----------------------------------------------------------------------------#
# Entry-point for `python server.py`
# -----------------------------------------------------------------------------#
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"server:app",
host="0.0.0.0",
port=int(os.getenv("PORT", 8000)),
log_level="info",
workers=1, # multi-process → use gunicorn/uvicorn externally
)