auto detect GPU, VRAM, CPU

This commit is contained in:
hailin 2025-08-04 20:29:15 +08:00
parent cb54502fae
commit 627b4179a6
3 changed files with 305 additions and 46 deletions

View File

@ -1,67 +1,258 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
BGEM3 inference server (FastAPI) with robust GPU / CPU management.
Launch examples:
python server.py # 自动选卡 / 自动降级
python server.py --device 1 # 固定用第 1 张 GPU
CUDA_VISIBLE_DEVICES=0,1 python server.py
"""
import argparse
import logging
import os
import sys
import time
from typing import List, Union
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Union
import torch
from transformers import AutoTokenizer
from FlagEmbedding import BGEM3FlagModel
# 自动检测设备
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 如果是 GPU则启用 fp16 加速,否则使用 fp32
USE_FP16 = DEVICE == "cuda"
# -----------------------------------------------------------------------------#
# Config
# -----------------------------------------------------------------------------#
MODEL_PATH = "model/bge-m3" # 按需改成你的权重路径
SAFE_MIN_FREE_MB = int(
os.getenv("SAFE_MIN_FREE_MB", "4096")
) # 启动时要求的最小空闲显存MB
# -----------------------------------------------------------------------------#
# Logging
# -----------------------------------------------------------------------------#
logger = logging.getLogger("bge-m3-server")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
)
# -----------------------------------------------------------------------------#
# GPU memory helpers (NVML → torch fallback)
# -----------------------------------------------------------------------------#
try:
import pynvml
pynvml.nvmlInit()
_USE_NVML = True
except Exception:
_USE_NVML = False
logger.warning("pynvml 不可用,将使用 torch.cuda.mem_get_info() 探测显存")
def _gpu_mem_info(idx: int) -> tuple[int, int]:
"""Return (free_bytes, total_bytes) for GPU `idx`."""
if _USE_NVML:
handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
return mem.free, mem.total
# torch fallback
torch.cuda.set_device(idx)
free, total = torch.cuda.mem_get_info(idx)
return free, total
# -----------------------------------------------------------------------------#
# Precision & model loader
# -----------------------------------------------------------------------------#
def _choose_precision(device: str) -> str:
"""Return 'fp16'|'bf16'|'fp32'."""
if not device.startswith("cuda"):
return "fp32"
major, _ = torch.cuda.get_device_capability(device)
if major >= 8: # Ampere/Hopper
return "fp16"
if major >= 7:
return "bf16"
return "fp32"
def load_model(device: str):
"""Instantiate model on `device`; return (model, precision)."""
precision = _choose_precision(device)
use_fp16 = precision == "fp16"
logger.info("Loading BGEM3 on %s (%s)", device, precision)
mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device)
# Simple DataParallel for multi-GPU inference
if device.startswith("cuda") and torch.cuda.device_count() > 1:
logger.info(
"Wrapping model with torch.nn.DataParallel (%d GPUs)",
torch.cuda.device_count(),
)
mdl = torch.nn.DataParallel(mdl)
return mdl, precision
# -----------------------------------------------------------------------------#
# Auto-select device (startup)
# -----------------------------------------------------------------------------#
def auto_select_and_load(min_free_mb: int = 4096):
"""
1. Gather GPUs with free_mem min_free_mb
2. Sort by free_mem desc, attempt to load
3. All fail CPU
Return (model, device, precision)
"""
if not torch.cuda.is_available():
logger.info("No GPU detected → CPU")
return (*load_model("cpu"), "cpu")
# Build candidate list
candidates: list[tuple[int, int]] = [] # (free_MB, idx)
for idx in range(torch.cuda.device_count()):
free, _ = _gpu_mem_info(idx)
free_mb = free // 2**20
candidates.append((free_mb, idx))
candidates = [c for c in candidates if c[0] >= min_free_mb]
if not candidates:
logger.warning("All GPUs free_mem < %d MB → CPU", min_free_mb)
return (*load_model("cpu"), "cpu")
candidates.sort(reverse=True) # high free_mem first
for free_mb, idx in candidates:
dev = f"cuda:{idx}"
try:
logger.info("Trying %s (free=%d MB)", dev, free_mb)
mdl, prec = load_model(dev)
return mdl, prec, dev
except RuntimeError as e:
if "out of memory" in str(e).lower():
logger.warning("%s OOM → next GPU", dev)
torch.cuda.empty_cache()
continue
raise # non-OOM error
logger.warning("All GPUs failed → CPU fallback")
return (*load_model("cpu"), "cpu")
# -----------------------------------------------------------------------------#
# CLI
# -----------------------------------------------------------------------------#
parser = argparse.ArgumentParser()
parser.add_argument(
"--device", help="GPU index (e.g. 0) or 'cpu'; overrides auto-selection"
)
args, _ = parser.parse_known_args()
if args.device is not None:
# Forced path
if args.device.lower() == "cpu":
DEVICE = "cpu"
else:
DEVICE = f"cuda:{int(args.device)}" if torch.cuda.is_available() else "cpu"
model, PRECISION = load_model(DEVICE)
else:
# Auto path with VRAM check
model, PRECISION, DEVICE = auto_select_and_load(SAFE_MIN_FREE_MB)
# 加载模型和分词器
MODEL_PATH = "model/bge-m3"
model = BGEM3FlagModel(MODEL_PATH, use_fp16=USE_FP16, device=DEVICE)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
# -----------------------------------------------------------------------------#
# FastAPI
# -----------------------------------------------------------------------------#
app = FastAPI()
class EmbeddingRequest(BaseModel):
input: Union[str, List[str]]
model: str = "text-embedding-bge-m3"
@app.post("/v1/embeddings")
def create_embedding(request: EmbeddingRequest):
# 统一成列表
texts = [request.input] if isinstance(request.input, str) else request.input
fallback_done = False # prevent endless downgrade loop
def _encode(texts: List[str]):
"""Encode with single downgrade to CPU on OOM / CUDA failure."""
global model, DEVICE, PRECISION, fallback_done
try:
# 先用 tokenizer 统计 token 数
encoding = tokenizer(
texts,
padding=True,
truncation=True,
max_length=8192,
return_tensors="pt"
)
# attention_mask 中 1 的数量即为实际 tokens不含 padding
mask = encoding["attention_mask"]
# prompt_tokens = 所有输入 tokens 之和
prompt_tokens = int(mask.sum().item())
total_tokens = prompt_tokens # embedding 不产生额外 tokens
return model.encode(texts, return_dense=True)
# 生成 dense 向量
output = model.encode(texts, return_dense=True)
except RuntimeError as err:
is_oom = "out of memory" in str(err).lower()
is_cuda_fail = "cuda error" in str(err).lower() or "device-side assert" in str(
err
).lower()
if (is_oom or is_cuda_fail) and not fallback_done:
logger.error("GPU failure (%s). Falling back to CPU…", err)
fallback_done = True
torch.cuda.empty_cache()
DEVICE = "cpu"
model, PRECISION = load_model(DEVICE)
return model.encode(texts, return_dense=True)
raise # second failure → propagate
@app.post("/v1/embeddings")
def create_embedding(request: EmbeddingRequest):
texts = [request.input] if isinstance(request.input, str) else request.input
# Token stats
enc = tokenizer(
texts,
padding=True,
truncation=True,
max_length=8192,
return_tensors="pt",
)
prompt_tokens = int(enc["attention_mask"].sum().item())
try:
output = _encode(texts)
embeddings = output["dense_vecs"]
return {
"object": "list",
"data": [
{
"object": "embedding",
"index": idx,
"embedding": emb.tolist() if hasattr(emb, "tolist") else emb
}
for idx, emb in enumerate(embeddings)
],
"model": request.model,
"usage": {
"prompt_tokens": prompt_tokens,
"total_tokens": total_tokens
}
}
except Exception as e:
logger.exception("Embedding failed")
raise HTTPException(status_code=500, detail=str(e))
return {
"object": "list",
"data": [
{
"object": "embedding",
"index": i,
"embedding": emb.tolist() if hasattr(emb, "tolist") else emb,
}
for i, emb in enumerate(embeddings)
],
"model": request.model,
"usage": {
"prompt_tokens": prompt_tokens,
"total_tokens": prompt_tokens,
},
"device": DEVICE,
"precision": PRECISION,
}
# -----------------------------------------------------------------------------#
# Entry-point for `python server.py`
# -----------------------------------------------------------------------------#
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"server:app",
host="0.0.0.0",
port=int(os.getenv("PORT", 8000)),
log_level="info",
workers=1, # multi-process → use gunicorn/uvicorn externally
)

67
app/main.py.old Normal file
View File

@ -0,0 +1,67 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Union
import torch
from transformers import AutoTokenizer
from FlagEmbedding import BGEM3FlagModel
# 自动检测设备
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 如果是 GPU则启用 fp16 加速,否则使用 fp32
USE_FP16 = DEVICE == "cuda"
# 加载模型和分词器
MODEL_PATH = "model/bge-m3"
model = BGEM3FlagModel(MODEL_PATH, use_fp16=USE_FP16, device=DEVICE)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
app = FastAPI()
class EmbeddingRequest(BaseModel):
input: Union[str, List[str]]
model: str = "text-embedding-bge-m3"
@app.post("/v1/embeddings")
def create_embedding(request: EmbeddingRequest):
# 统一成列表
texts = [request.input] if isinstance(request.input, str) else request.input
try:
# 先用 tokenizer 统计 token 数
encoding = tokenizer(
texts,
padding=True,
truncation=True,
max_length=8192,
return_tensors="pt"
)
# attention_mask 中 1 的数量即为实际 tokens不含 padding
mask = encoding["attention_mask"]
# prompt_tokens = 所有输入 tokens 之和
prompt_tokens = int(mask.sum().item())
total_tokens = prompt_tokens # embedding 不产生额外 tokens
# 生成 dense 向量
output = model.encode(texts, return_dense=True)
embeddings = output["dense_vecs"]
return {
"object": "list",
"data": [
{
"object": "embedding",
"index": idx,
"embedding": emb.tolist() if hasattr(emb, "tolist") else emb
}
for idx, emb in enumerate(embeddings)
],
"model": request.model,
"usage": {
"prompt_tokens": prompt_tokens,
"total_tokens": total_tokens
}
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@ -4,3 +4,4 @@ torch
transformers
datasets
peft
pynvml