This commit is contained in:
hailin 2025-08-10 22:38:20 +08:00
parent 2dd0928d6e
commit 0500e81f1c
3 changed files with 522 additions and 106 deletions

View File

@ -1,18 +1,21 @@
FROM python:3.10-slim
# 安装系统依赖
RUN apt-get update && apt-get install -y gcc libglib2.0-0 && rm -rf /var/lib/apt/lists/*
RUN apt-get update && apt-get install -y --no-install-recommends \
gcc libglib2.0-0 curl && rm -rf /var/lib/apt/lists/*
# 设置工作目录
WORKDIR /app
# 安装 Python 依赖
COPY requirements.txt .
RUN pip install --upgrade pip && pip install -r requirements.txt
RUN pip install --upgrade pip && pip install --no-cache-dir -r requirements.txt
# 安装本地 FlagEmbedding 源码
COPY FlagEmbedding /opt/FlagEmbedding
RUN pip install --no-deps --upgrade -e /opt/FlagEmbedding
RUN pip install --no-cache-dir --no-deps /opt/FlagEmbedding
# 拷贝应用代码和模型权重
COPY app /app/app
@ -26,9 +29,16 @@ EXPOSE 8001
# 新增:给 PT 显存分段配置,减少碎片 (可选但推荐)
ENV PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:32
ENV TOKENIZERS_PARALLELISM=false
ENV HF_HUB_DISABLE_TELEMETRY=1 TRANSFORMERS_NO_ADVISORY_WARNINGS=1
# 健康检查:给启动和预热留时间(按你模型体量调整 start-period
HEALTHCHECK --interval=30s --timeout=5s --start-period=120s --retries=3 \
CMD curl -fsS http://127.0.0.1:8001/ready >/dev/null || exit 1
# 启动Gunicorn + 1 worker每个 worker 一个独立进程
CMD ["gunicorn", "app.main:app", \
"-k", "uvicorn.workers.UvicornWorker", \
"-w", "1", \
"-b", "0.0.0.0:8001"]
CMD ["gunicorn","app.main:app", \
"-k","uvicorn.workers.UvicornWorker", \
"-w","1","-b","0.0.0.0:8001", \
"--timeout","120","--graceful-timeout","30", \
"--max-requests","1000","--max-requests-jitter","200"]

View File

@ -12,10 +12,8 @@ Launch examples:
import argparse
import logging
import os
import sys
import time
from typing import List, Union
import multiprocessing as mp
from threading import Lock
import torch
from fastapi import FastAPI, HTTPException
@ -23,8 +21,6 @@ from pydantic import BaseModel
from transformers import AutoTokenizer
from FlagEmbedding import BGEM3FlagModel
mp.set_start_method("spawn", force=True)
# -----------------------------------------------------------------------------#
# Config
# -----------------------------------------------------------------------------#
@ -33,6 +29,27 @@ MODEL_VRAM_MB = int(os.getenv("MODEL_VRAM_MB", "8000")) # bge-m3-large fp32
POST_LOAD_GAP_MB = 192
SAFE_MIN_FREE_MB = MODEL_VRAM_MB + POST_LOAD_GAP_MB # == 8192 MB
# 请求批次与单条最大长度上限(防御异常大 payload
MAX_BATCH = int(os.getenv("MAX_BATCH", "1024"))
MAX_TEXT_LEN = int(os.getenv("MAX_TEXT_LEN", "200000"))
BATCH_SIZE = max(1, int(os.getenv("BATCH_SIZE", "32")))
# GPU 连续失败多少次后尝试重载(自愈)
GPU_MAX_CONSEC_FAILS = int(os.getenv("GPU_MAX_CONSEC_FAILS", "3"))
# CPU 兜底模型缓存
CPU_MODEL_CACHE = None
# 记录 GPU 连续失败次数;仅用于触发重载,不影响正常路径
_GPU_FAIL_COUNT = 0
_reload_lock = Lock()
_READY = False
class EmbeddingRequest(BaseModel):
input: Union[str, List[str]]
model: str = "text-embedding-bge-m3"
# -----------------------------------------------------------------------------#
# Logging
# -----------------------------------------------------------------------------#
@ -42,6 +59,10 @@ logging.basicConfig(
format="%(asctime)s [%(levelname)s] %(message)s",
)
def _http_error(status: int, code: str, message: str):
"""统一错误体:便于日志/客户端识别。"""
raise HTTPException(status_code=status, detail={"code": code, "message": message})
# -----------------------------------------------------------------------------#
# GPU memory helpers (NVML → torch fallback)
# -----------------------------------------------------------------------------#
@ -54,56 +75,55 @@ except Exception:
_USE_NVML = False
logger.warning("pynvml 不可用,将使用 torch.cuda.mem_get_info() 探测显存")
def _gpu_mem_info(idx: int) -> tuple[int, int]:
"""Return (free_bytes, total_bytes) for GPU `idx`."""
"""Return (free_bytes, total_bytes) for GPU `idx`(全局序号)."""
if _USE_NVML:
handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
return mem.free, mem.total
# torch fallback
# torch fallback(注意:若进程已屏蔽,仅能访问进程内索引)
torch.cuda.set_device(idx)
free, total = torch.cuda.mem_get_info(idx)
return free, total
# -----------------------------------------------------------------------------#
# Precision & model loader
# -----------------------------------------------------------------------------#
def _choose_precision(device: str) -> str:
"""Return 'fp16'|'bf16'|'fp32'."""
if not device.startswith("cuda"):
return "fp32"
major, _ = torch.cuda.get_device_capability(device)
def _choose_precision_by_idx(idx: int) -> str:
"""Return 'fp16' or 'fp32' by compute capability."""
try:
major, _ = torch.cuda.get_device_capability(idx) # 这里 idx 为全局序号
if major >= 8: # Ampere/Hopper
return "fp16"
if major >= 7:
return "bf16"
if major >= 7: # Volta/Turing
return "fp16"
return "fp32"
except Exception:
return "fp16" if torch.cuda.is_available() else "fp32"
def load_model(device: str):
precision = _choose_precision(device)
use_fp16 = precision == "fp16"
logger.info("Loading BGEM3 on %s (%s)", device, precision)
"""
device: "cpu" "cuda:{global_idx}"
屏蔽后在本进程内统一使用 "cuda:0" 加载避免 invalid device ordinal
"""
if device == "cpu":
# CPU 路径:彻底摘掉 CUDA
os.environ["CUDA_VISIBLE_DEVICES"] = ""
torch.cuda.is_available = lambda: False
torch.cuda.device_count = lambda: 0
else:
# GPU 路径:只暴露本 worker 选中的那张卡
# 例device == "cuda:3" → 只让当前进程看到 GPU 3
idx = device.split(":")[1]
os.environ["CUDA_VISIBLE_DEVICES"] = idx
# device_count 现在返回 1BGEM3 只会在这张卡上建 1 个子进程
torch.cuda.device_count = lambda: 1
precision = "fp32"
mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=False, device="cpu")
logger.info("Loading BGEM3 on cpu (fp32)")
return mdl, precision
mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device)
# 解析全局序号并在屏蔽前确定精度
idx = int(device.split(":")[1])
precision = _choose_precision_by_idx(idx)
use_fp16 = (precision == "fp16")
# 不再包 DataParallel每个 worker 单卡即可
# 仅暴露这张卡;进程内映射为 cuda:0
os.environ["CUDA_VISIBLE_DEVICES"] = str(idx)
mapped = "cuda:0"
logger.info("Loading BGEM3 on %s (mapped=%s, %s)", device, mapped, precision)
mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=mapped)
return mdl, precision
# -----------------------------------------------------------------------------#
@ -118,18 +138,20 @@ def auto_select_and_load() -> tuple:
"""
if not torch.cuda.is_available():
logger.info("No GPU detected → CPU")
return (*load_model("cpu"), "cpu")
mdl, prec = load_model("cpu")
return mdl, prec, "cpu"
# 收集候选卡 (free_MB, idx)
candidates = []
for idx in range(torch.cuda.device_count()):
free_mb = _gpu_mem_info(idx)[0] // 2**20
if free_mb >= MODEL_VRAM_MB: # 至少能放下权重
if free_mb >= MODEL_VRAM_MB:
candidates.append((free_mb, idx))
if not candidates:
logger.warning("All GPUs free_mem < %d MB → CPU", MODEL_VRAM_MB)
return (*load_model("cpu"), "cpu")
mdl, prec = load_model("cpu")
return mdl, prec, "cpu"
# 空闲显存从高到低
for free_mb, idx in sorted(candidates, reverse=True):
@ -138,104 +160,180 @@ def auto_select_and_load() -> tuple:
logger.info("Trying %s (free=%d MB)", dev, free_mb)
mdl, prec = load_model(dev)
# 载入后余量检查NVML 用全局 idx无 NVML 时,用进程内 0 号
if _USE_NVML:
remain_mb = _gpu_mem_info(idx)[0] // 2**20
if remain_mb < POST_LOAD_GAP_MB:
raise RuntimeError(
f"post-load free {remain_mb} MB < {POST_LOAD_GAP_MB} MB")
else:
remain_mb = _gpu_mem_info(0)[0] // 2**20
return mdl, prec, dev # 成功
if remain_mb < POST_LOAD_GAP_MB:
raise RuntimeError(f"post-load free {remain_mb} MB < {POST_LOAD_GAP_MB} MB")
return mdl, prec, dev
except RuntimeError as e:
logger.warning("%s unusable (%s) → next", dev, e)
torch.cuda.empty_cache()
logger.warning("No suitable GPU left → CPU fallback")
return (*load_model("cpu"), "cpu")
mdl, prec = load_model("cpu")
return mdl, prec, "cpu"
# -----------------------------------------------------------------------------#
# CLI
# -----------------------------------------------------------------------------#
FORCE_DEVICE = os.getenv("FORCE_DEVICE") # "cpu" 或 "0" / "1" / ...
parser = argparse.ArgumentParser()
parser.add_argument(
"--device", help="GPU index (e.g. 0) or 'cpu'; overrides auto-selection"
)
parser.add_argument("--device", help="GPU index (e.g. 0) or 'cpu'; overrides auto-selection")
args, _ = parser.parse_known_args()
if args.device is not None:
# Forced path
if FORCE_DEVICE is not None:
if FORCE_DEVICE.lower() == "cpu":
DEVICE = "cpu"
else:
DEVICE = f"cuda:{int(FORCE_DEVICE)}" if torch.cuda.is_available() else "cpu"
model, PRECISION = load_model(DEVICE)
elif args.device is not None:
if args.device.lower() == "cpu":
DEVICE = "cpu"
else:
DEVICE = f"cuda:{int(args.device)}" if torch.cuda.is_available() else "cpu"
model, PRECISION = load_model(DEVICE)
else:
# Auto path with VRAM check
model, PRECISION, DEVICE = auto_select_and_load()
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
# -----------------------------------------------------------------------------#
# FastAPI
# -----------------------------------------------------------------------------#
app = FastAPI()
logger.info("Using SAFE_MIN_FREE_MB = %d MB", SAFE_MIN_FREE_MB)
logger.info("Using SAFE_MIN_FREE_MB = %d MB, BATCH_SIZE = %d", SAFE_MIN_FREE_MB, BATCH_SIZE)
def _warm_worker(t, q):
try:
_ = model.encode(t, return_dense=True)
q.put("ok")
except Exception as e:
q.put(str(e))
# ② -------- FastAPI 启动预热 --------
@app.on_event("startup")
def warm_up():
logger.info("Warm-up on %s", DEVICE)
def _warmup():
global _READY
try:
_ = model.encode([
"This is a warmup sentence used to initialize CUDA kernels and avoid latency spikes."
], return_dense=True)
# 尝试用 batch_size 预热;不支持就回退
try:
model.encode(["warmup sentence"], return_dense=True, batch_size=BATCH_SIZE)
except TypeError:
_ = _encode_chunked(model, ["warmup sentence"], max(1, min(BATCH_SIZE, 8)))
_READY = True
logger.info("Warm-up complete.")
except Exception as e:
_READY = False
logger.warning("Warm-up failed: %s", e)
@app.get("/ready")
def ready():
if not _READY:
# 503 让编排器/HEALTHCHECK 等待
raise HTTPException(status_code=503, detail={"code": "not_ready", "message": "warming up"})
return {"status": "ready"}
class EmbeddingRequest(BaseModel):
input: Union[str, List[str]]
model: str = "text-embedding-bge-m3"
# ③ -------- _encode() 里 worker 调用 --------
def _worker(t, q):
# ---------------------- GPU 自愈:连续失败后重载模型 ---------------------- #
def _maybe_reload_gpu_model():
"""
GPU 连挂达到阈值时尝试重载失败不影响请求流程继续用 CPU 兜底
"""
global _GPU_FAIL_COUNT, model, PRECISION, DEVICE
if not str(DEVICE).startswith("cuda"):
return
if _GPU_FAIL_COUNT < GPU_MAX_CONSEC_FAILS:
return
with _reload_lock:
# 双重检查,避免并发重复 reload
if _GPU_FAIL_COUNT < GPU_MAX_CONSEC_FAILS:
return
try:
# out = model.encode(t, return_dense=True) # GPU or CPU 均安全
out = model.encode(t, return_dense=True)
q.put(("ok", out))
logger.warning("GPU unhealthy (fail_count=%d) → reloading on %s", _GPU_FAIL_COUNT, DEVICE)
model, PRECISION = load_model(DEVICE)
_GPU_FAIL_COUNT = 0
logger.info("GPU model reloaded on %s", DEVICE)
except Exception as e:
q.put(("err", str(e)))
logger.exception("GPU reload failed: %s", e)
# 保持 CPU 兜底路径继续工作
def _encode_chunked(model_obj, texts: List[str], chunk: int):
"""不依赖 batch_size 形参的手动分块编码;返回结构与 BGEM3 一致。"""
chunk = max(1, int(chunk)) # guard
dense_all = []
for i in range(0, len(texts), chunk):
part = texts[i:i+chunk]
out = model_obj.encode(part, return_dense=True) # 不传 batch_size
dense_all.extend(out["dense_vecs"])
return {"dense_vecs": dense_all}
def _encode(texts: List[str]):
"""
返回 (output_dict, served_device)
正常走 GPU OOM/CUDA 错误自动切 CPU 兜底并触发自愈计数
"""
served = DEVICE
try:
return model.encode(texts, return_dense=True)
try:
return model.encode(texts, return_dense=True, batch_size=BATCH_SIZE), served
except TypeError:
# encode() 不支持 batch_size 形参 → 手动分块
return _encode_chunked(model, texts, BATCH_SIZE), served
except RuntimeError as e:
if "out of memory" in str(e).lower() or "cuda error" in str(e).lower():
logger.warning("GPU OOM → fallback to CPU: %s", str(e))
emsg = str(e).lower()
if "out of memory" in emsg or "cuda error" in emsg:
logger.warning("GPU OOM/CUDA error → fallback to CPU: %s", str(e))
torch.cuda.empty_cache()
global CPU_MODEL_CACHE
global CPU_MODEL_CACHE, _GPU_FAIL_COUNT
if CPU_MODEL_CACHE is None:
CPU_MODEL_CACHE, _ = load_model("cpu")
return CPU_MODEL_CACHE.encode(texts, return_dense=True)
_GPU_FAIL_COUNT += 1 # 记录连续失败
_maybe_reload_gpu_model() # 达到阈值尝试自愈
served = "cpu"
# CPU 也先尝试 batch_size再回退分块
try:
return CPU_MODEL_CACHE.encode(texts, return_dense=True, batch_size=BATCH_SIZE), served
except TypeError:
return _encode_chunked(CPU_MODEL_CACHE, texts, BATCH_SIZE), served
raise
# =============== 参数规范化(入口就拦住会炸栈的输入) =============== #
def _normalize_inputs(raw) -> List[str]:
"""
行为约束
- 字符串保持原行为转为 [str]允许 ""
- 列表 要求 list[str] strip不丢弃空白保持原行为
- 空列表直接 400以前会让 HF tokenizer IndexError
- 非字符串元素400
- 批量/文本超大413防御性限制
"""
if isinstance(raw, str):
texts = [raw]
elif isinstance(raw, list):
if len(raw) == 0:
_http_error(400, "empty_list", "`input` is an empty list.")
for i, x in enumerate(raw):
if not isinstance(x, str):
_http_error(400, "non_string_item", f"`input[{i}]` must be a string.")
texts = raw
else:
_http_error(400, "invalid_type", "`input` must be a string or list of strings.")
if len(texts) > MAX_BATCH:
_http_error(413, "batch_too_large", f"batch size {len(texts)} > MAX_BATCH={MAX_BATCH}.")
for i, t in enumerate(texts):
if t is not None and len(t) > MAX_TEXT_LEN:
_http_error(413, "text_too_long", f"`input[{i}]` exceeds MAX_TEXT_LEN={MAX_TEXT_LEN}.")
return texts
# ------------------------------- 主路由 --------------------------------------#
@app.post("/v1/embeddings")
def create_embedding(request: EmbeddingRequest):
texts = [request.input] if isinstance(request.input, str) else request.input
# 1) 参数规范化(入口就拦住会炸栈的输入;不改变允许 "" 的原行为)
texts = _normalize_inputs(request.input)
# Token stats
# 2) Token 统计:失败不影响主流程(记 0
try:
enc = tokenizer(
texts,
padding=True,
@ -244,14 +342,24 @@ def create_embedding(request: EmbeddingRequest):
return_tensors="pt",
)
prompt_tokens = int(enc["attention_mask"].sum().item())
except IndexError:
# 双保险(理论上不会再发生)
_http_error(400, "empty_after_tokenize", "`input` contains no encodable texts.")
except Exception as e:
logger.warning("token counting failed: %s", e)
prompt_tokens = 0
# 3) 编码GPU→CPU 兜底;自愈计数与重载)
try:
output = _encode(texts)
output, served_device = _encode(texts)
embeddings = output["dense_vecs"]
except HTTPException:
raise
except Exception as e:
logger.exception("Embedding failed")
raise HTTPException(status_code=500, detail=str(e))
_http_error(500, "embedding_failed", str(e))
# 4) 原样返回(结构/字段保持一致)
return {
"object": "list",
"data": [
@ -267,10 +375,21 @@ def create_embedding(request: EmbeddingRequest):
"prompt_tokens": prompt_tokens,
"total_tokens": prompt_tokens,
},
"device": DEVICE,
"precision": PRECISION,
"device": served_device, # 返回真实服务设备
"precision": "fp32" if served_device == "cpu" else PRECISION,
}
# --------------------------- 健康检查(可选) --------------------------------#
@app.get("/health")
def healthz():
return {
"status": "ok",
"device": DEVICE,
"precision": PRECISION,
"gpu_fail_count": _GPU_FAIL_COUNT,
"batch_size": BATCH_SIZE,
"ready": _READY,
}
# -----------------------------------------------------------------------------#
# Entry-point for `python server.py`

287
app/main.py.ok Normal file
View File

@ -0,0 +1,287 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
BGEM3 inference server (FastAPI) with robust GPU / CPU management.
Launch examples:
python server.py # 自动选卡 / 自动降级
python server.py --device 1 # 固定用第 1 张 GPU
CUDA_VISIBLE_DEVICES=0,1 python server.py
"""
import argparse
import logging
import os
import sys
import time
from typing import List, Union
import multiprocessing as mp
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer
from FlagEmbedding import BGEM3FlagModel
mp.set_start_method("spawn", force=True)
# -----------------------------------------------------------------------------#
# Config
# -----------------------------------------------------------------------------#
MODEL_PATH = "model/bge-m3" # 按需改成你的权重路径
MODEL_VRAM_MB = int(os.getenv("MODEL_VRAM_MB", "8000")) # bge-m3-large fp32 ≈ 8 GiB
POST_LOAD_GAP_MB = 192
SAFE_MIN_FREE_MB = MODEL_VRAM_MB + POST_LOAD_GAP_MB # == 8192 MB
# -----------------------------------------------------------------------------#
# Logging
# -----------------------------------------------------------------------------#
logger = logging.getLogger("bge-m3-server")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
)
# -----------------------------------------------------------------------------#
# GPU memory helpers (NVML → torch fallback)
# -----------------------------------------------------------------------------#
try:
import pynvml
pynvml.nvmlInit()
_USE_NVML = True
except Exception:
_USE_NVML = False
logger.warning("pynvml 不可用,将使用 torch.cuda.mem_get_info() 探测显存")
def _gpu_mem_info(idx: int) -> tuple[int, int]:
"""Return (free_bytes, total_bytes) for GPU `idx`."""
if _USE_NVML:
handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
return mem.free, mem.total
# torch fallback
torch.cuda.set_device(idx)
free, total = torch.cuda.mem_get_info(idx)
return free, total
# -----------------------------------------------------------------------------#
# Precision & model loader
# -----------------------------------------------------------------------------#
def _choose_precision(device: str) -> str:
"""Return 'fp16'|'bf16'|'fp32'."""
if not device.startswith("cuda"):
return "fp32"
major, _ = torch.cuda.get_device_capability(device)
if major >= 8: # Ampere/Hopper
return "fp16"
if major >= 7:
return "bf16"
return "fp32"
def load_model(device: str):
precision = _choose_precision(device)
use_fp16 = precision == "fp16"
logger.info("Loading BGEM3 on %s (%s)", device, precision)
if device == "cpu":
# CPU 路径:彻底摘掉 CUDA
os.environ["CUDA_VISIBLE_DEVICES"] = ""
torch.cuda.is_available = lambda: False
torch.cuda.device_count = lambda: 0
else:
# GPU 路径:只暴露本 worker 选中的那张卡
# 例device == "cuda:3" → 只让当前进程看到 GPU 3
idx = device.split(":")[1]
os.environ["CUDA_VISIBLE_DEVICES"] = idx
# device_count 现在返回 1BGEM3 只会在这张卡上建 1 个子进程
torch.cuda.device_count = lambda: 1
mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device)
# 不再包 DataParallel每个 worker 单卡即可
return mdl, precision
# -----------------------------------------------------------------------------#
# Auto-select device (startup)
# -----------------------------------------------------------------------------#
def auto_select_and_load() -> tuple:
"""
1. 过滤掉空闲显存 < MODEL_VRAM_MB 的 GPU
2. 按空闲显存降序依次尝试加载
3. 载入后再次检查:若剩余 < POST_LOAD_GAP_MB → 视为失败
4. 若全部 GPU 不满足 → CPU
"""
if not torch.cuda.is_available():
logger.info("No GPU detected → CPU")
return (*load_model("cpu"), "cpu")
# 收集候选卡 (free_MB, idx)
candidates = []
for idx in range(torch.cuda.device_count()):
free_mb = _gpu_mem_info(idx)[0] // 2**20
if free_mb >= MODEL_VRAM_MB: # 至少能放下权重
candidates.append((free_mb, idx))
if not candidates:
logger.warning("All GPUs free_mem < %d MB → CPU", MODEL_VRAM_MB)
return (*load_model("cpu"), "cpu")
# 空闲显存从高到低
for free_mb, idx in sorted(candidates, reverse=True):
dev = f"cuda:{idx}"
try:
logger.info("Trying %s (free=%d MB)", dev, free_mb)
mdl, prec = load_model(dev)
remain_mb = _gpu_mem_info(idx)[0] // 2**20
if remain_mb < POST_LOAD_GAP_MB:
raise RuntimeError(
f"post-load free {remain_mb} MB < {POST_LOAD_GAP_MB} MB")
return mdl, prec, dev # 成功
except RuntimeError as e:
logger.warning("%s unusable (%s) → next", dev, e)
torch.cuda.empty_cache()
logger.warning("No suitable GPU left → CPU fallback")
return (*load_model("cpu"), "cpu")
# -----------------------------------------------------------------------------#
# CLI
# -----------------------------------------------------------------------------#
parser = argparse.ArgumentParser()
parser.add_argument(
"--device", help="GPU index (e.g. 0) or 'cpu'; overrides auto-selection"
)
args, _ = parser.parse_known_args()
if args.device is not None:
# Forced path
if args.device.lower() == "cpu":
DEVICE = "cpu"
else:
DEVICE = f"cuda:{int(args.device)}" if torch.cuda.is_available() else "cpu"
model, PRECISION = load_model(DEVICE)
else:
# Auto path with VRAM check
model, PRECISION, DEVICE = auto_select_and_load()
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
# -----------------------------------------------------------------------------#
# FastAPI
# -----------------------------------------------------------------------------#
app = FastAPI()
logger.info("Using SAFE_MIN_FREE_MB = %d MB", SAFE_MIN_FREE_MB)
def _warm_worker(t, q):
try:
_ = model.encode(t, return_dense=True)
q.put("ok")
except Exception as e:
q.put(str(e))
# ② -------- FastAPI 启动预热 --------
@app.on_event("startup")
def warm_up():
logger.info("Warm-up on %s", DEVICE)
try:
_ = model.encode([
"This is a warmup sentence used to initialize CUDA kernels and avoid latency spikes."
], return_dense=True)
logger.info("Warm-up complete.")
except Exception as e:
logger.warning("Warm-up failed: %s", e)
class EmbeddingRequest(BaseModel):
input: Union[str, List[str]]
model: str = "text-embedding-bge-m3"
# ③ -------- _encode() 里 worker 调用 --------
def _worker(t, q):
try:
# out = model.encode(t, return_dense=True) # GPU or CPU 均安全
out = model.encode(t, return_dense=True)
q.put(("ok", out))
except Exception as e:
q.put(("err", str(e)))
def _encode(texts: List[str]):
try:
return model.encode(texts, return_dense=True)
except RuntimeError as e:
if "out of memory" in str(e).lower() or "cuda error" in str(e).lower():
logger.warning("GPU OOM → fallback to CPU: %s", str(e))
torch.cuda.empty_cache()
global CPU_MODEL_CACHE
if CPU_MODEL_CACHE is None:
CPU_MODEL_CACHE, _ = load_model("cpu")
return CPU_MODEL_CACHE.encode(texts, return_dense=True)
raise
@app.post("/v1/embeddings")
def create_embedding(request: EmbeddingRequest):
texts = [request.input] if isinstance(request.input, str) else request.input
# Token stats
enc = tokenizer(
texts,
padding=True,
truncation=True,
max_length=8192,
return_tensors="pt",
)
prompt_tokens = int(enc["attention_mask"].sum().item())
try:
output = _encode(texts)
embeddings = output["dense_vecs"]
except Exception as e:
logger.exception("Embedding failed")
raise HTTPException(status_code=500, detail=str(e))
return {
"object": "list",
"data": [
{
"object": "embedding",
"index": i,
"embedding": emb.tolist() if hasattr(emb, "tolist") else emb,
}
for i, emb in enumerate(embeddings)
],
"model": request.model,
"usage": {
"prompt_tokens": prompt_tokens,
"total_tokens": prompt_tokens,
},
"device": DEVICE,
"precision": PRECISION,
}
# -----------------------------------------------------------------------------#
# Entry-point for `python server.py`
# -----------------------------------------------------------------------------#
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"server:app",
host="0.0.0.0",
port=int(os.getenv("PORT", 8000)),
log_level="info",
workers=1, # multi-process → use gunicorn/uvicorn externally
)