embed-bge-m3/app/main.py.ok

288 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
BGEM3 inference server (FastAPI) with robust GPU / CPU management.
Launch examples:
python server.py # 自动选卡 / 自动降级
python server.py --device 1 # 固定用第 1 张 GPU
CUDA_VISIBLE_DEVICES=0,1 python server.py
"""
import argparse
import logging
import os
import sys
import time
from typing import List, Union
import multiprocessing as mp
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer
from FlagEmbedding import BGEM3FlagModel
mp.set_start_method("spawn", force=True)
# -----------------------------------------------------------------------------#
# Config
# -----------------------------------------------------------------------------#
MODEL_PATH = "model/bge-m3" # 按需改成你的权重路径
MODEL_VRAM_MB = int(os.getenv("MODEL_VRAM_MB", "8000")) # bge-m3-large fp32 ≈ 8 GiB
POST_LOAD_GAP_MB = 192
SAFE_MIN_FREE_MB = MODEL_VRAM_MB + POST_LOAD_GAP_MB # == 8192 MB
# -----------------------------------------------------------------------------#
# Logging
# -----------------------------------------------------------------------------#
logger = logging.getLogger("bge-m3-server")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
)
# -----------------------------------------------------------------------------#
# GPU memory helpers (NVML → torch fallback)
# -----------------------------------------------------------------------------#
try:
import pynvml
pynvml.nvmlInit()
_USE_NVML = True
except Exception:
_USE_NVML = False
logger.warning("pynvml 不可用,将使用 torch.cuda.mem_get_info() 探测显存")
def _gpu_mem_info(idx: int) -> tuple[int, int]:
"""Return (free_bytes, total_bytes) for GPU `idx`."""
if _USE_NVML:
handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
return mem.free, mem.total
# torch fallback
torch.cuda.set_device(idx)
free, total = torch.cuda.mem_get_info(idx)
return free, total
# -----------------------------------------------------------------------------#
# Precision & model loader
# -----------------------------------------------------------------------------#
def _choose_precision(device: str) -> str:
"""Return 'fp16'|'bf16'|'fp32'."""
if not device.startswith("cuda"):
return "fp32"
major, _ = torch.cuda.get_device_capability(device)
if major >= 8: # Ampere/Hopper
return "fp16"
if major >= 7:
return "bf16"
return "fp32"
def load_model(device: str):
precision = _choose_precision(device)
use_fp16 = precision == "fp16"
logger.info("Loading BGEM3 on %s (%s)", device, precision)
if device == "cpu":
# CPU 路径:彻底摘掉 CUDA
os.environ["CUDA_VISIBLE_DEVICES"] = ""
torch.cuda.is_available = lambda: False
torch.cuda.device_count = lambda: 0
else:
# GPU 路径:只暴露本 worker 选中的那张卡
# 例device == "cuda:3" → 只让当前进程看到 GPU 3
idx = device.split(":")[1]
os.environ["CUDA_VISIBLE_DEVICES"] = idx
# device_count 现在返回 1BGEM3 只会在这张卡上建 1 个子进程
torch.cuda.device_count = lambda: 1
mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device)
# 不再包 DataParallel每个 worker 单卡即可
return mdl, precision
# -----------------------------------------------------------------------------#
# Auto-select device (startup)
# -----------------------------------------------------------------------------#
def auto_select_and_load() -> tuple:
"""
1. 过滤掉空闲显存 < MODEL_VRAM_MB 的 GPU
2. 按空闲显存降序依次尝试加载
3. 载入后再次检查:若剩余 < POST_LOAD_GAP_MB → 视为失败
4. 若全部 GPU 不满足 → CPU
"""
if not torch.cuda.is_available():
logger.info("No GPU detected → CPU")
return (*load_model("cpu"), "cpu")
# 收集候选卡 (free_MB, idx)
candidates = []
for idx in range(torch.cuda.device_count()):
free_mb = _gpu_mem_info(idx)[0] // 2**20
if free_mb >= MODEL_VRAM_MB: # 至少能放下权重
candidates.append((free_mb, idx))
if not candidates:
logger.warning("All GPUs free_mem < %d MB → CPU", MODEL_VRAM_MB)
return (*load_model("cpu"), "cpu")
# 空闲显存从高到低
for free_mb, idx in sorted(candidates, reverse=True):
dev = f"cuda:{idx}"
try:
logger.info("Trying %s (free=%d MB)", dev, free_mb)
mdl, prec = load_model(dev)
remain_mb = _gpu_mem_info(idx)[0] // 2**20
if remain_mb < POST_LOAD_GAP_MB:
raise RuntimeError(
f"post-load free {remain_mb} MB < {POST_LOAD_GAP_MB} MB")
return mdl, prec, dev # 成功
except RuntimeError as e:
logger.warning("%s unusable (%s) → next", dev, e)
torch.cuda.empty_cache()
logger.warning("No suitable GPU left → CPU fallback")
return (*load_model("cpu"), "cpu")
# -----------------------------------------------------------------------------#
# CLI
# -----------------------------------------------------------------------------#
parser = argparse.ArgumentParser()
parser.add_argument(
"--device", help="GPU index (e.g. 0) or 'cpu'; overrides auto-selection"
)
args, _ = parser.parse_known_args()
if args.device is not None:
# Forced path
if args.device.lower() == "cpu":
DEVICE = "cpu"
else:
DEVICE = f"cuda:{int(args.device)}" if torch.cuda.is_available() else "cpu"
model, PRECISION = load_model(DEVICE)
else:
# Auto path with VRAM check
model, PRECISION, DEVICE = auto_select_and_load()
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
# -----------------------------------------------------------------------------#
# FastAPI
# -----------------------------------------------------------------------------#
app = FastAPI()
logger.info("Using SAFE_MIN_FREE_MB = %d MB", SAFE_MIN_FREE_MB)
def _warm_worker(t, q):
try:
_ = model.encode(t, return_dense=True)
q.put("ok")
except Exception as e:
q.put(str(e))
# ② -------- FastAPI 启动预热 --------
@app.on_event("startup")
def warm_up():
logger.info("Warm-up on %s", DEVICE)
try:
_ = model.encode([
"This is a warmup sentence used to initialize CUDA kernels and avoid latency spikes."
], return_dense=True)
logger.info("Warm-up complete.")
except Exception as e:
logger.warning("Warm-up failed: %s", e)
class EmbeddingRequest(BaseModel):
input: Union[str, List[str]]
model: str = "text-embedding-bge-m3"
# ③ -------- _encode() 里 worker 调用 --------
def _worker(t, q):
try:
# out = model.encode(t, return_dense=True) # GPU or CPU 均安全
out = model.encode(t, return_dense=True)
q.put(("ok", out))
except Exception as e:
q.put(("err", str(e)))
def _encode(texts: List[str]):
try:
return model.encode(texts, return_dense=True)
except RuntimeError as e:
if "out of memory" in str(e).lower() or "cuda error" in str(e).lower():
logger.warning("GPU OOM → fallback to CPU: %s", str(e))
torch.cuda.empty_cache()
global CPU_MODEL_CACHE
if CPU_MODEL_CACHE is None:
CPU_MODEL_CACHE, _ = load_model("cpu")
return CPU_MODEL_CACHE.encode(texts, return_dense=True)
raise
@app.post("/v1/embeddings")
def create_embedding(request: EmbeddingRequest):
texts = [request.input] if isinstance(request.input, str) else request.input
# Token stats
enc = tokenizer(
texts,
padding=True,
truncation=True,
max_length=8192,
return_tensors="pt",
)
prompt_tokens = int(enc["attention_mask"].sum().item())
try:
output = _encode(texts)
embeddings = output["dense_vecs"]
except Exception as e:
logger.exception("Embedding failed")
raise HTTPException(status_code=500, detail=str(e))
return {
"object": "list",
"data": [
{
"object": "embedding",
"index": i,
"embedding": emb.tolist() if hasattr(emb, "tolist") else emb,
}
for i, emb in enumerate(embeddings)
],
"model": request.model,
"usage": {
"prompt_tokens": prompt_tokens,
"total_tokens": prompt_tokens,
},
"device": DEVICE,
"precision": PRECISION,
}
# -----------------------------------------------------------------------------#
# Entry-point for `python server.py`
# -----------------------------------------------------------------------------#
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"server:app",
host="0.0.0.0",
port=int(os.getenv("PORT", 8000)),
log_level="info",
workers=1, # multi-process → use gunicorn/uvicorn externally
)