#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ BGEM3 inference server (FastAPI) with robust GPU / CPU management. Launch examples: python server.py # 自动选卡 / 自动降级 python server.py --device 1 # 固定用第 1 张 GPU CUDA_VISIBLE_DEVICES=0,1 python server.py """ import argparse import logging import os import sys import time from typing import List, Union import multiprocessing as mp import torch from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer from FlagEmbedding import BGEM3FlagModel # -----------------------------------------------------------------------------# # Config # -----------------------------------------------------------------------------# MODEL_PATH = "model/bge-m3" # 按需改成你的权重路径 SAFE_MIN_FREE_MB = int( os.getenv("SAFE_MIN_FREE_MB", "16384") ) # 启动时要求的最小空闲显存,MB # -----------------------------------------------------------------------------# # Logging # -----------------------------------------------------------------------------# logger = logging.getLogger("bge-m3-server") logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", ) # -----------------------------------------------------------------------------# # GPU memory helpers (NVML → torch fallback) # -----------------------------------------------------------------------------# try: import pynvml pynvml.nvmlInit() _USE_NVML = True except Exception: _USE_NVML = False logger.warning("pynvml 不可用,将使用 torch.cuda.mem_get_info() 探测显存") def _gpu_mem_info(idx: int) -> tuple[int, int]: """Return (free_bytes, total_bytes) for GPU `idx`.""" if _USE_NVML: handle = pynvml.nvmlDeviceGetHandleByIndex(idx) mem = pynvml.nvmlDeviceGetMemoryInfo(handle) return mem.free, mem.total # torch fallback torch.cuda.set_device(idx) free, total = torch.cuda.mem_get_info(idx) return free, total # -----------------------------------------------------------------------------# # Precision & model loader # -----------------------------------------------------------------------------# def _choose_precision(device: str) -> str: """Return 'fp16'|'bf16'|'fp32'.""" if not device.startswith("cuda"): return "fp32" major, _ = torch.cuda.get_device_capability(device) if major >= 8: # Ampere/Hopper return "fp16" if major >= 7: return "bf16" return "fp32" def load_model(device: str): """Instantiate model on `device`; return (model, precision).""" precision = _choose_precision(device) use_fp16 = precision == "fp16" logger.info("Loading BGEM3 on %s (%s)", device, precision) mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device) if device == "cpu": # 屏蔽 GPU,让后续 torch / BGEM3 都认不出 CUDA os.environ["CUDA_VISIBLE_DEVICES"] = "" # Simple DataParallel for multi-GPU inference if device.startswith("cuda") and torch.cuda.device_count() > 1: logger.info( "Wrapping model with torch.nn.DataParallel (%d GPUs)", torch.cuda.device_count(), ) mdl = torch.nn.DataParallel(mdl) return mdl, precision # -----------------------------------------------------------------------------# # Auto-select device (startup) # -----------------------------------------------------------------------------# def auto_select_and_load(min_free_mb: int = 4096): """ 1. Gather GPUs with free_mem ≥ min_free_mb 2. Sort by free_mem desc, attempt to load 3. All fail → CPU Return (model, device, precision) """ if not torch.cuda.is_available(): logger.info("No GPU detected → CPU") return (*load_model("cpu"), "cpu") # Build candidate list candidates: list[tuple[int, int]] = [] # (free_MB, idx) for idx in range(torch.cuda.device_count()): free, _ = _gpu_mem_info(idx) free_mb = free // 2**20 candidates.append((free_mb, idx)) candidates = [c for c in candidates if c[0] >= min_free_mb] if not candidates: logger.warning("All GPUs free_mem < %d MB → CPU", min_free_mb) return (*load_model("cpu"), "cpu") candidates.sort(reverse=True) # high free_mem first for free_mb, idx in candidates: dev = f"cuda:{idx}" try: logger.info("Trying %s (free=%d MB)", dev, free_mb) mdl, prec = load_model(dev) return mdl, prec, dev except RuntimeError as e: if "out of memory" in str(e).lower(): logger.warning("%s OOM → next GPU", dev) torch.cuda.empty_cache() continue raise # non-OOM error logger.warning("All GPUs failed → CPU fallback") return (*load_model("cpu"), "cpu") # -----------------------------------------------------------------------------# # CLI # -----------------------------------------------------------------------------# parser = argparse.ArgumentParser() parser.add_argument( "--device", help="GPU index (e.g. 0) or 'cpu'; overrides auto-selection" ) args, _ = parser.parse_known_args() if args.device is not None: # Forced path if args.device.lower() == "cpu": DEVICE = "cpu" else: DEVICE = f"cuda:{int(args.device)}" if torch.cuda.is_available() else "cpu" model, PRECISION = load_model(DEVICE) else: # Auto path with VRAM check model, PRECISION, DEVICE = auto_select_and_load(SAFE_MIN_FREE_MB) tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) # -----------------------------------------------------------------------------# # FastAPI # -----------------------------------------------------------------------------# app = FastAPI() logger.info("Using SAFE_MIN_FREE_MB = %d MB", SAFE_MIN_FREE_MB) @app.on_event("startup") async def warm_up_mp_pool(): try: if DEVICE.startswith("cuda"): logger.info("Warm-up (GPU) → 预生成多进程池") _ = model.encode(["warmup"], return_dense=True) else: logger.info("Warm-up (CPU) → 单进程初始化") if hasattr(model, "devices"): model.devices = ["cpu"] # 彻底屏蔽 GPU model.device = "cpu" _ = model.encode(["warmup"], return_dense=True) # ← 删掉 num_processes except Exception as e: logger.warning("Warm-up failed: %s —— 首条请求时再退避", e) class EmbeddingRequest(BaseModel): input: Union[str, List[str]] model: str = "text-embedding-bge-m3" def _encode(texts: List[str]): """ 单次请求: 1. 子进程跑 GPU 推理;成功→返回 2. 若子进程 OOM / CUDA Error → 同一次请求 fallback 到 CPU 绝不改全局状态,其他并发请求不受影响 """ def _worker(t, q): try: if DEVICE.startswith("cuda"): out = model.encode(t, return_dense=True) else: out = model.encode(t, return_dense=True) # ← 同样不传 num_processes q.put(("ok", out)) except Exception as e: q.put(("err", str(e))) q = mp.Queue() p = mp.Process(target=_worker, args=(texts, q)) p.start() p.join(timeout=60) if not q.empty(): status, payload = q.get() if status == "ok": return payload if "out of memory" in payload.lower() or "cuda error" in payload.lower(): logger.warning("GPU OOM → 本次请求改走 CPU:%s", payload) torch.cuda.empty_cache() cpu_model, _ = load_model("cpu") return cpu_model.encode(texts, return_dense=True) raise RuntimeError(payload) raise RuntimeError("子进程异常退出,无返回") # fallback_done = False # prevent endless downgrade loop # def _encode(texts: List[str]): # """Encode with single downgrade to CPU on OOM / CUDA failure.""" # global model, DEVICE, PRECISION, fallback_done # try: # return model.encode(texts, return_dense=True) # except RuntimeError as err: # is_oom = "out of memory" in str(err).lower() # is_cuda_fail = "cuda error" in str(err).lower() or "device-side assert" in str( # err # ).lower() # if (is_oom or is_cuda_fail) and not fallback_done: # logger.error("GPU failure (%s). Falling back to CPU…", err) # fallback_done = True # torch.cuda.empty_cache() # DEVICE = "cpu" # model, PRECISION = load_model(DEVICE) # return model.encode(texts, return_dense=True) # raise # second failure → propagate @app.post("/v1/embeddings") def create_embedding(request: EmbeddingRequest): texts = [request.input] if isinstance(request.input, str) else request.input # Token stats enc = tokenizer( texts, padding=True, truncation=True, max_length=8192, return_tensors="pt", ) prompt_tokens = int(enc["attention_mask"].sum().item()) try: output = _encode(texts) embeddings = output["dense_vecs"] except Exception as e: logger.exception("Embedding failed") raise HTTPException(status_code=500, detail=str(e)) return { "object": "list", "data": [ { "object": "embedding", "index": i, "embedding": emb.tolist() if hasattr(emb, "tolist") else emb, } for i, emb in enumerate(embeddings) ], "model": request.model, "usage": { "prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens, }, "device": DEVICE, "precision": PRECISION, } # -----------------------------------------------------------------------------# # Entry-point for `python server.py` # -----------------------------------------------------------------------------# if __name__ == "__main__": import uvicorn uvicorn.run( "server:app", host="0.0.0.0", port=int(os.getenv("PORT", 8000)), log_level="info", workers=1, # multi-process → use gunicorn/uvicorn externally )