auto detect GPU, VRAM, CPU
This commit is contained in:
parent
cb54502fae
commit
627b4179a6
251
app/main.py
251
app/main.py
|
|
@ -1,67 +1,258 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
BGEM3 inference server (FastAPI) with robust GPU / CPU management.
|
||||
|
||||
Launch examples:
|
||||
python server.py # 自动选卡 / 自动降级
|
||||
python server.py --device 1 # 固定用第 1 张 GPU
|
||||
CUDA_VISIBLE_DEVICES=0,1 python server.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Union
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from FlagEmbedding import BGEM3FlagModel
|
||||
|
||||
# 自动检测设备
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
# 如果是 GPU,则启用 fp16 加速,否则使用 fp32
|
||||
USE_FP16 = DEVICE == "cuda"
|
||||
# -----------------------------------------------------------------------------#
|
||||
# Config
|
||||
# -----------------------------------------------------------------------------#
|
||||
MODEL_PATH = "model/bge-m3" # 按需改成你的权重路径
|
||||
SAFE_MIN_FREE_MB = int(
|
||||
os.getenv("SAFE_MIN_FREE_MB", "4096")
|
||||
) # 启动时要求的最小空闲显存,MB
|
||||
|
||||
# -----------------------------------------------------------------------------#
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------#
|
||||
logger = logging.getLogger("bge-m3-server")
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------#
|
||||
# GPU memory helpers (NVML → torch fallback)
|
||||
# -----------------------------------------------------------------------------#
|
||||
try:
|
||||
import pynvml
|
||||
|
||||
pynvml.nvmlInit()
|
||||
_USE_NVML = True
|
||||
except Exception:
|
||||
_USE_NVML = False
|
||||
logger.warning("pynvml 不可用,将使用 torch.cuda.mem_get_info() 探测显存")
|
||||
|
||||
|
||||
def _gpu_mem_info(idx: int) -> tuple[int, int]:
|
||||
"""Return (free_bytes, total_bytes) for GPU `idx`."""
|
||||
if _USE_NVML:
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
|
||||
mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
return mem.free, mem.total
|
||||
# torch fallback
|
||||
torch.cuda.set_device(idx)
|
||||
free, total = torch.cuda.mem_get_info(idx)
|
||||
return free, total
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------#
|
||||
# Precision & model loader
|
||||
# -----------------------------------------------------------------------------#
|
||||
def _choose_precision(device: str) -> str:
|
||||
"""Return 'fp16'|'bf16'|'fp32'."""
|
||||
if not device.startswith("cuda"):
|
||||
return "fp32"
|
||||
|
||||
major, _ = torch.cuda.get_device_capability(device)
|
||||
if major >= 8: # Ampere/Hopper
|
||||
return "fp16"
|
||||
if major >= 7:
|
||||
return "bf16"
|
||||
return "fp32"
|
||||
|
||||
|
||||
def load_model(device: str):
|
||||
"""Instantiate model on `device`; return (model, precision)."""
|
||||
precision = _choose_precision(device)
|
||||
use_fp16 = precision == "fp16"
|
||||
logger.info("Loading BGEM3 on %s (%s)", device, precision)
|
||||
|
||||
mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device)
|
||||
|
||||
# Simple DataParallel for multi-GPU inference
|
||||
if device.startswith("cuda") and torch.cuda.device_count() > 1:
|
||||
logger.info(
|
||||
"Wrapping model with torch.nn.DataParallel (%d GPUs)",
|
||||
torch.cuda.device_count(),
|
||||
)
|
||||
mdl = torch.nn.DataParallel(mdl)
|
||||
return mdl, precision
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------#
|
||||
# Auto-select device (startup)
|
||||
# -----------------------------------------------------------------------------#
|
||||
def auto_select_and_load(min_free_mb: int = 4096):
|
||||
"""
|
||||
1. Gather GPUs with free_mem ≥ min_free_mb
|
||||
2. Sort by free_mem desc, attempt to load
|
||||
3. All fail → CPU
|
||||
Return (model, device, precision)
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
logger.info("No GPU detected → CPU")
|
||||
return (*load_model("cpu"), "cpu")
|
||||
|
||||
# Build candidate list
|
||||
candidates: list[tuple[int, int]] = [] # (free_MB, idx)
|
||||
for idx in range(torch.cuda.device_count()):
|
||||
free, _ = _gpu_mem_info(idx)
|
||||
free_mb = free // 2**20
|
||||
candidates.append((free_mb, idx))
|
||||
|
||||
candidates = [c for c in candidates if c[0] >= min_free_mb]
|
||||
if not candidates:
|
||||
logger.warning("All GPUs free_mem < %d MB → CPU", min_free_mb)
|
||||
return (*load_model("cpu"), "cpu")
|
||||
|
||||
candidates.sort(reverse=True) # high free_mem first
|
||||
for free_mb, idx in candidates:
|
||||
dev = f"cuda:{idx}"
|
||||
try:
|
||||
logger.info("Trying %s (free=%d MB)", dev, free_mb)
|
||||
mdl, prec = load_model(dev)
|
||||
return mdl, prec, dev
|
||||
except RuntimeError as e:
|
||||
if "out of memory" in str(e).lower():
|
||||
logger.warning("%s OOM → next GPU", dev)
|
||||
torch.cuda.empty_cache()
|
||||
continue
|
||||
raise # non-OOM error
|
||||
|
||||
logger.warning("All GPUs failed → CPU fallback")
|
||||
return (*load_model("cpu"), "cpu")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------#
|
||||
# CLI
|
||||
# -----------------------------------------------------------------------------#
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--device", help="GPU index (e.g. 0) or 'cpu'; overrides auto-selection"
|
||||
)
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if args.device is not None:
|
||||
# Forced path
|
||||
if args.device.lower() == "cpu":
|
||||
DEVICE = "cpu"
|
||||
else:
|
||||
DEVICE = f"cuda:{int(args.device)}" if torch.cuda.is_available() else "cpu"
|
||||
model, PRECISION = load_model(DEVICE)
|
||||
else:
|
||||
# Auto path with VRAM check
|
||||
model, PRECISION, DEVICE = auto_select_and_load(SAFE_MIN_FREE_MB)
|
||||
|
||||
# 加载模型和分词器
|
||||
MODEL_PATH = "model/bge-m3"
|
||||
model = BGEM3FlagModel(MODEL_PATH, use_fp16=USE_FP16, device=DEVICE)
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||
|
||||
# -----------------------------------------------------------------------------#
|
||||
# FastAPI
|
||||
# -----------------------------------------------------------------------------#
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
input: Union[str, List[str]]
|
||||
model: str = "text-embedding-bge-m3"
|
||||
|
||||
@app.post("/v1/embeddings")
|
||||
def create_embedding(request: EmbeddingRequest):
|
||||
# 统一成列表
|
||||
texts = [request.input] if isinstance(request.input, str) else request.input
|
||||
|
||||
fallback_done = False # prevent endless downgrade loop
|
||||
|
||||
|
||||
def _encode(texts: List[str]):
|
||||
"""Encode with single downgrade to CPU on OOM / CUDA failure."""
|
||||
global model, DEVICE, PRECISION, fallback_done
|
||||
|
||||
try:
|
||||
# 先用 tokenizer 统计 token 数
|
||||
encoding = tokenizer(
|
||||
return model.encode(texts, return_dense=True)
|
||||
|
||||
except RuntimeError as err:
|
||||
is_oom = "out of memory" in str(err).lower()
|
||||
is_cuda_fail = "cuda error" in str(err).lower() or "device-side assert" in str(
|
||||
err
|
||||
).lower()
|
||||
|
||||
if (is_oom or is_cuda_fail) and not fallback_done:
|
||||
logger.error("GPU failure (%s). Falling back to CPU…", err)
|
||||
fallback_done = True
|
||||
torch.cuda.empty_cache()
|
||||
DEVICE = "cpu"
|
||||
model, PRECISION = load_model(DEVICE)
|
||||
return model.encode(texts, return_dense=True)
|
||||
|
||||
raise # second failure → propagate
|
||||
|
||||
|
||||
@app.post("/v1/embeddings")
|
||||
def create_embedding(request: EmbeddingRequest):
|
||||
texts = [request.input] if isinstance(request.input, str) else request.input
|
||||
|
||||
# Token stats
|
||||
enc = tokenizer(
|
||||
texts,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=8192,
|
||||
return_tensors="pt"
|
||||
return_tensors="pt",
|
||||
)
|
||||
# attention_mask 中 1 的数量即为实际 tokens(不含 padding)
|
||||
mask = encoding["attention_mask"]
|
||||
# prompt_tokens = 所有输入 tokens 之和
|
||||
prompt_tokens = int(mask.sum().item())
|
||||
total_tokens = prompt_tokens # embedding 不产生额外 tokens
|
||||
prompt_tokens = int(enc["attention_mask"].sum().item())
|
||||
|
||||
# 生成 dense 向量
|
||||
output = model.encode(texts, return_dense=True)
|
||||
try:
|
||||
output = _encode(texts)
|
||||
embeddings = output["dense_vecs"]
|
||||
except Exception as e:
|
||||
logger.exception("Embedding failed")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": emb.tolist() if hasattr(emb, "tolist") else emb
|
||||
"index": i,
|
||||
"embedding": emb.tolist() if hasattr(emb, "tolist") else emb,
|
||||
}
|
||||
for idx, emb in enumerate(embeddings)
|
||||
for i, emb in enumerate(embeddings)
|
||||
],
|
||||
"model": request.model,
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
"total_tokens": prompt_tokens,
|
||||
},
|
||||
"device": DEVICE,
|
||||
"precision": PRECISION,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# -----------------------------------------------------------------------------#
|
||||
# Entry-point for `python server.py`
|
||||
# -----------------------------------------------------------------------------#
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"server:app",
|
||||
host="0.0.0.0",
|
||||
port=int(os.getenv("PORT", 8000)),
|
||||
log_level="info",
|
||||
workers=1, # multi-process → use gunicorn/uvicorn externally
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,67 @@
|
|||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Union
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from FlagEmbedding import BGEM3FlagModel
|
||||
|
||||
# 自动检测设备
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
# 如果是 GPU,则启用 fp16 加速,否则使用 fp32
|
||||
USE_FP16 = DEVICE == "cuda"
|
||||
|
||||
# 加载模型和分词器
|
||||
MODEL_PATH = "model/bge-m3"
|
||||
model = BGEM3FlagModel(MODEL_PATH, use_fp16=USE_FP16, device=DEVICE)
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
input: Union[str, List[str]]
|
||||
model: str = "text-embedding-bge-m3"
|
||||
|
||||
@app.post("/v1/embeddings")
|
||||
def create_embedding(request: EmbeddingRequest):
|
||||
# 统一成列表
|
||||
texts = [request.input] if isinstance(request.input, str) else request.input
|
||||
|
||||
try:
|
||||
# 先用 tokenizer 统计 token 数
|
||||
encoding = tokenizer(
|
||||
texts,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=8192,
|
||||
return_tensors="pt"
|
||||
)
|
||||
# attention_mask 中 1 的数量即为实际 tokens(不含 padding)
|
||||
mask = encoding["attention_mask"]
|
||||
# prompt_tokens = 所有输入 tokens 之和
|
||||
prompt_tokens = int(mask.sum().item())
|
||||
total_tokens = prompt_tokens # embedding 不产生额外 tokens
|
||||
|
||||
# 生成 dense 向量
|
||||
output = model.encode(texts, return_dense=True)
|
||||
embeddings = output["dense_vecs"]
|
||||
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": emb.tolist() if hasattr(emb, "tolist") else emb
|
||||
}
|
||||
for idx, emb in enumerate(embeddings)
|
||||
],
|
||||
"model": request.model,
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
|
@ -4,3 +4,4 @@ torch
|
|||
transformers
|
||||
datasets
|
||||
peft
|
||||
pynvml
|
||||
Loading…
Reference in New Issue