#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ BGEM3 inference server (FastAPI) with robust GPU / CPU management. Launch examples: python server.py # 自动选卡 / 自动降级 python server.py --device 1 # 固定用第 1 张 GPU CUDA_VISIBLE_DEVICES=0,1 python server.py """ import argparse import logging import os from typing import List, Union from threading import Lock import torch from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer from FlagEmbedding import BGEM3FlagModel # -----------------------------------------------------------------------------# # Config # -----------------------------------------------------------------------------# MODEL_PATH = "model/bge-m3" # 按需改成你的权重路径 MODEL_VRAM_MB = int(os.getenv("MODEL_VRAM_MB", "8000")) # bge-m3-large fp32 ≈ 8 GiB POST_LOAD_GAP_MB = 192 SAFE_MIN_FREE_MB = MODEL_VRAM_MB + POST_LOAD_GAP_MB # == 8192 MB # 请求批次与单条最大长度上限(防御异常大 payload) MAX_BATCH = int(os.getenv("MAX_BATCH", "1024")) MAX_TEXT_LEN = int(os.getenv("MAX_TEXT_LEN", "200000")) BATCH_SIZE = max(1, int(os.getenv("BATCH_SIZE", "32"))) # GPU 连续失败多少次后尝试重载(自愈) GPU_MAX_CONSEC_FAILS = int(os.getenv("GPU_MAX_CONSEC_FAILS", "3")) # CPU 兜底模型缓存 CPU_MODEL_CACHE = None # 记录 GPU 连续失败次数;仅用于触发重载,不影响正常路径 _GPU_FAIL_COUNT = 0 _reload_lock = Lock() _READY = False class EmbeddingRequest(BaseModel): input: Union[str, List[str]] model: str = "text-embedding-bge-m3" # -----------------------------------------------------------------------------# # Logging # -----------------------------------------------------------------------------# logger = logging.getLogger("bge-m3-server") logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", ) def _http_error(status: int, code: str, message: str): """统一错误体:便于日志/客户端识别。""" raise HTTPException(status_code=status, detail={"code": code, "message": message}) # -----------------------------------------------------------------------------# # GPU memory helpers (NVML → torch fallback) # -----------------------------------------------------------------------------# try: import pynvml pynvml.nvmlInit() _USE_NVML = True except Exception: _USE_NVML = False logger.warning("pynvml 不可用,将使用 torch.cuda.mem_get_info() 探测显存") def _gpu_mem_info(idx: int) -> tuple[int, int]: """Return (free_bytes, total_bytes) for GPU `idx`(全局序号).""" if _USE_NVML: handle = pynvml.nvmlDeviceGetHandleByIndex(idx) mem = pynvml.nvmlDeviceGetMemoryInfo(handle) return mem.free, mem.total # torch fallback(注意:若进程已屏蔽,仅能访问进程内索引) torch.cuda.set_device(idx) free, total = torch.cuda.mem_get_info(idx) return free, total # -----------------------------------------------------------------------------# # Precision & model loader # -----------------------------------------------------------------------------# def _choose_precision_by_idx(idx: int) -> str: """Return 'fp16' or 'fp32' by compute capability.""" try: major, _ = torch.cuda.get_device_capability(idx) # 这里 idx 为全局序号 if major >= 8: # Ampere/Hopper return "fp16" if major >= 7: # Volta/Turing return "fp16" return "fp32" except Exception: return "fp16" if torch.cuda.is_available() else "fp32" def load_model(device: str): """ device: "cpu" 或 "cuda:{global_idx}" 屏蔽后在本进程内统一使用 "cuda:0" 加载,避免 invalid device ordinal。 """ if device == "cpu": os.environ["CUDA_VISIBLE_DEVICES"] = "" precision = "fp32" mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=False, device="cpu") logger.info("Loading BGEM3 on cpu (fp32)") return mdl, precision # 解析全局序号并在屏蔽前确定精度 idx = int(device.split(":")[1]) precision = _choose_precision_by_idx(idx) use_fp16 = (precision == "fp16") # 仅暴露这张卡;进程内映射为 cuda:0 os.environ["CUDA_VISIBLE_DEVICES"] = str(idx) mapped = "cuda:0" logger.info("Loading BGEM3 on %s (mapped=%s, %s)", device, mapped, precision) mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=mapped) return mdl, precision # -----------------------------------------------------------------------------# # Auto-select device (startup) # -----------------------------------------------------------------------------# def auto_select_and_load() -> tuple: """ 1. 过滤掉空闲显存 < MODEL_VRAM_MB 的 GPU 2. 按空闲显存降序依次尝试加载 3. 载入后再次检查:若剩余 < POST_LOAD_GAP_MB → 视为失败 4. 若全部 GPU 不满足 → CPU """ if not torch.cuda.is_available(): logger.info("No GPU detected → CPU") mdl, prec = load_model("cpu") return mdl, prec, "cpu" # 收集候选卡 (free_MB, idx) candidates = [] for idx in range(torch.cuda.device_count()): free_mb = _gpu_mem_info(idx)[0] // 2**20 if free_mb >= MODEL_VRAM_MB: candidates.append((free_mb, idx)) if not candidates: logger.warning("All GPUs free_mem < %d MB → CPU", MODEL_VRAM_MB) mdl, prec = load_model("cpu") return mdl, prec, "cpu" # 空闲显存从高到低 for free_mb, idx in sorted(candidates, reverse=True): dev = f"cuda:{idx}" try: logger.info("Trying %s (free=%d MB)", dev, free_mb) mdl, prec = load_model(dev) # 载入后余量检查:NVML 用全局 idx;无 NVML 时,用进程内 0 号 if _USE_NVML: remain_mb = _gpu_mem_info(idx)[0] // 2**20 else: remain_mb = _gpu_mem_info(0)[0] // 2**20 if remain_mb < POST_LOAD_GAP_MB: raise RuntimeError(f"post-load free {remain_mb} MB < {POST_LOAD_GAP_MB} MB") return mdl, prec, dev except RuntimeError as e: logger.warning("%s unusable (%s) → next", dev, e) torch.cuda.empty_cache() logger.warning("No suitable GPU left → CPU fallback") mdl, prec = load_model("cpu") return mdl, prec, "cpu" # -----------------------------------------------------------------------------# # CLI # -----------------------------------------------------------------------------# FORCE_DEVICE = os.getenv("FORCE_DEVICE") # "cpu" 或 "0" / "1" / ... parser = argparse.ArgumentParser() parser.add_argument("--device", help="GPU index (e.g. 0) or 'cpu'; overrides auto-selection") args, _ = parser.parse_known_args() if FORCE_DEVICE is not None: if FORCE_DEVICE.lower() == "cpu": DEVICE = "cpu" else: DEVICE = f"cuda:{int(FORCE_DEVICE)}" if torch.cuda.is_available() else "cpu" model, PRECISION = load_model(DEVICE) elif args.device is not None: if args.device.lower() == "cpu": DEVICE = "cpu" else: DEVICE = f"cuda:{int(args.device)}" if torch.cuda.is_available() else "cpu" model, PRECISION = load_model(DEVICE) else: model, PRECISION, DEVICE = auto_select_and_load() tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) # -----------------------------------------------------------------------------# # FastAPI # -----------------------------------------------------------------------------# app = FastAPI() logger.info("Using SAFE_MIN_FREE_MB = %d MB, BATCH_SIZE = %d", SAFE_MIN_FREE_MB, BATCH_SIZE) @app.on_event("startup") def _warmup(): global _READY try: # 尝试用 batch_size 预热;不支持就回退 try: model.encode(["warmup sentence"], return_dense=True, batch_size=BATCH_SIZE) except TypeError: _ = _encode_chunked(model, ["warmup sentence"], max(1, min(BATCH_SIZE, 8))) _READY = True logger.info("Warm-up complete.") except Exception as e: _READY = False logger.warning("Warm-up failed: %s", e) @app.get("/ready") def ready(): if not _READY: # 503 让编排器/HEALTHCHECK 等待 raise HTTPException(status_code=503, detail={"code": "not_ready", "message": "warming up"}) return {"status": "ready"} # ---------------------- GPU 自愈:连续失败后重载模型 ---------------------- # def _maybe_reload_gpu_model(): """ 当 GPU 连挂达到阈值时尝试重载;失败不影响请求流程(继续用 CPU 兜底)。 """ global _GPU_FAIL_COUNT, model, PRECISION, DEVICE if not str(DEVICE).startswith("cuda"): return if _GPU_FAIL_COUNT < GPU_MAX_CONSEC_FAILS: return with _reload_lock: # 双重检查,避免并发重复 reload if _GPU_FAIL_COUNT < GPU_MAX_CONSEC_FAILS: return try: logger.warning("GPU unhealthy (fail_count=%d) → reloading on %s", _GPU_FAIL_COUNT, DEVICE) model, PRECISION = load_model(DEVICE) _GPU_FAIL_COUNT = 0 logger.info("GPU model reloaded on %s", DEVICE) except Exception as e: logger.exception("GPU reload failed: %s", e) # 保持 CPU 兜底路径继续工作 def _encode_chunked(model_obj, texts: List[str], chunk: int): """不依赖 batch_size 形参的手动分块编码;返回结构与 BGEM3 一致。""" chunk = max(1, int(chunk)) # guard dense_all = [] for i in range(0, len(texts), chunk): part = texts[i:i+chunk] out = model_obj.encode(part, return_dense=True) # 不传 batch_size dense_all.extend(out["dense_vecs"]) return {"dense_vecs": dense_all} def _encode(texts: List[str]): """ 返回 (output_dict, served_device) 正常走 GPU;若 OOM/CUDA 错误,自动切 CPU 兜底并触发自愈计数。 """ served = DEVICE try: try: return model.encode(texts, return_dense=True, batch_size=BATCH_SIZE), served except TypeError: # encode() 不支持 batch_size 形参 → 手动分块 return _encode_chunked(model, texts, BATCH_SIZE), served except RuntimeError as e: emsg = str(e).lower() if "out of memory" in emsg or "cuda error" in emsg: logger.warning("GPU OOM/CUDA error → fallback to CPU: %s", str(e)) torch.cuda.empty_cache() global CPU_MODEL_CACHE, _GPU_FAIL_COUNT if CPU_MODEL_CACHE is None: CPU_MODEL_CACHE, _ = load_model("cpu") _GPU_FAIL_COUNT += 1 # 记录连续失败 _maybe_reload_gpu_model() # 达到阈值尝试自愈 served = "cpu" # CPU 也先尝试 batch_size,再回退分块 try: return CPU_MODEL_CACHE.encode(texts, return_dense=True, batch_size=BATCH_SIZE), served except TypeError: return _encode_chunked(CPU_MODEL_CACHE, texts, BATCH_SIZE), served raise # =============== 参数规范化(入口就拦住会炸栈的输入) =============== # def _normalize_inputs(raw) -> List[str]: """ 行为约束: - 字符串:保持原行为,转为 [str](允许 "") - 列表 :要求 list[str];不 strip、不丢弃空白(保持原行为) - 空列表:直接 400(以前会让 HF tokenizer 抛 IndexError) - 非字符串元素:400 - 批量/文本超大:413(防御性限制) """ if isinstance(raw, str): texts = [raw] elif isinstance(raw, list): if len(raw) == 0: _http_error(400, "empty_list", "`input` is an empty list.") for i, x in enumerate(raw): if not isinstance(x, str): _http_error(400, "non_string_item", f"`input[{i}]` must be a string.") texts = raw else: _http_error(400, "invalid_type", "`input` must be a string or list of strings.") if len(texts) > MAX_BATCH: _http_error(413, "batch_too_large", f"batch size {len(texts)} > MAX_BATCH={MAX_BATCH}.") for i, t in enumerate(texts): if t is not None and len(t) > MAX_TEXT_LEN: _http_error(413, "text_too_long", f"`input[{i}]` exceeds MAX_TEXT_LEN={MAX_TEXT_LEN}.") return texts # ------------------------------- 主路由 --------------------------------------# @app.post("/v1/embeddings") def create_embedding(request: EmbeddingRequest): # 1) 参数规范化(入口就拦住会炸栈的输入;不改变允许 "" 的原行为) texts = _normalize_inputs(request.input) # 2) Token 统计:失败不影响主流程(记 0) try: enc = tokenizer( texts, padding=True, truncation=True, max_length=8192, return_tensors="pt", ) prompt_tokens = int(enc["attention_mask"].sum().item()) except IndexError: # 双保险(理论上不会再发生) _http_error(400, "empty_after_tokenize", "`input` contains no encodable texts.") except Exception as e: logger.warning("token counting failed: %s", e) prompt_tokens = 0 # 3) 编码(GPU→CPU 兜底;自愈计数与重载) try: output, served_device = _encode(texts) embeddings = output["dense_vecs"] except HTTPException: raise except Exception as e: logger.exception("Embedding failed") _http_error(500, "embedding_failed", str(e)) # 4) 原样返回(结构/字段保持一致) return { "object": "list", "data": [ { "object": "embedding", "index": i, "embedding": emb.tolist() if hasattr(emb, "tolist") else emb, } for i, emb in enumerate(embeddings) ], "model": request.model, "usage": { "prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens, }, "device": served_device, # 返回真实服务设备 "precision": "fp32" if served_device == "cpu" else PRECISION, } # --------------------------- 健康检查(可选) --------------------------------# @app.get("/health") def healthz(): return { "status": "ok", "device": DEVICE, "precision": PRECISION, "gpu_fail_count": _GPU_FAIL_COUNT, "batch_size": BATCH_SIZE, "ready": _READY, } # -----------------------------------------------------------------------------# # Entry-point for `python server.py` # -----------------------------------------------------------------------------# if __name__ == "__main__": import uvicorn uvicorn.run( "server:app", host="0.0.0.0", port=int(os.getenv("PORT", 8000)), log_level="info", workers=1, # multi-process → use gunicorn/uvicorn externally )