auto detect GPU, VRAM, CPU
This commit is contained in:
parent
cb54502fae
commit
627b4179a6
251
app/main.py
251
app/main.py
|
|
@ -1,67 +1,258 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
BGEM3 inference server (FastAPI) with robust GPU / CPU management.
|
||||||
|
|
||||||
|
Launch examples:
|
||||||
|
python server.py # 自动选卡 / 自动降级
|
||||||
|
python server.py --device 1 # 固定用第 1 张 GPU
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1 python server.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List, Union
|
|
||||||
import torch
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from FlagEmbedding import BGEM3FlagModel
|
from FlagEmbedding import BGEM3FlagModel
|
||||||
|
|
||||||
# 自动检测设备
|
# -----------------------------------------------------------------------------#
|
||||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
# Config
|
||||||
# 如果是 GPU,则启用 fp16 加速,否则使用 fp32
|
# -----------------------------------------------------------------------------#
|
||||||
USE_FP16 = DEVICE == "cuda"
|
MODEL_PATH = "model/bge-m3" # 按需改成你的权重路径
|
||||||
|
SAFE_MIN_FREE_MB = int(
|
||||||
|
os.getenv("SAFE_MIN_FREE_MB", "4096")
|
||||||
|
) # 启动时要求的最小空闲显存,MB
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
# Logging
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
logger = logging.getLogger("bge-m3-server")
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
# GPU memory helpers (NVML → torch fallback)
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
try:
|
||||||
|
import pynvml
|
||||||
|
|
||||||
|
pynvml.nvmlInit()
|
||||||
|
_USE_NVML = True
|
||||||
|
except Exception:
|
||||||
|
_USE_NVML = False
|
||||||
|
logger.warning("pynvml 不可用,将使用 torch.cuda.mem_get_info() 探测显存")
|
||||||
|
|
||||||
|
|
||||||
|
def _gpu_mem_info(idx: int) -> tuple[int, int]:
|
||||||
|
"""Return (free_bytes, total_bytes) for GPU `idx`."""
|
||||||
|
if _USE_NVML:
|
||||||
|
handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
|
||||||
|
mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||||
|
return mem.free, mem.total
|
||||||
|
# torch fallback
|
||||||
|
torch.cuda.set_device(idx)
|
||||||
|
free, total = torch.cuda.mem_get_info(idx)
|
||||||
|
return free, total
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
# Precision & model loader
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
def _choose_precision(device: str) -> str:
|
||||||
|
"""Return 'fp16'|'bf16'|'fp32'."""
|
||||||
|
if not device.startswith("cuda"):
|
||||||
|
return "fp32"
|
||||||
|
|
||||||
|
major, _ = torch.cuda.get_device_capability(device)
|
||||||
|
if major >= 8: # Ampere/Hopper
|
||||||
|
return "fp16"
|
||||||
|
if major >= 7:
|
||||||
|
return "bf16"
|
||||||
|
return "fp32"
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(device: str):
|
||||||
|
"""Instantiate model on `device`; return (model, precision)."""
|
||||||
|
precision = _choose_precision(device)
|
||||||
|
use_fp16 = precision == "fp16"
|
||||||
|
logger.info("Loading BGEM3 on %s (%s)", device, precision)
|
||||||
|
|
||||||
|
mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device)
|
||||||
|
|
||||||
|
# Simple DataParallel for multi-GPU inference
|
||||||
|
if device.startswith("cuda") and torch.cuda.device_count() > 1:
|
||||||
|
logger.info(
|
||||||
|
"Wrapping model with torch.nn.DataParallel (%d GPUs)",
|
||||||
|
torch.cuda.device_count(),
|
||||||
|
)
|
||||||
|
mdl = torch.nn.DataParallel(mdl)
|
||||||
|
return mdl, precision
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
# Auto-select device (startup)
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
def auto_select_and_load(min_free_mb: int = 4096):
|
||||||
|
"""
|
||||||
|
1. Gather GPUs with free_mem ≥ min_free_mb
|
||||||
|
2. Sort by free_mem desc, attempt to load
|
||||||
|
3. All fail → CPU
|
||||||
|
Return (model, device, precision)
|
||||||
|
"""
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
logger.info("No GPU detected → CPU")
|
||||||
|
return (*load_model("cpu"), "cpu")
|
||||||
|
|
||||||
|
# Build candidate list
|
||||||
|
candidates: list[tuple[int, int]] = [] # (free_MB, idx)
|
||||||
|
for idx in range(torch.cuda.device_count()):
|
||||||
|
free, _ = _gpu_mem_info(idx)
|
||||||
|
free_mb = free // 2**20
|
||||||
|
candidates.append((free_mb, idx))
|
||||||
|
|
||||||
|
candidates = [c for c in candidates if c[0] >= min_free_mb]
|
||||||
|
if not candidates:
|
||||||
|
logger.warning("All GPUs free_mem < %d MB → CPU", min_free_mb)
|
||||||
|
return (*load_model("cpu"), "cpu")
|
||||||
|
|
||||||
|
candidates.sort(reverse=True) # high free_mem first
|
||||||
|
for free_mb, idx in candidates:
|
||||||
|
dev = f"cuda:{idx}"
|
||||||
|
try:
|
||||||
|
logger.info("Trying %s (free=%d MB)", dev, free_mb)
|
||||||
|
mdl, prec = load_model(dev)
|
||||||
|
return mdl, prec, dev
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "out of memory" in str(e).lower():
|
||||||
|
logger.warning("%s OOM → next GPU", dev)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
continue
|
||||||
|
raise # non-OOM error
|
||||||
|
|
||||||
|
logger.warning("All GPUs failed → CPU fallback")
|
||||||
|
return (*load_model("cpu"), "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
# CLI
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--device", help="GPU index (e.g. 0) or 'cpu'; overrides auto-selection"
|
||||||
|
)
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
if args.device is not None:
|
||||||
|
# Forced path
|
||||||
|
if args.device.lower() == "cpu":
|
||||||
|
DEVICE = "cpu"
|
||||||
|
else:
|
||||||
|
DEVICE = f"cuda:{int(args.device)}" if torch.cuda.is_available() else "cpu"
|
||||||
|
model, PRECISION = load_model(DEVICE)
|
||||||
|
else:
|
||||||
|
# Auto path with VRAM check
|
||||||
|
model, PRECISION, DEVICE = auto_select_and_load(SAFE_MIN_FREE_MB)
|
||||||
|
|
||||||
# 加载模型和分词器
|
|
||||||
MODEL_PATH = "model/bge-m3"
|
|
||||||
model = BGEM3FlagModel(MODEL_PATH, use_fp16=USE_FP16, device=DEVICE)
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
# FastAPI
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingRequest(BaseModel):
|
class EmbeddingRequest(BaseModel):
|
||||||
input: Union[str, List[str]]
|
input: Union[str, List[str]]
|
||||||
model: str = "text-embedding-bge-m3"
|
model: str = "text-embedding-bge-m3"
|
||||||
|
|
||||||
@app.post("/v1/embeddings")
|
|
||||||
def create_embedding(request: EmbeddingRequest):
|
fallback_done = False # prevent endless downgrade loop
|
||||||
# 统一成列表
|
|
||||||
texts = [request.input] if isinstance(request.input, str) else request.input
|
|
||||||
|
def _encode(texts: List[str]):
|
||||||
|
"""Encode with single downgrade to CPU on OOM / CUDA failure."""
|
||||||
|
global model, DEVICE, PRECISION, fallback_done
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 先用 tokenizer 统计 token 数
|
return model.encode(texts, return_dense=True)
|
||||||
encoding = tokenizer(
|
|
||||||
|
except RuntimeError as err:
|
||||||
|
is_oom = "out of memory" in str(err).lower()
|
||||||
|
is_cuda_fail = "cuda error" in str(err).lower() or "device-side assert" in str(
|
||||||
|
err
|
||||||
|
).lower()
|
||||||
|
|
||||||
|
if (is_oom or is_cuda_fail) and not fallback_done:
|
||||||
|
logger.error("GPU failure (%s). Falling back to CPU…", err)
|
||||||
|
fallback_done = True
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
DEVICE = "cpu"
|
||||||
|
model, PRECISION = load_model(DEVICE)
|
||||||
|
return model.encode(texts, return_dense=True)
|
||||||
|
|
||||||
|
raise # second failure → propagate
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/embeddings")
|
||||||
|
def create_embedding(request: EmbeddingRequest):
|
||||||
|
texts = [request.input] if isinstance(request.input, str) else request.input
|
||||||
|
|
||||||
|
# Token stats
|
||||||
|
enc = tokenizer(
|
||||||
texts,
|
texts,
|
||||||
padding=True,
|
padding=True,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=8192,
|
max_length=8192,
|
||||||
return_tensors="pt"
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
# attention_mask 中 1 的数量即为实际 tokens(不含 padding)
|
prompt_tokens = int(enc["attention_mask"].sum().item())
|
||||||
mask = encoding["attention_mask"]
|
|
||||||
# prompt_tokens = 所有输入 tokens 之和
|
|
||||||
prompt_tokens = int(mask.sum().item())
|
|
||||||
total_tokens = prompt_tokens # embedding 不产生额外 tokens
|
|
||||||
|
|
||||||
# 生成 dense 向量
|
try:
|
||||||
output = model.encode(texts, return_dense=True)
|
output = _encode(texts)
|
||||||
embeddings = output["dense_vecs"]
|
embeddings = output["dense_vecs"]
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Embedding failed")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"data": [
|
"data": [
|
||||||
{
|
{
|
||||||
"object": "embedding",
|
"object": "embedding",
|
||||||
"index": idx,
|
"index": i,
|
||||||
"embedding": emb.tolist() if hasattr(emb, "tolist") else emb
|
"embedding": emb.tolist() if hasattr(emb, "tolist") else emb,
|
||||||
}
|
}
|
||||||
for idx, emb in enumerate(embeddings)
|
for i, emb in enumerate(embeddings)
|
||||||
],
|
],
|
||||||
"model": request.model,
|
"model": request.model,
|
||||||
"usage": {
|
"usage": {
|
||||||
"prompt_tokens": prompt_tokens,
|
"prompt_tokens": prompt_tokens,
|
||||||
"total_tokens": total_tokens
|
"total_tokens": prompt_tokens,
|
||||||
}
|
},
|
||||||
|
"device": DEVICE,
|
||||||
|
"precision": PRECISION,
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
# Entry-point for `python server.py`
|
||||||
|
# -----------------------------------------------------------------------------#
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(
|
||||||
|
"server:app",
|
||||||
|
host="0.0.0.0",
|
||||||
|
port=int(os.getenv("PORT", 8000)),
|
||||||
|
log_level="info",
|
||||||
|
workers=1, # multi-process → use gunicorn/uvicorn externally
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,67 @@
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List, Union
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from FlagEmbedding import BGEM3FlagModel
|
||||||
|
|
||||||
|
# 自动检测设备
|
||||||
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
# 如果是 GPU,则启用 fp16 加速,否则使用 fp32
|
||||||
|
USE_FP16 = DEVICE == "cuda"
|
||||||
|
|
||||||
|
# 加载模型和分词器
|
||||||
|
MODEL_PATH = "model/bge-m3"
|
||||||
|
model = BGEM3FlagModel(MODEL_PATH, use_fp16=USE_FP16, device=DEVICE)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
class EmbeddingRequest(BaseModel):
|
||||||
|
input: Union[str, List[str]]
|
||||||
|
model: str = "text-embedding-bge-m3"
|
||||||
|
|
||||||
|
@app.post("/v1/embeddings")
|
||||||
|
def create_embedding(request: EmbeddingRequest):
|
||||||
|
# 统一成列表
|
||||||
|
texts = [request.input] if isinstance(request.input, str) else request.input
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 先用 tokenizer 统计 token 数
|
||||||
|
encoding = tokenizer(
|
||||||
|
texts,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=8192,
|
||||||
|
return_tensors="pt"
|
||||||
|
)
|
||||||
|
# attention_mask 中 1 的数量即为实际 tokens(不含 padding)
|
||||||
|
mask = encoding["attention_mask"]
|
||||||
|
# prompt_tokens = 所有输入 tokens 之和
|
||||||
|
prompt_tokens = int(mask.sum().item())
|
||||||
|
total_tokens = prompt_tokens # embedding 不产生额外 tokens
|
||||||
|
|
||||||
|
# 生成 dense 向量
|
||||||
|
output = model.encode(texts, return_dense=True)
|
||||||
|
embeddings = output["dense_vecs"]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": idx,
|
||||||
|
"embedding": emb.tolist() if hasattr(emb, "tolist") else emb
|
||||||
|
}
|
||||||
|
for idx, emb in enumerate(embeddings)
|
||||||
|
],
|
||||||
|
"model": request.model,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"total_tokens": total_tokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
@ -4,3 +4,4 @@ torch
|
||||||
transformers
|
transformers
|
||||||
datasets
|
datasets
|
||||||
peft
|
peft
|
||||||
|
pynvml
|
||||||
Loading…
Reference in New Issue