This commit is contained in:
parent
2dd0928d6e
commit
0500e81f1c
20
Dockerfile
20
Dockerfile
|
|
@ -1,18 +1,21 @@
|
||||||
FROM python:3.10-slim
|
FROM python:3.10-slim
|
||||||
|
|
||||||
# 安装系统依赖
|
# 安装系统依赖
|
||||||
RUN apt-get update && apt-get install -y gcc libglib2.0-0 && rm -rf /var/lib/apt/lists/*
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
gcc libglib2.0-0 curl && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
|
||||||
# 设置工作目录
|
# 设置工作目录
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# 安装 Python 依赖
|
# 安装 Python 依赖
|
||||||
COPY requirements.txt .
|
COPY requirements.txt .
|
||||||
RUN pip install --upgrade pip && pip install -r requirements.txt
|
RUN pip install --upgrade pip && pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
# 安装本地 FlagEmbedding 源码
|
# 安装本地 FlagEmbedding 源码
|
||||||
COPY FlagEmbedding /opt/FlagEmbedding
|
COPY FlagEmbedding /opt/FlagEmbedding
|
||||||
RUN pip install --no-deps --upgrade -e /opt/FlagEmbedding
|
RUN pip install --no-cache-dir --no-deps /opt/FlagEmbedding
|
||||||
|
|
||||||
|
|
||||||
# 拷贝应用代码和模型权重
|
# 拷贝应用代码和模型权重
|
||||||
COPY app /app/app
|
COPY app /app/app
|
||||||
|
|
@ -26,9 +29,16 @@ EXPOSE 8001
|
||||||
|
|
||||||
# 新增:给 PT 显存分段配置,减少碎片 (可选但推荐)
|
# 新增:给 PT 显存分段配置,减少碎片 (可选但推荐)
|
||||||
ENV PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:32
|
ENV PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:32
|
||||||
|
ENV TOKENIZERS_PARALLELISM=false
|
||||||
|
ENV HF_HUB_DISABLE_TELEMETRY=1 TRANSFORMERS_NO_ADVISORY_WARNINGS=1
|
||||||
|
|
||||||
|
# 健康检查:给启动和预热留时间(按你模型体量调整 start-period)
|
||||||
|
HEALTHCHECK --interval=30s --timeout=5s --start-period=120s --retries=3 \
|
||||||
|
CMD curl -fsS http://127.0.0.1:8001/ready >/dev/null || exit 1
|
||||||
|
|
||||||
# 启动:Gunicorn + 1 worker,每个 worker 一个独立进程
|
# 启动:Gunicorn + 1 worker,每个 worker 一个独立进程
|
||||||
CMD ["gunicorn","app.main:app", \
|
CMD ["gunicorn","app.main:app", \
|
||||||
"-k","uvicorn.workers.UvicornWorker", \
|
"-k","uvicorn.workers.UvicornWorker", \
|
||||||
"-w", "1", \
|
"-w","1","-b","0.0.0.0:8001", \
|
||||||
"-b", "0.0.0.0:8001"]
|
"--timeout","120","--graceful-timeout","30", \
|
||||||
|
"--max-requests","1000","--max-requests-jitter","200"]
|
||||||
287
app/main.py
287
app/main.py
|
|
@ -12,10 +12,8 @@ Launch examples:
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
import multiprocessing as mp
|
from threading import Lock
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
|
|
@ -23,8 +21,6 @@ from pydantic import BaseModel
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from FlagEmbedding import BGEM3FlagModel
|
from FlagEmbedding import BGEM3FlagModel
|
||||||
|
|
||||||
mp.set_start_method("spawn", force=True)
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------#
|
# -----------------------------------------------------------------------------#
|
||||||
# Config
|
# Config
|
||||||
# -----------------------------------------------------------------------------#
|
# -----------------------------------------------------------------------------#
|
||||||
|
|
@ -33,6 +29,27 @@ MODEL_VRAM_MB = int(os.getenv("MODEL_VRAM_MB", "8000")) # bge-m3-large fp32
|
||||||
POST_LOAD_GAP_MB = 192
|
POST_LOAD_GAP_MB = 192
|
||||||
SAFE_MIN_FREE_MB = MODEL_VRAM_MB + POST_LOAD_GAP_MB # == 8192 MB
|
SAFE_MIN_FREE_MB = MODEL_VRAM_MB + POST_LOAD_GAP_MB # == 8192 MB
|
||||||
|
|
||||||
|
# 请求批次与单条最大长度上限(防御异常大 payload)
|
||||||
|
MAX_BATCH = int(os.getenv("MAX_BATCH", "1024"))
|
||||||
|
MAX_TEXT_LEN = int(os.getenv("MAX_TEXT_LEN", "200000"))
|
||||||
|
BATCH_SIZE = max(1, int(os.getenv("BATCH_SIZE", "32")))
|
||||||
|
|
||||||
|
# GPU 连续失败多少次后尝试重载(自愈)
|
||||||
|
GPU_MAX_CONSEC_FAILS = int(os.getenv("GPU_MAX_CONSEC_FAILS", "3"))
|
||||||
|
|
||||||
|
# CPU 兜底模型缓存
|
||||||
|
CPU_MODEL_CACHE = None
|
||||||
|
|
||||||
|
# 记录 GPU 连续失败次数;仅用于触发重载,不影响正常路径
|
||||||
|
_GPU_FAIL_COUNT = 0
|
||||||
|
_reload_lock = Lock()
|
||||||
|
|
||||||
|
_READY = False
|
||||||
|
|
||||||
|
class EmbeddingRequest(BaseModel):
|
||||||
|
input: Union[str, List[str]]
|
||||||
|
model: str = "text-embedding-bge-m3"
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------#
|
# -----------------------------------------------------------------------------#
|
||||||
# Logging
|
# Logging
|
||||||
# -----------------------------------------------------------------------------#
|
# -----------------------------------------------------------------------------#
|
||||||
|
|
@ -42,6 +59,10 @@ logging.basicConfig(
|
||||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _http_error(status: int, code: str, message: str):
|
||||||
|
"""统一错误体:便于日志/客户端识别。"""
|
||||||
|
raise HTTPException(status_code=status, detail={"code": code, "message": message})
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------#
|
# -----------------------------------------------------------------------------#
|
||||||
# GPU memory helpers (NVML → torch fallback)
|
# GPU memory helpers (NVML → torch fallback)
|
||||||
# -----------------------------------------------------------------------------#
|
# -----------------------------------------------------------------------------#
|
||||||
|
|
@ -54,56 +75,55 @@ except Exception:
|
||||||
_USE_NVML = False
|
_USE_NVML = False
|
||||||
logger.warning("pynvml 不可用,将使用 torch.cuda.mem_get_info() 探测显存")
|
logger.warning("pynvml 不可用,将使用 torch.cuda.mem_get_info() 探测显存")
|
||||||
|
|
||||||
|
|
||||||
def _gpu_mem_info(idx: int) -> tuple[int, int]:
|
def _gpu_mem_info(idx: int) -> tuple[int, int]:
|
||||||
"""Return (free_bytes, total_bytes) for GPU `idx`."""
|
"""Return (free_bytes, total_bytes) for GPU `idx`(全局序号)."""
|
||||||
if _USE_NVML:
|
if _USE_NVML:
|
||||||
handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
|
handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
|
||||||
mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||||
return mem.free, mem.total
|
return mem.free, mem.total
|
||||||
# torch fallback
|
# torch fallback(注意:若进程已屏蔽,仅能访问进程内索引)
|
||||||
torch.cuda.set_device(idx)
|
torch.cuda.set_device(idx)
|
||||||
free, total = torch.cuda.mem_get_info(idx)
|
free, total = torch.cuda.mem_get_info(idx)
|
||||||
return free, total
|
return free, total
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------#
|
# -----------------------------------------------------------------------------#
|
||||||
# Precision & model loader
|
# Precision & model loader
|
||||||
# -----------------------------------------------------------------------------#
|
# -----------------------------------------------------------------------------#
|
||||||
def _choose_precision(device: str) -> str:
|
def _choose_precision_by_idx(idx: int) -> str:
|
||||||
"""Return 'fp16'|'bf16'|'fp32'."""
|
"""Return 'fp16' or 'fp32' by compute capability."""
|
||||||
if not device.startswith("cuda"):
|
try:
|
||||||
return "fp32"
|
major, _ = torch.cuda.get_device_capability(idx) # 这里 idx 为全局序号
|
||||||
|
|
||||||
major, _ = torch.cuda.get_device_capability(device)
|
|
||||||
if major >= 8: # Ampere/Hopper
|
if major >= 8: # Ampere/Hopper
|
||||||
return "fp16"
|
return "fp16"
|
||||||
if major >= 7:
|
if major >= 7: # Volta/Turing
|
||||||
return "bf16"
|
return "fp16"
|
||||||
return "fp32"
|
return "fp32"
|
||||||
|
except Exception:
|
||||||
|
return "fp16" if torch.cuda.is_available() else "fp32"
|
||||||
|
|
||||||
def load_model(device: str):
|
def load_model(device: str):
|
||||||
precision = _choose_precision(device)
|
"""
|
||||||
use_fp16 = precision == "fp16"
|
device: "cpu" 或 "cuda:{global_idx}"
|
||||||
logger.info("Loading BGEM3 on %s (%s)", device, precision)
|
屏蔽后在本进程内统一使用 "cuda:0" 加载,避免 invalid device ordinal。
|
||||||
|
"""
|
||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
# CPU 路径:彻底摘掉 CUDA
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||||
torch.cuda.is_available = lambda: False
|
precision = "fp32"
|
||||||
torch.cuda.device_count = lambda: 0
|
mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=False, device="cpu")
|
||||||
else:
|
logger.info("Loading BGEM3 on cpu (fp32)")
|
||||||
# GPU 路径:只暴露本 worker 选中的那张卡
|
return mdl, precision
|
||||||
# 例:device == "cuda:3" → 只让当前进程看到 GPU 3
|
|
||||||
idx = device.split(":")[1]
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = idx
|
|
||||||
# device_count 现在返回 1,BGEM3 只会在这张卡上建 1 个子进程
|
|
||||||
torch.cuda.device_count = lambda: 1
|
|
||||||
|
|
||||||
mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device)
|
# 解析全局序号并在屏蔽前确定精度
|
||||||
|
idx = int(device.split(":")[1])
|
||||||
|
precision = _choose_precision_by_idx(idx)
|
||||||
|
use_fp16 = (precision == "fp16")
|
||||||
|
|
||||||
# 不再包 DataParallel;每个 worker 单卡即可
|
# 仅暴露这张卡;进程内映射为 cuda:0
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(idx)
|
||||||
|
mapped = "cuda:0"
|
||||||
|
|
||||||
|
logger.info("Loading BGEM3 on %s (mapped=%s, %s)", device, mapped, precision)
|
||||||
|
mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=mapped)
|
||||||
return mdl, precision
|
return mdl, precision
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------#
|
# -----------------------------------------------------------------------------#
|
||||||
|
|
@ -118,18 +138,20 @@ def auto_select_and_load() -> tuple:
|
||||||
"""
|
"""
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
logger.info("No GPU detected → CPU")
|
logger.info("No GPU detected → CPU")
|
||||||
return (*load_model("cpu"), "cpu")
|
mdl, prec = load_model("cpu")
|
||||||
|
return mdl, prec, "cpu"
|
||||||
|
|
||||||
# 收集候选卡 (free_MB, idx)
|
# 收集候选卡 (free_MB, idx)
|
||||||
candidates = []
|
candidates = []
|
||||||
for idx in range(torch.cuda.device_count()):
|
for idx in range(torch.cuda.device_count()):
|
||||||
free_mb = _gpu_mem_info(idx)[0] // 2**20
|
free_mb = _gpu_mem_info(idx)[0] // 2**20
|
||||||
if free_mb >= MODEL_VRAM_MB: # 至少能放下权重
|
if free_mb >= MODEL_VRAM_MB:
|
||||||
candidates.append((free_mb, idx))
|
candidates.append((free_mb, idx))
|
||||||
|
|
||||||
if not candidates:
|
if not candidates:
|
||||||
logger.warning("All GPUs free_mem < %d MB → CPU", MODEL_VRAM_MB)
|
logger.warning("All GPUs free_mem < %d MB → CPU", MODEL_VRAM_MB)
|
||||||
return (*load_model("cpu"), "cpu")
|
mdl, prec = load_model("cpu")
|
||||||
|
return mdl, prec, "cpu"
|
||||||
|
|
||||||
# 空闲显存从高到低
|
# 空闲显存从高到低
|
||||||
for free_mb, idx in sorted(candidates, reverse=True):
|
for free_mb, idx in sorted(candidates, reverse=True):
|
||||||
|
|
@ -138,104 +160,180 @@ def auto_select_and_load() -> tuple:
|
||||||
logger.info("Trying %s (free=%d MB)", dev, free_mb)
|
logger.info("Trying %s (free=%d MB)", dev, free_mb)
|
||||||
mdl, prec = load_model(dev)
|
mdl, prec = load_model(dev)
|
||||||
|
|
||||||
|
# 载入后余量检查:NVML 用全局 idx;无 NVML 时,用进程内 0 号
|
||||||
|
if _USE_NVML:
|
||||||
remain_mb = _gpu_mem_info(idx)[0] // 2**20
|
remain_mb = _gpu_mem_info(idx)[0] // 2**20
|
||||||
if remain_mb < POST_LOAD_GAP_MB:
|
else:
|
||||||
raise RuntimeError(
|
remain_mb = _gpu_mem_info(0)[0] // 2**20
|
||||||
f"post-load free {remain_mb} MB < {POST_LOAD_GAP_MB} MB")
|
|
||||||
|
|
||||||
return mdl, prec, dev # 成功
|
if remain_mb < POST_LOAD_GAP_MB:
|
||||||
|
raise RuntimeError(f"post-load free {remain_mb} MB < {POST_LOAD_GAP_MB} MB")
|
||||||
|
return mdl, prec, dev
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.warning("%s unusable (%s) → next", dev, e)
|
logger.warning("%s unusable (%s) → next", dev, e)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
logger.warning("No suitable GPU left → CPU fallback")
|
logger.warning("No suitable GPU left → CPU fallback")
|
||||||
return (*load_model("cpu"), "cpu")
|
mdl, prec = load_model("cpu")
|
||||||
|
return mdl, prec, "cpu"
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------#
|
# -----------------------------------------------------------------------------#
|
||||||
# CLI
|
# CLI
|
||||||
# -----------------------------------------------------------------------------#
|
# -----------------------------------------------------------------------------#
|
||||||
|
FORCE_DEVICE = os.getenv("FORCE_DEVICE") # "cpu" 或 "0" / "1" / ...
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument("--device", help="GPU index (e.g. 0) or 'cpu'; overrides auto-selection")
|
||||||
"--device", help="GPU index (e.g. 0) or 'cpu'; overrides auto-selection"
|
|
||||||
)
|
|
||||||
args, _ = parser.parse_known_args()
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
if args.device is not None:
|
if FORCE_DEVICE is not None:
|
||||||
# Forced path
|
if FORCE_DEVICE.lower() == "cpu":
|
||||||
|
DEVICE = "cpu"
|
||||||
|
else:
|
||||||
|
DEVICE = f"cuda:{int(FORCE_DEVICE)}" if torch.cuda.is_available() else "cpu"
|
||||||
|
model, PRECISION = load_model(DEVICE)
|
||||||
|
elif args.device is not None:
|
||||||
if args.device.lower() == "cpu":
|
if args.device.lower() == "cpu":
|
||||||
DEVICE = "cpu"
|
DEVICE = "cpu"
|
||||||
else:
|
else:
|
||||||
DEVICE = f"cuda:{int(args.device)}" if torch.cuda.is_available() else "cpu"
|
DEVICE = f"cuda:{int(args.device)}" if torch.cuda.is_available() else "cpu"
|
||||||
model, PRECISION = load_model(DEVICE)
|
model, PRECISION = load_model(DEVICE)
|
||||||
else:
|
else:
|
||||||
# Auto path with VRAM check
|
|
||||||
model, PRECISION, DEVICE = auto_select_and_load()
|
model, PRECISION, DEVICE = auto_select_and_load()
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------#
|
# -----------------------------------------------------------------------------#
|
||||||
# FastAPI
|
# FastAPI
|
||||||
# -----------------------------------------------------------------------------#
|
# -----------------------------------------------------------------------------#
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
logger.info("Using SAFE_MIN_FREE_MB = %d MB", SAFE_MIN_FREE_MB)
|
logger.info("Using SAFE_MIN_FREE_MB = %d MB, BATCH_SIZE = %d", SAFE_MIN_FREE_MB, BATCH_SIZE)
|
||||||
|
|
||||||
|
|
||||||
def _warm_worker(t, q):
|
|
||||||
try:
|
|
||||||
_ = model.encode(t, return_dense=True)
|
|
||||||
q.put("ok")
|
|
||||||
except Exception as e:
|
|
||||||
q.put(str(e))
|
|
||||||
|
|
||||||
# ② -------- FastAPI 启动预热 --------
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
def warm_up():
|
def _warmup():
|
||||||
logger.info("Warm-up on %s", DEVICE)
|
global _READY
|
||||||
try:
|
try:
|
||||||
_ = model.encode([
|
# 尝试用 batch_size 预热;不支持就回退
|
||||||
"This is a warmup sentence used to initialize CUDA kernels and avoid latency spikes."
|
try:
|
||||||
], return_dense=True)
|
model.encode(["warmup sentence"], return_dense=True, batch_size=BATCH_SIZE)
|
||||||
|
except TypeError:
|
||||||
|
_ = _encode_chunked(model, ["warmup sentence"], max(1, min(BATCH_SIZE, 8)))
|
||||||
|
_READY = True
|
||||||
logger.info("Warm-up complete.")
|
logger.info("Warm-up complete.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
_READY = False
|
||||||
logger.warning("Warm-up failed: %s", e)
|
logger.warning("Warm-up failed: %s", e)
|
||||||
|
|
||||||
|
@app.get("/ready")
|
||||||
|
def ready():
|
||||||
|
if not _READY:
|
||||||
|
# 503 让编排器/HEALTHCHECK 等待
|
||||||
|
raise HTTPException(status_code=503, detail={"code": "not_ready", "message": "warming up"})
|
||||||
|
return {"status": "ready"}
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingRequest(BaseModel):
|
# ---------------------- GPU 自愈:连续失败后重载模型 ---------------------- #
|
||||||
input: Union[str, List[str]]
|
def _maybe_reload_gpu_model():
|
||||||
model: str = "text-embedding-bge-m3"
|
"""
|
||||||
|
当 GPU 连挂达到阈值时尝试重载;失败不影响请求流程(继续用 CPU 兜底)。
|
||||||
|
"""
|
||||||
# ③ -------- _encode() 里 worker 调用 --------
|
global _GPU_FAIL_COUNT, model, PRECISION, DEVICE
|
||||||
def _worker(t, q):
|
if not str(DEVICE).startswith("cuda"):
|
||||||
|
return
|
||||||
|
if _GPU_FAIL_COUNT < GPU_MAX_CONSEC_FAILS:
|
||||||
|
return
|
||||||
|
with _reload_lock:
|
||||||
|
# 双重检查,避免并发重复 reload
|
||||||
|
if _GPU_FAIL_COUNT < GPU_MAX_CONSEC_FAILS:
|
||||||
|
return
|
||||||
try:
|
try:
|
||||||
# out = model.encode(t, return_dense=True) # GPU or CPU 均安全
|
logger.warning("GPU unhealthy (fail_count=%d) → reloading on %s", _GPU_FAIL_COUNT, DEVICE)
|
||||||
out = model.encode(t, return_dense=True)
|
model, PRECISION = load_model(DEVICE)
|
||||||
q.put(("ok", out))
|
_GPU_FAIL_COUNT = 0
|
||||||
|
logger.info("GPU model reloaded on %s", DEVICE)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
q.put(("err", str(e)))
|
logger.exception("GPU reload failed: %s", e)
|
||||||
|
# 保持 CPU 兜底路径继续工作
|
||||||
|
|
||||||
|
def _encode_chunked(model_obj, texts: List[str], chunk: int):
|
||||||
|
"""不依赖 batch_size 形参的手动分块编码;返回结构与 BGEM3 一致。"""
|
||||||
|
chunk = max(1, int(chunk)) # guard
|
||||||
|
dense_all = []
|
||||||
|
for i in range(0, len(texts), chunk):
|
||||||
|
part = texts[i:i+chunk]
|
||||||
|
out = model_obj.encode(part, return_dense=True) # 不传 batch_size
|
||||||
|
dense_all.extend(out["dense_vecs"])
|
||||||
|
return {"dense_vecs": dense_all}
|
||||||
|
|
||||||
def _encode(texts: List[str]):
|
def _encode(texts: List[str]):
|
||||||
|
"""
|
||||||
|
返回 (output_dict, served_device)
|
||||||
|
正常走 GPU;若 OOM/CUDA 错误,自动切 CPU 兜底并触发自愈计数。
|
||||||
|
"""
|
||||||
|
served = DEVICE
|
||||||
try:
|
try:
|
||||||
return model.encode(texts, return_dense=True)
|
try:
|
||||||
|
return model.encode(texts, return_dense=True, batch_size=BATCH_SIZE), served
|
||||||
|
except TypeError:
|
||||||
|
# encode() 不支持 batch_size 形参 → 手动分块
|
||||||
|
return _encode_chunked(model, texts, BATCH_SIZE), served
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
if "out of memory" in str(e).lower() or "cuda error" in str(e).lower():
|
emsg = str(e).lower()
|
||||||
logger.warning("GPU OOM → fallback to CPU: %s", str(e))
|
if "out of memory" in emsg or "cuda error" in emsg:
|
||||||
|
logger.warning("GPU OOM/CUDA error → fallback to CPU: %s", str(e))
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
global CPU_MODEL_CACHE
|
global CPU_MODEL_CACHE, _GPU_FAIL_COUNT
|
||||||
if CPU_MODEL_CACHE is None:
|
if CPU_MODEL_CACHE is None:
|
||||||
CPU_MODEL_CACHE, _ = load_model("cpu")
|
CPU_MODEL_CACHE, _ = load_model("cpu")
|
||||||
return CPU_MODEL_CACHE.encode(texts, return_dense=True)
|
_GPU_FAIL_COUNT += 1 # 记录连续失败
|
||||||
|
_maybe_reload_gpu_model() # 达到阈值尝试自愈
|
||||||
|
served = "cpu"
|
||||||
|
# CPU 也先尝试 batch_size,再回退分块
|
||||||
|
try:
|
||||||
|
return CPU_MODEL_CACHE.encode(texts, return_dense=True, batch_size=BATCH_SIZE), served
|
||||||
|
except TypeError:
|
||||||
|
return _encode_chunked(CPU_MODEL_CACHE, texts, BATCH_SIZE), served
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
# =============== 参数规范化(入口就拦住会炸栈的输入) =============== #
|
||||||
|
def _normalize_inputs(raw) -> List[str]:
|
||||||
|
"""
|
||||||
|
行为约束:
|
||||||
|
- 字符串:保持原行为,转为 [str](允许 "")
|
||||||
|
- 列表 :要求 list[str];不 strip、不丢弃空白(保持原行为)
|
||||||
|
- 空列表:直接 400(以前会让 HF tokenizer 抛 IndexError)
|
||||||
|
- 非字符串元素:400
|
||||||
|
- 批量/文本超大:413(防御性限制)
|
||||||
|
"""
|
||||||
|
if isinstance(raw, str):
|
||||||
|
texts = [raw]
|
||||||
|
elif isinstance(raw, list):
|
||||||
|
if len(raw) == 0:
|
||||||
|
_http_error(400, "empty_list", "`input` is an empty list.")
|
||||||
|
for i, x in enumerate(raw):
|
||||||
|
if not isinstance(x, str):
|
||||||
|
_http_error(400, "non_string_item", f"`input[{i}]` must be a string.")
|
||||||
|
texts = raw
|
||||||
|
else:
|
||||||
|
_http_error(400, "invalid_type", "`input` must be a string or list of strings.")
|
||||||
|
|
||||||
|
if len(texts) > MAX_BATCH:
|
||||||
|
_http_error(413, "batch_too_large", f"batch size {len(texts)} > MAX_BATCH={MAX_BATCH}.")
|
||||||
|
for i, t in enumerate(texts):
|
||||||
|
if t is not None and len(t) > MAX_TEXT_LEN:
|
||||||
|
_http_error(413, "text_too_long", f"`input[{i}]` exceeds MAX_TEXT_LEN={MAX_TEXT_LEN}.")
|
||||||
|
return texts
|
||||||
|
|
||||||
|
# ------------------------------- 主路由 --------------------------------------#
|
||||||
@app.post("/v1/embeddings")
|
@app.post("/v1/embeddings")
|
||||||
def create_embedding(request: EmbeddingRequest):
|
def create_embedding(request: EmbeddingRequest):
|
||||||
texts = [request.input] if isinstance(request.input, str) else request.input
|
# 1) 参数规范化(入口就拦住会炸栈的输入;不改变允许 "" 的原行为)
|
||||||
|
texts = _normalize_inputs(request.input)
|
||||||
|
|
||||||
# Token stats
|
# 2) Token 统计:失败不影响主流程(记 0)
|
||||||
|
try:
|
||||||
enc = tokenizer(
|
enc = tokenizer(
|
||||||
texts,
|
texts,
|
||||||
padding=True,
|
padding=True,
|
||||||
|
|
@ -244,14 +342,24 @@ def create_embedding(request: EmbeddingRequest):
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
prompt_tokens = int(enc["attention_mask"].sum().item())
|
prompt_tokens = int(enc["attention_mask"].sum().item())
|
||||||
|
except IndexError:
|
||||||
|
# 双保险(理论上不会再发生)
|
||||||
|
_http_error(400, "empty_after_tokenize", "`input` contains no encodable texts.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("token counting failed: %s", e)
|
||||||
|
prompt_tokens = 0
|
||||||
|
|
||||||
|
# 3) 编码(GPU→CPU 兜底;自愈计数与重载)
|
||||||
try:
|
try:
|
||||||
output = _encode(texts)
|
output, served_device = _encode(texts)
|
||||||
embeddings = output["dense_vecs"]
|
embeddings = output["dense_vecs"]
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Embedding failed")
|
logger.exception("Embedding failed")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
_http_error(500, "embedding_failed", str(e))
|
||||||
|
|
||||||
|
# 4) 原样返回(结构/字段保持一致)
|
||||||
return {
|
return {
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"data": [
|
"data": [
|
||||||
|
|
@ -267,10 +375,21 @@ def create_embedding(request: EmbeddingRequest):
|
||||||
"prompt_tokens": prompt_tokens,
|
"prompt_tokens": prompt_tokens,
|
||||||
"total_tokens": prompt_tokens,
|
"total_tokens": prompt_tokens,
|
||||||
},
|
},
|
||||||
"device": DEVICE,
|
"device": served_device, # 返回真实服务设备
|
||||||
"precision": PRECISION,
|
"precision": "fp32" if served_device == "cpu" else PRECISION,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# --------------------------- 健康检查(可选) --------------------------------#
|
||||||
|
@app.get("/health")
|
||||||
|
def healthz():
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"device": DEVICE,
|
||||||
|
"precision": PRECISION,
|
||||||
|
"gpu_fail_count": _GPU_FAIL_COUNT,
|
||||||
|
"batch_size": BATCH_SIZE,
|
||||||
|
"ready": _READY,
|
||||||
|
}
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------#
|
# -----------------------------------------------------------------------------#
|
||||||
# Entry-point for `python server.py`
|
# Entry-point for `python server.py`
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,287 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
BGEM3 inference server (FastAPI) with robust GPU / CPU management.
|
||||||
|
|
||||||
|
Launch examples:
|
||||||
|
python server.py # 自动选卡 / 自动降级
|
||||||
|
python server.py --device 1 # 固定用第 1 张 GPU
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1 python server.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import List, Union
|
||||||
|
import multiprocessing as mp
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from FlagEmbedding import BGEM3FlagModel
|
||||||
|
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
# Config
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
MODEL_PATH = "model/bge-m3" # 按需改成你的权重路径
|
||||||
|
MODEL_VRAM_MB = int(os.getenv("MODEL_VRAM_MB", "8000")) # bge-m3-large fp32 ≈ 8 GiB
|
||||||
|
POST_LOAD_GAP_MB = 192
|
||||||
|
SAFE_MIN_FREE_MB = MODEL_VRAM_MB + POST_LOAD_GAP_MB # == 8192 MB
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
# Logging
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
logger = logging.getLogger("bge-m3-server")
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
# GPU memory helpers (NVML → torch fallback)
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
try:
|
||||||
|
import pynvml
|
||||||
|
|
||||||
|
pynvml.nvmlInit()
|
||||||
|
_USE_NVML = True
|
||||||
|
except Exception:
|
||||||
|
_USE_NVML = False
|
||||||
|
logger.warning("pynvml 不可用,将使用 torch.cuda.mem_get_info() 探测显存")
|
||||||
|
|
||||||
|
|
||||||
|
def _gpu_mem_info(idx: int) -> tuple[int, int]:
|
||||||
|
"""Return (free_bytes, total_bytes) for GPU `idx`."""
|
||||||
|
if _USE_NVML:
|
||||||
|
handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
|
||||||
|
mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||||
|
return mem.free, mem.total
|
||||||
|
# torch fallback
|
||||||
|
torch.cuda.set_device(idx)
|
||||||
|
free, total = torch.cuda.mem_get_info(idx)
|
||||||
|
return free, total
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
# Precision & model loader
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
def _choose_precision(device: str) -> str:
|
||||||
|
"""Return 'fp16'|'bf16'|'fp32'."""
|
||||||
|
if not device.startswith("cuda"):
|
||||||
|
return "fp32"
|
||||||
|
|
||||||
|
major, _ = torch.cuda.get_device_capability(device)
|
||||||
|
if major >= 8: # Ampere/Hopper
|
||||||
|
return "fp16"
|
||||||
|
if major >= 7:
|
||||||
|
return "bf16"
|
||||||
|
return "fp32"
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(device: str):
|
||||||
|
precision = _choose_precision(device)
|
||||||
|
use_fp16 = precision == "fp16"
|
||||||
|
logger.info("Loading BGEM3 on %s (%s)", device, precision)
|
||||||
|
|
||||||
|
if device == "cpu":
|
||||||
|
# CPU 路径:彻底摘掉 CUDA
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||||
|
torch.cuda.is_available = lambda: False
|
||||||
|
torch.cuda.device_count = lambda: 0
|
||||||
|
else:
|
||||||
|
# GPU 路径:只暴露本 worker 选中的那张卡
|
||||||
|
# 例:device == "cuda:3" → 只让当前进程看到 GPU 3
|
||||||
|
idx = device.split(":")[1]
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = idx
|
||||||
|
# device_count 现在返回 1,BGEM3 只会在这张卡上建 1 个子进程
|
||||||
|
torch.cuda.device_count = lambda: 1
|
||||||
|
|
||||||
|
mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device)
|
||||||
|
|
||||||
|
# 不再包 DataParallel;每个 worker 单卡即可
|
||||||
|
return mdl, precision
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
# Auto-select device (startup)
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
def auto_select_and_load() -> tuple:
|
||||||
|
"""
|
||||||
|
1. 过滤掉空闲显存 < MODEL_VRAM_MB 的 GPU
|
||||||
|
2. 按空闲显存降序依次尝试加载
|
||||||
|
3. 载入后再次检查:若剩余 < POST_LOAD_GAP_MB → 视为失败
|
||||||
|
4. 若全部 GPU 不满足 → CPU
|
||||||
|
"""
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
logger.info("No GPU detected → CPU")
|
||||||
|
return (*load_model("cpu"), "cpu")
|
||||||
|
|
||||||
|
# 收集候选卡 (free_MB, idx)
|
||||||
|
candidates = []
|
||||||
|
for idx in range(torch.cuda.device_count()):
|
||||||
|
free_mb = _gpu_mem_info(idx)[0] // 2**20
|
||||||
|
if free_mb >= MODEL_VRAM_MB: # 至少能放下权重
|
||||||
|
candidates.append((free_mb, idx))
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
logger.warning("All GPUs free_mem < %d MB → CPU", MODEL_VRAM_MB)
|
||||||
|
return (*load_model("cpu"), "cpu")
|
||||||
|
|
||||||
|
# 空闲显存从高到低
|
||||||
|
for free_mb, idx in sorted(candidates, reverse=True):
|
||||||
|
dev = f"cuda:{idx}"
|
||||||
|
try:
|
||||||
|
logger.info("Trying %s (free=%d MB)", dev, free_mb)
|
||||||
|
mdl, prec = load_model(dev)
|
||||||
|
|
||||||
|
remain_mb = _gpu_mem_info(idx)[0] // 2**20
|
||||||
|
if remain_mb < POST_LOAD_GAP_MB:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"post-load free {remain_mb} MB < {POST_LOAD_GAP_MB} MB")
|
||||||
|
|
||||||
|
return mdl, prec, dev # 成功
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.warning("%s unusable (%s) → next", dev, e)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
logger.warning("No suitable GPU left → CPU fallback")
|
||||||
|
return (*load_model("cpu"), "cpu")
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
# CLI
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--device", help="GPU index (e.g. 0) or 'cpu'; overrides auto-selection"
|
||||||
|
)
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
if args.device is not None:
|
||||||
|
# Forced path
|
||||||
|
if args.device.lower() == "cpu":
|
||||||
|
DEVICE = "cpu"
|
||||||
|
else:
|
||||||
|
DEVICE = f"cuda:{int(args.device)}" if torch.cuda.is_available() else "cpu"
|
||||||
|
model, PRECISION = load_model(DEVICE)
|
||||||
|
else:
|
||||||
|
# Auto path with VRAM check
|
||||||
|
model, PRECISION, DEVICE = auto_select_and_load()
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
# FastAPI
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
app = FastAPI()
|
||||||
|
logger.info("Using SAFE_MIN_FREE_MB = %d MB", SAFE_MIN_FREE_MB)
|
||||||
|
|
||||||
|
|
||||||
|
def _warm_worker(t, q):
|
||||||
|
try:
|
||||||
|
_ = model.encode(t, return_dense=True)
|
||||||
|
q.put("ok")
|
||||||
|
except Exception as e:
|
||||||
|
q.put(str(e))
|
||||||
|
|
||||||
|
# ② -------- FastAPI 启动预热 --------
|
||||||
|
@app.on_event("startup")
|
||||||
|
def warm_up():
|
||||||
|
logger.info("Warm-up on %s", DEVICE)
|
||||||
|
try:
|
||||||
|
_ = model.encode([
|
||||||
|
"This is a warmup sentence used to initialize CUDA kernels and avoid latency spikes."
|
||||||
|
], return_dense=True)
|
||||||
|
logger.info("Warm-up complete.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Warm-up failed: %s", e)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingRequest(BaseModel):
|
||||||
|
input: Union[str, List[str]]
|
||||||
|
model: str = "text-embedding-bge-m3"
|
||||||
|
|
||||||
|
|
||||||
|
# ③ -------- _encode() 里 worker 调用 --------
|
||||||
|
def _worker(t, q):
|
||||||
|
try:
|
||||||
|
# out = model.encode(t, return_dense=True) # GPU or CPU 均安全
|
||||||
|
out = model.encode(t, return_dense=True)
|
||||||
|
q.put(("ok", out))
|
||||||
|
except Exception as e:
|
||||||
|
q.put(("err", str(e)))
|
||||||
|
|
||||||
|
|
||||||
|
def _encode(texts: List[str]):
|
||||||
|
try:
|
||||||
|
return model.encode(texts, return_dense=True)
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "out of memory" in str(e).lower() or "cuda error" in str(e).lower():
|
||||||
|
logger.warning("GPU OOM → fallback to CPU: %s", str(e))
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
global CPU_MODEL_CACHE
|
||||||
|
if CPU_MODEL_CACHE is None:
|
||||||
|
CPU_MODEL_CACHE, _ = load_model("cpu")
|
||||||
|
return CPU_MODEL_CACHE.encode(texts, return_dense=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/embeddings")
|
||||||
|
def create_embedding(request: EmbeddingRequest):
|
||||||
|
texts = [request.input] if isinstance(request.input, str) else request.input
|
||||||
|
|
||||||
|
# Token stats
|
||||||
|
enc = tokenizer(
|
||||||
|
texts,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=8192,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
prompt_tokens = int(enc["attention_mask"].sum().item())
|
||||||
|
|
||||||
|
try:
|
||||||
|
output = _encode(texts)
|
||||||
|
embeddings = output["dense_vecs"]
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Embedding failed")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": i,
|
||||||
|
"embedding": emb.tolist() if hasattr(emb, "tolist") else emb,
|
||||||
|
}
|
||||||
|
for i, emb in enumerate(embeddings)
|
||||||
|
],
|
||||||
|
"model": request.model,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"total_tokens": prompt_tokens,
|
||||||
|
},
|
||||||
|
"device": DEVICE,
|
||||||
|
"precision": PRECISION,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
# Entry-point for `python server.py`
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(
|
||||||
|
"server:app",
|
||||||
|
host="0.0.0.0",
|
||||||
|
port=int(os.getenv("PORT", 8000)),
|
||||||
|
log_level="info",
|
||||||
|
workers=1, # multi-process → use gunicorn/uvicorn externally
|
||||||
|
)
|
||||||
Loading…
Reference in New Issue