#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ BGEM3 inference server (FastAPI) with robust GPU / CPU management. Launch examples: python server.py # 自动选卡 / 自动降级 python server.py --device 1 # 固定用第 1 张 GPU CUDA_VISIBLE_DEVICES=0,1 python server.py """ import argparse import logging import os from typing import List, Union from threading import Lock import torch from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer from FlagEmbedding import BGEM3FlagModel # -----------------------------------------------------------------------------# # Config # -----------------------------------------------------------------------------# MODEL_PATH = "model/bge-m3" # 按需改成你的权重路径 MODEL_VRAM_MB = int(os.getenv("MODEL_VRAM_MB", "4800")) # bge-m3-large fp16=2.4 fp32 ≈ 4.8 GiB POST_LOAD_GAP_MB = 200 SAFE_MIN_FREE_MB = MODEL_VRAM_MB + POST_LOAD_GAP_MB # == 5000 MB # 请求批次与单条最大长度上限(防御异常大 payload) MAX_BATCH = int(os.getenv("MAX_BATCH", "1024")) MAX_TEXT_LEN = int(os.getenv("MAX_TEXT_LEN", "200000")) BATCH_SIZE = max(1, int(os.getenv("BATCH_SIZE", "32"))) # GPU 连续失败多少次后尝试重载(自愈) GPU_MAX_CONSEC_FAILS = int(os.getenv("GPU_MAX_CONSEC_FAILS", "3")) # CPU 兜底模型缓存 CPU_MODEL_CACHE = None # 记录 GPU 连续失败次数;仅用于触发重载,不影响正常路径 _GPU_FAIL_COUNT = 0 _reload_lock = Lock() _READY = False class EmbeddingRequest(BaseModel): input: Union[str, List[str]] model: str = "text-embedding-bge-m3" # -----------------------------------------------------------------------------# # Logging # -----------------------------------------------------------------------------# logger = logging.getLogger("bge-m3-server") logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", ) def _http_error(status: int, code: str, message: str): """统一错误体:便于日志/客户端识别。""" raise HTTPException(status_code=status, detail={"code": code, "message": message}) # -----------------------------------------------------------------------------# # GPU memory helpers (NVML → torch fallback) # -----------------------------------------------------------------------------# try: import pynvml pynvml.nvmlInit() _USE_NVML = True except Exception: _USE_NVML = False logger.warning("pynvml 不可用,将使用 torch.cuda.mem_get_info() 探测显存") def _gpu_mem_info(idx: int) -> tuple[int, int]: """Return (free_bytes, total_bytes) for GPU `idx`(全局序号).""" if _USE_NVML: handle = pynvml.nvmlDeviceGetHandleByIndex(idx) mem = pynvml.nvmlDeviceGetMemoryInfo(handle) return mem.free, mem.total # torch fallback(注意:若进程已屏蔽,仅能访问进程内索引) torch.cuda.set_device(idx) free, total = torch.cuda.mem_get_info(idx) return free, total # -----------------------------------------------------------------------------# # Precision & model loader # -----------------------------------------------------------------------------# def _choose_precision_by_idx(idx: int) -> str: """Return 'fp16' or 'fp32' by compute capability.""" try: major, _ = torch.cuda.get_device_capability(idx) # 这里 idx 为全局序号 if major >= 8: # Ampere/Hopper return "fp16" if major >= 7: # Volta/Turing return "fp16" return "fp32" except Exception: return "fp32" def load_model(device: str): """ device: "cpu" 或 "cuda:{global_idx}" 屏蔽后在本进程内统一使用 "cuda:0" 加载,避免 invalid device ordinal。 """ if device == "cpu": os.environ["CUDA_VISIBLE_DEVICES"] = "" precision = "fp32" mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=False, device="cpu") logger.info("Loading BGEM3 on cpu (fp32)") return mdl, precision # 解析全局序号并在屏蔽前确定精度 idx = int(device.split(":")[1]) precision = _choose_precision_by_idx(idx) use_fp16 = (precision == "fp16") # 仅暴露这张卡;进程内映射为 cuda:0 mapped = "cuda:0" if device.startswith("cuda") else "cpu" logger.info("Loading BGEM3 on %s (mapped=%s, %s)", device, mapped, precision) mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=mapped) return mdl, precision # -----------------------------------------------------------------------------# # Auto-select device (startup) # -----------------------------------------------------------------------------# def auto_select_and_load() -> tuple: """ 只用 NVML 选卡并在首次 CUDA 调用前设置 CUDA_VISIBLE_DEVICES。 选卡规则: - 过滤空闲显存 < MODEL_VRAM_MB 的卡 - 按空闲显存降序尝试加载 - 加载后再用 NVML 复检剩余显存 < POST_LOAD_GAP_MB 则换下一张 - 全部不满足则 CPU """ # 1) 没有 NVML:无法安全做显存筛选 → 尝试盲选 0 号卡(提前 MASK),失败就 CPU if not _USE_NVML: if "CUDA_VISIBLE_DEVICES" not in os.environ or os.environ["CUDA_VISIBLE_DEVICES"] == "": os.environ["CUDA_VISIBLE_DEVICES"] = "0" try: mdl, prec = load_model("cuda:0") # 进程内看到的就是单卡 0 return mdl, prec, "cuda:0" except Exception as e: logger.warning("No NVML or CUDA unusable (%s) → CPU fallback", e) mdl, prec = load_model("cpu") return mdl, prec, "cpu" # 2) NVML 可用:按空闲显存挑卡(全程不触碰 torch.cuda) try: gpu_count = pynvml.nvmlDeviceGetCount() except Exception as e: logger.warning("NVML getCount failed (%s) → CPU", e) mdl, prec = load_model("cpu") return mdl, prec, "cpu" candidates = [] for idx in range(gpu_count): try: free_b, total_b = _gpu_mem_info(idx) # NVML 路径 free_mb = free_b // 2**20 if free_mb >= MODEL_VRAM_MB: candidates.append((free_mb, idx)) except Exception as e: logger.warning("NVML query gpu %d failed: %s", idx, e) if not candidates: logger.warning("All GPUs free_mem < %d MB → CPU", MODEL_VRAM_MB) mdl, prec = load_model("cpu") return mdl, prec, "cpu" # 3) 从大到小尝试加载;每次尝试前先 MASK 该卡 for free_mb, idx in sorted(candidates, reverse=True): try: os.environ["CUDA_VISIBLE_DEVICES"] = str(idx) # **关键:先 MASK,再触碰 torch** dev_label = f"cuda:{idx}" # 对外标注用全局序号 mdl, prec = load_model("cuda:0") # 进程内实际就是 0 号 # 载入后用 NVML 复检剩余显存(仍按全局 idx) remain_mb = _gpu_mem_info(idx)[0] // 2**20 if remain_mb < POST_LOAD_GAP_MB: raise RuntimeError(f"post-load free {remain_mb} MB < {POST_LOAD_GAP_MB} MB") return mdl, prec, dev_label except Exception as e: logger.warning("GPU %d unusable (%s) → next", idx, e) # 不要在这里调用 torch.cuda.empty_cache(),以免无意中初始化其他设备 continue # 4) 都不行 → CPU logger.warning("No suitable GPU left → CPU fallback") mdl, prec = load_model("cpu") return mdl, prec, "cpu" # -----------------------------------------------------------------------------# # CLI # -----------------------------------------------------------------------------# FORCE_DEVICE = os.getenv("FORCE_DEVICE") # "cpu" 或 "0" / "1" / ... parser = argparse.ArgumentParser() parser.add_argument("--device", help="GPU index (e.g. 0) or 'cpu'; overrides auto-selection") args, _ = parser.parse_known_args() if FORCE_DEVICE is not None: if FORCE_DEVICE.lower() == "cpu": os.environ["CUDA_VISIBLE_DEVICES"] = "" DEVICE = "cpu" model, PRECISION = load_model("cpu") else: os.environ["CUDA_VISIBLE_DEVICES"] = str(int(FORCE_DEVICE)) # 先掩蔽 DEVICE = f"cuda:{int(FORCE_DEVICE)}" # 对外展示全局序号 model, PRECISION = load_model("cuda:0") # 进程内使用 0 号 elif args.device is not None: if args.device.lower() == "cpu": os.environ["CUDA_VISIBLE_DEVICES"] = "" DEVICE = "cpu" model, PRECISION = load_model("cpu") else: os.environ["CUDA_VISIBLE_DEVICES"] = str(int(args.device)) # 先掩蔽 DEVICE = f"cuda:{int(args.device)}" model, PRECISION = load_model("cuda:0") else: model, PRECISION, DEVICE = auto_select_and_load() # --- global tokenizer (need this or you'll get "name 'tokenizer' is not defined") --- tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, local_files_only=True) # -----------------------------------------------------------------------------# # FastAPI # -----------------------------------------------------------------------------# app = FastAPI() logger.info("Using SAFE_MIN_FREE_MB = %d MB, BATCH_SIZE = %d", SAFE_MIN_FREE_MB, BATCH_SIZE) @app.on_event("startup") def _warmup(): global _READY try: with torch.inference_mode(): try: model.encode(["warmup sentence"], return_dense=True, batch_size=BATCH_SIZE) except TypeError: _ = _encode_chunked(model, ["warmup sentence"], max(1, min(BATCH_SIZE, 8))) _READY = True logger.info("Warm-up complete.") except Exception as e: _READY = False logger.warning("Warm-up failed: %s", e) @app.get("/ready") def ready(): if not _READY: # 503 让编排器/HEALTHCHECK 等待 raise HTTPException(status_code=503, detail={"code": "not_ready", "message": "warming up"}) return {"status": "ready"} # ---------------------- GPU 自愈:连续失败后重载模型 ---------------------- # def _maybe_reload_gpu_model(): """ 当 GPU 连挂达到阈值时尝试重载;失败不影响请求流程(继续用 CPU 兜底)。 """ global _GPU_FAIL_COUNT, model, PRECISION, DEVICE if not str(DEVICE).startswith("cuda"): return if _GPU_FAIL_COUNT < GPU_MAX_CONSEC_FAILS: return with _reload_lock: # 双重检查,避免并发重复 reload if _GPU_FAIL_COUNT < GPU_MAX_CONSEC_FAILS: return try: logger.warning("GPU unhealthy (fail_count=%d) → reloading on %s", _GPU_FAIL_COUNT, DEVICE) model, PRECISION = load_model(DEVICE) _GPU_FAIL_COUNT = 0 logger.info("GPU model reloaded on %s", DEVICE) except Exception as e: logger.exception("GPU reload failed: %s", e) # 保持 CPU 兜底路径继续工作 def _encode_chunked(model_obj, texts: List[str], chunk: int): """不依赖 batch_size 形参的手动分块编码;返回结构与 BGEM3 一致。""" chunk = max(1, int(chunk)) # guard dense_all = [] for i in range(0, len(texts), chunk): part = texts[i:i+chunk] out = model_obj.encode(part, return_dense=True) # 不传 batch_size dense_all.extend(out["dense_vecs"]) return {"dense_vecs": dense_all} def _encode(texts: List[str]): """ 返回 (output_dict, served_device) 正常走 GPU;若 OOM/CUDA 错误,自动切 CPU 兜底并触发自愈计数。 """ served = DEVICE try: try: return model.encode(texts, return_dense=True, batch_size=BATCH_SIZE), served except TypeError: # encode() 不支持 batch_size 形参 → 手动分块 return _encode_chunked(model, texts, BATCH_SIZE), served except RuntimeError as e: emsg = str(e).lower() if "out of memory" in emsg or "cuda error" in emsg: logger.warning("GPU OOM/CUDA error → fallback to CPU: %s", str(e)) torch.cuda.empty_cache() global CPU_MODEL_CACHE, _GPU_FAIL_COUNT if CPU_MODEL_CACHE is None: CPU_MODEL_CACHE, _ = load_model("cpu") _GPU_FAIL_COUNT += 1 # 记录连续失败 _maybe_reload_gpu_model() # 达到阈值尝试自愈 served = "cpu" # CPU 也先尝试 batch_size,再回退分块 try: return CPU_MODEL_CACHE.encode(texts, return_dense=True, batch_size=BATCH_SIZE), served except TypeError: return _encode_chunked(CPU_MODEL_CACHE, texts, BATCH_SIZE), served raise # =============== 参数规范化(入口就拦住会炸栈的输入) =============== # def _normalize_inputs(raw) -> List[str]: """ 行为约束: - 字符串:保持原行为,转为 [str](允许 "") - 列表 :要求 list[str];不 strip、不丢弃空白(保持原行为) - 空列表:直接 400(以前会让 HF tokenizer 抛 IndexError) - 非字符串元素:400 - 批量/文本超大:413(防御性限制) """ if isinstance(raw, str): texts = [raw] elif isinstance(raw, list): if len(raw) == 0: _http_error(400, "empty_list", "`input` is an empty list.") for i, x in enumerate(raw): if not isinstance(x, str): _http_error(400, "non_string_item", f"`input[{i}]` must be a string.") texts = raw else: _http_error(400, "invalid_type", "`input` must be a string or list of strings.") if len(texts) > MAX_BATCH: _http_error(413, "batch_too_large", f"batch size {len(texts)} > MAX_BATCH={MAX_BATCH}.") for i, t in enumerate(texts): if t is not None and len(t) > MAX_TEXT_LEN: _http_error(413, "text_too_long", f"`input[{i}]` exceeds MAX_TEXT_LEN={MAX_TEXT_LEN}.") return texts # ------------------------------- 主路由 --------------------------------------# @app.post("/v1/embeddings") def create_embedding(request: EmbeddingRequest): # 1) 参数规范化(入口就拦住会炸栈的输入;不改变允许 "" 的原行为) texts = _normalize_inputs(request.input) # 2) Token 统计:失败不影响主流程(记 0) try: enc = tokenizer( texts, padding=True, truncation=True, max_length=8192, return_tensors="pt", ) prompt_tokens = int(enc["attention_mask"].sum().item()) except IndexError: # 双保险(理论上不会再发生) _http_error(400, "empty_after_tokenize", "`input` contains no encodable texts.") except Exception as e: logger.warning("token counting failed: %s", e) prompt_tokens = 0 # 3) 编码(GPU→CPU 兜底;自愈计数与重载) try: output, served_device = _encode(texts) embeddings = output["dense_vecs"] except HTTPException: raise except Exception as e: logger.exception("Embedding failed") _http_error(500, "embedding_failed", str(e)) # 4) 原样返回(结构/字段保持一致) return { "object": "list", "data": [ { "object": "embedding", "index": i, "embedding": emb.tolist() if hasattr(emb, "tolist") else emb, } for i, emb in enumerate(embeddings) ], "model": request.model, "usage": { "prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens, }, "device": served_device, # 返回真实服务设备 "precision": "fp32" if served_device == "cpu" else PRECISION, } # --------------------------- 健康检查(可选) --------------------------------# @app.get("/health") def healthz(): return { "status": "ok", "device": DEVICE, "precision": PRECISION, "gpu_fail_count": _GPU_FAIL_COUNT, "batch_size": BATCH_SIZE, "ready": _READY, } # -----------------------------------------------------------------------------# # Entry-point for `python server.py` # -----------------------------------------------------------------------------# if __name__ == "__main__": import uvicorn uvicorn.run( "server:app", host="0.0.0.0", port=int(os.getenv("PORT", 8000)), log_level="info", workers=1, # multi-process → use gunicorn/uvicorn externally )