diff --git a/app/main.py b/app/main.py index c8d17d8..7c399cc 100644 --- a/app/main.py +++ b/app/main.py @@ -1,67 +1,258 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +BGEM3 inference server (FastAPI) with robust GPU / CPU management. + +Launch examples: + python server.py # 自动选卡 / 自动降级 + python server.py --device 1 # 固定用第 1 张 GPU + CUDA_VISIBLE_DEVICES=0,1 python server.py +""" + +import argparse +import logging +import os +import sys +import time +from typing import List, Union + +import torch from fastapi import FastAPI, HTTPException from pydantic import BaseModel -from typing import List, Union -import torch from transformers import AutoTokenizer from FlagEmbedding import BGEM3FlagModel -# 自动检测设备 -DEVICE = "cuda" if torch.cuda.is_available() else "cpu" -# 如果是 GPU,则启用 fp16 加速,否则使用 fp32 -USE_FP16 = DEVICE == "cuda" +# -----------------------------------------------------------------------------# +# Config +# -----------------------------------------------------------------------------# +MODEL_PATH = "model/bge-m3" # 按需改成你的权重路径 +SAFE_MIN_FREE_MB = int( + os.getenv("SAFE_MIN_FREE_MB", "4096") +) # 启动时要求的最小空闲显存,MB + +# -----------------------------------------------------------------------------# +# Logging +# -----------------------------------------------------------------------------# +logger = logging.getLogger("bge-m3-server") +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", +) + +# -----------------------------------------------------------------------------# +# GPU memory helpers (NVML → torch fallback) +# -----------------------------------------------------------------------------# +try: + import pynvml + + pynvml.nvmlInit() + _USE_NVML = True +except Exception: + _USE_NVML = False + logger.warning("pynvml 不可用,将使用 torch.cuda.mem_get_info() 探测显存") + + +def _gpu_mem_info(idx: int) -> tuple[int, int]: + """Return (free_bytes, total_bytes) for GPU `idx`.""" + if _USE_NVML: + handle = pynvml.nvmlDeviceGetHandleByIndex(idx) + mem = pynvml.nvmlDeviceGetMemoryInfo(handle) + return mem.free, mem.total + # torch fallback + torch.cuda.set_device(idx) + free, total = torch.cuda.mem_get_info(idx) + return free, total + + +# -----------------------------------------------------------------------------# +# Precision & model loader +# -----------------------------------------------------------------------------# +def _choose_precision(device: str) -> str: + """Return 'fp16'|'bf16'|'fp32'.""" + if not device.startswith("cuda"): + return "fp32" + + major, _ = torch.cuda.get_device_capability(device) + if major >= 8: # Ampere/Hopper + return "fp16" + if major >= 7: + return "bf16" + return "fp32" + + +def load_model(device: str): + """Instantiate model on `device`; return (model, precision).""" + precision = _choose_precision(device) + use_fp16 = precision == "fp16" + logger.info("Loading BGEM3 on %s (%s)", device, precision) + + mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device) + + # Simple DataParallel for multi-GPU inference + if device.startswith("cuda") and torch.cuda.device_count() > 1: + logger.info( + "Wrapping model with torch.nn.DataParallel (%d GPUs)", + torch.cuda.device_count(), + ) + mdl = torch.nn.DataParallel(mdl) + return mdl, precision + + +# -----------------------------------------------------------------------------# +# Auto-select device (startup) +# -----------------------------------------------------------------------------# +def auto_select_and_load(min_free_mb: int = 4096): + """ + 1. Gather GPUs with free_mem ≥ min_free_mb + 2. Sort by free_mem desc, attempt to load + 3. All fail → CPU + Return (model, device, precision) + """ + if not torch.cuda.is_available(): + logger.info("No GPU detected → CPU") + return (*load_model("cpu"), "cpu") + + # Build candidate list + candidates: list[tuple[int, int]] = [] # (free_MB, idx) + for idx in range(torch.cuda.device_count()): + free, _ = _gpu_mem_info(idx) + free_mb = free // 2**20 + candidates.append((free_mb, idx)) + + candidates = [c for c in candidates if c[0] >= min_free_mb] + if not candidates: + logger.warning("All GPUs free_mem < %d MB → CPU", min_free_mb) + return (*load_model("cpu"), "cpu") + + candidates.sort(reverse=True) # high free_mem first + for free_mb, idx in candidates: + dev = f"cuda:{idx}" + try: + logger.info("Trying %s (free=%d MB)", dev, free_mb) + mdl, prec = load_model(dev) + return mdl, prec, dev + except RuntimeError as e: + if "out of memory" in str(e).lower(): + logger.warning("%s OOM → next GPU", dev) + torch.cuda.empty_cache() + continue + raise # non-OOM error + + logger.warning("All GPUs failed → CPU fallback") + return (*load_model("cpu"), "cpu") + + +# -----------------------------------------------------------------------------# +# CLI +# -----------------------------------------------------------------------------# +parser = argparse.ArgumentParser() +parser.add_argument( + "--device", help="GPU index (e.g. 0) or 'cpu'; overrides auto-selection" +) +args, _ = parser.parse_known_args() + +if args.device is not None: + # Forced path + if args.device.lower() == "cpu": + DEVICE = "cpu" + else: + DEVICE = f"cuda:{int(args.device)}" if torch.cuda.is_available() else "cpu" + model, PRECISION = load_model(DEVICE) +else: + # Auto path with VRAM check + model, PRECISION, DEVICE = auto_select_and_load(SAFE_MIN_FREE_MB) -# 加载模型和分词器 -MODEL_PATH = "model/bge-m3" -model = BGEM3FlagModel(MODEL_PATH, use_fp16=USE_FP16, device=DEVICE) tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) +# -----------------------------------------------------------------------------# +# FastAPI +# -----------------------------------------------------------------------------# app = FastAPI() + class EmbeddingRequest(BaseModel): input: Union[str, List[str]] model: str = "text-embedding-bge-m3" -@app.post("/v1/embeddings") -def create_embedding(request: EmbeddingRequest): - # 统一成列表 - texts = [request.input] if isinstance(request.input, str) else request.input + +fallback_done = False # prevent endless downgrade loop + + +def _encode(texts: List[str]): + """Encode with single downgrade to CPU on OOM / CUDA failure.""" + global model, DEVICE, PRECISION, fallback_done try: - # 先用 tokenizer 统计 token 数 - encoding = tokenizer( - texts, - padding=True, - truncation=True, - max_length=8192, - return_tensors="pt" - ) - # attention_mask 中 1 的数量即为实际 tokens(不含 padding) - mask = encoding["attention_mask"] - # prompt_tokens = 所有输入 tokens 之和 - prompt_tokens = int(mask.sum().item()) - total_tokens = prompt_tokens # embedding 不产生额外 tokens + return model.encode(texts, return_dense=True) - # 生成 dense 向量 - output = model.encode(texts, return_dense=True) + except RuntimeError as err: + is_oom = "out of memory" in str(err).lower() + is_cuda_fail = "cuda error" in str(err).lower() or "device-side assert" in str( + err + ).lower() + + if (is_oom or is_cuda_fail) and not fallback_done: + logger.error("GPU failure (%s). Falling back to CPU…", err) + fallback_done = True + torch.cuda.empty_cache() + DEVICE = "cpu" + model, PRECISION = load_model(DEVICE) + return model.encode(texts, return_dense=True) + + raise # second failure → propagate + + +@app.post("/v1/embeddings") +def create_embedding(request: EmbeddingRequest): + texts = [request.input] if isinstance(request.input, str) else request.input + + # Token stats + enc = tokenizer( + texts, + padding=True, + truncation=True, + max_length=8192, + return_tensors="pt", + ) + prompt_tokens = int(enc["attention_mask"].sum().item()) + + try: + output = _encode(texts) embeddings = output["dense_vecs"] - - return { - "object": "list", - "data": [ - { - "object": "embedding", - "index": idx, - "embedding": emb.tolist() if hasattr(emb, "tolist") else emb - } - for idx, emb in enumerate(embeddings) - ], - "model": request.model, - "usage": { - "prompt_tokens": prompt_tokens, - "total_tokens": total_tokens - } - } - except Exception as e: + logger.exception("Embedding failed") raise HTTPException(status_code=500, detail=str(e)) + return { + "object": "list", + "data": [ + { + "object": "embedding", + "index": i, + "embedding": emb.tolist() if hasattr(emb, "tolist") else emb, + } + for i, emb in enumerate(embeddings) + ], + "model": request.model, + "usage": { + "prompt_tokens": prompt_tokens, + "total_tokens": prompt_tokens, + }, + "device": DEVICE, + "precision": PRECISION, + } + + +# -----------------------------------------------------------------------------# +# Entry-point for `python server.py` +# -----------------------------------------------------------------------------# +if __name__ == "__main__": + import uvicorn + + uvicorn.run( + "server:app", + host="0.0.0.0", + port=int(os.getenv("PORT", 8000)), + log_level="info", + workers=1, # multi-process → use gunicorn/uvicorn externally + ) diff --git a/app/main.py.old b/app/main.py.old new file mode 100644 index 0000000..c8d17d8 --- /dev/null +++ b/app/main.py.old @@ -0,0 +1,67 @@ +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +from typing import List, Union +import torch +from transformers import AutoTokenizer +from FlagEmbedding import BGEM3FlagModel + +# 自动检测设备 +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +# 如果是 GPU,则启用 fp16 加速,否则使用 fp32 +USE_FP16 = DEVICE == "cuda" + +# 加载模型和分词器 +MODEL_PATH = "model/bge-m3" +model = BGEM3FlagModel(MODEL_PATH, use_fp16=USE_FP16, device=DEVICE) +tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) + +app = FastAPI() + +class EmbeddingRequest(BaseModel): + input: Union[str, List[str]] + model: str = "text-embedding-bge-m3" + +@app.post("/v1/embeddings") +def create_embedding(request: EmbeddingRequest): + # 统一成列表 + texts = [request.input] if isinstance(request.input, str) else request.input + + try: + # 先用 tokenizer 统计 token 数 + encoding = tokenizer( + texts, + padding=True, + truncation=True, + max_length=8192, + return_tensors="pt" + ) + # attention_mask 中 1 的数量即为实际 tokens(不含 padding) + mask = encoding["attention_mask"] + # prompt_tokens = 所有输入 tokens 之和 + prompt_tokens = int(mask.sum().item()) + total_tokens = prompt_tokens # embedding 不产生额外 tokens + + # 生成 dense 向量 + output = model.encode(texts, return_dense=True) + embeddings = output["dense_vecs"] + + return { + "object": "list", + "data": [ + { + "object": "embedding", + "index": idx, + "embedding": emb.tolist() if hasattr(emb, "tolist") else emb + } + for idx, emb in enumerate(embeddings) + ], + "model": request.model, + "usage": { + "prompt_tokens": prompt_tokens, + "total_tokens": total_tokens + } + } + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + diff --git a/requirements.txt b/requirements.txt index ad250fc..73d86f1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ torch transformers datasets peft +pynvml \ No newline at end of file