diff --git a/Dockerfile b/Dockerfile index 672d237..7bf8843 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,18 +1,21 @@ FROM python:3.10-slim # 安装系统依赖 -RUN apt-get update && apt-get install -y gcc libglib2.0-0 && rm -rf /var/lib/apt/lists/* +RUN apt-get update && apt-get install -y --no-install-recommends \ + gcc libglib2.0-0 curl && rm -rf /var/lib/apt/lists/* + # 设置工作目录 WORKDIR /app # 安装 Python 依赖 COPY requirements.txt . -RUN pip install --upgrade pip && pip install -r requirements.txt +RUN pip install --upgrade pip && pip install --no-cache-dir -r requirements.txt # 安装本地 FlagEmbedding 源码 COPY FlagEmbedding /opt/FlagEmbedding -RUN pip install --no-deps --upgrade -e /opt/FlagEmbedding +RUN pip install --no-cache-dir --no-deps /opt/FlagEmbedding + # 拷贝应用代码和模型权重 COPY app /app/app @@ -26,9 +29,16 @@ EXPOSE 8001 # 新增:给 PT 显存分段配置,减少碎片 (可选但推荐) ENV PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:32 +ENV TOKENIZERS_PARALLELISM=false +ENV HF_HUB_DISABLE_TELEMETRY=1 TRANSFORMERS_NO_ADVISORY_WARNINGS=1 + +# 健康检查:给启动和预热留时间(按你模型体量调整 start-period) +HEALTHCHECK --interval=30s --timeout=5s --start-period=120s --retries=3 \ + CMD curl -fsS http://127.0.0.1:8001/ready >/dev/null || exit 1 # 启动:Gunicorn + 1 worker,每个 worker 一个独立进程 -CMD ["gunicorn", "app.main:app", \ - "-k", "uvicorn.workers.UvicornWorker", \ - "-w", "1", \ - "-b", "0.0.0.0:8001"] \ No newline at end of file +CMD ["gunicorn","app.main:app", \ + "-k","uvicorn.workers.UvicornWorker", \ + "-w","1","-b","0.0.0.0:8001", \ + "--timeout","120","--graceful-timeout","30", \ + "--max-requests","1000","--max-requests-jitter","200"] \ No newline at end of file diff --git a/app/main.py b/app/main.py index 2f48691..9b2f0b0 100644 --- a/app/main.py +++ b/app/main.py @@ -12,10 +12,8 @@ Launch examples: import argparse import logging import os -import sys -import time from typing import List, Union -import multiprocessing as mp +from threading import Lock import torch from fastapi import FastAPI, HTTPException @@ -23,16 +21,35 @@ from pydantic import BaseModel from transformers import AutoTokenizer from FlagEmbedding import BGEM3FlagModel -mp.set_start_method("spawn", force=True) - # -----------------------------------------------------------------------------# # Config # -----------------------------------------------------------------------------# MODEL_PATH = "model/bge-m3" # 按需改成你的权重路径 MODEL_VRAM_MB = int(os.getenv("MODEL_VRAM_MB", "8000")) # bge-m3-large fp32 ≈ 8 GiB -POST_LOAD_GAP_MB = 192 +POST_LOAD_GAP_MB = 192 SAFE_MIN_FREE_MB = MODEL_VRAM_MB + POST_LOAD_GAP_MB # == 8192 MB +# 请求批次与单条最大长度上限(防御异常大 payload) +MAX_BATCH = int(os.getenv("MAX_BATCH", "1024")) +MAX_TEXT_LEN = int(os.getenv("MAX_TEXT_LEN", "200000")) +BATCH_SIZE = max(1, int(os.getenv("BATCH_SIZE", "32"))) + +# GPU 连续失败多少次后尝试重载(自愈) +GPU_MAX_CONSEC_FAILS = int(os.getenv("GPU_MAX_CONSEC_FAILS", "3")) + +# CPU 兜底模型缓存 +CPU_MODEL_CACHE = None + +# 记录 GPU 连续失败次数;仅用于触发重载,不影响正常路径 +_GPU_FAIL_COUNT = 0 +_reload_lock = Lock() + +_READY = False + +class EmbeddingRequest(BaseModel): + input: Union[str, List[str]] + model: str = "text-embedding-bge-m3" + # -----------------------------------------------------------------------------# # Logging # -----------------------------------------------------------------------------# @@ -42,6 +59,10 @@ logging.basicConfig( format="%(asctime)s [%(levelname)s] %(message)s", ) +def _http_error(status: int, code: str, message: str): + """统一错误体:便于日志/客户端识别。""" + raise HTTPException(status_code=status, detail={"code": code, "message": message}) + # -----------------------------------------------------------------------------# # GPU memory helpers (NVML → torch fallback) # -----------------------------------------------------------------------------# @@ -54,56 +75,55 @@ except Exception: _USE_NVML = False logger.warning("pynvml 不可用,将使用 torch.cuda.mem_get_info() 探测显存") - def _gpu_mem_info(idx: int) -> tuple[int, int]: - """Return (free_bytes, total_bytes) for GPU `idx`.""" + """Return (free_bytes, total_bytes) for GPU `idx`(全局序号).""" if _USE_NVML: handle = pynvml.nvmlDeviceGetHandleByIndex(idx) mem = pynvml.nvmlDeviceGetMemoryInfo(handle) return mem.free, mem.total - # torch fallback + # torch fallback(注意:若进程已屏蔽,仅能访问进程内索引) torch.cuda.set_device(idx) free, total = torch.cuda.mem_get_info(idx) return free, total - # -----------------------------------------------------------------------------# # Precision & model loader # -----------------------------------------------------------------------------# -def _choose_precision(device: str) -> str: - """Return 'fp16'|'bf16'|'fp32'.""" - if not device.startswith("cuda"): +def _choose_precision_by_idx(idx: int) -> str: + """Return 'fp16' or 'fp32' by compute capability.""" + try: + major, _ = torch.cuda.get_device_capability(idx) # 这里 idx 为全局序号 + if major >= 8: # Ampere/Hopper + return "fp16" + if major >= 7: # Volta/Turing + return "fp16" return "fp32" - - major, _ = torch.cuda.get_device_capability(device) - if major >= 8: # Ampere/Hopper - return "fp16" - if major >= 7: - return "bf16" - return "fp32" - + except Exception: + return "fp16" if torch.cuda.is_available() else "fp32" def load_model(device: str): - precision = _choose_precision(device) - use_fp16 = precision == "fp16" - logger.info("Loading BGEM3 on %s (%s)", device, precision) - + """ + device: "cpu" 或 "cuda:{global_idx}" + 屏蔽后在本进程内统一使用 "cuda:0" 加载,避免 invalid device ordinal。 + """ if device == "cpu": - # CPU 路径:彻底摘掉 CUDA os.environ["CUDA_VISIBLE_DEVICES"] = "" - torch.cuda.is_available = lambda: False - torch.cuda.device_count = lambda: 0 - else: - # GPU 路径:只暴露本 worker 选中的那张卡 - # 例:device == "cuda:3" → 只让当前进程看到 GPU 3 - idx = device.split(":")[1] - os.environ["CUDA_VISIBLE_DEVICES"] = idx - # device_count 现在返回 1,BGEM3 只会在这张卡上建 1 个子进程 - torch.cuda.device_count = lambda: 1 + precision = "fp32" + mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=False, device="cpu") + logger.info("Loading BGEM3 on cpu (fp32)") + return mdl, precision - mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device) + # 解析全局序号并在屏蔽前确定精度 + idx = int(device.split(":")[1]) + precision = _choose_precision_by_idx(idx) + use_fp16 = (precision == "fp16") - # 不再包 DataParallel;每个 worker 单卡即可 + # 仅暴露这张卡;进程内映射为 cuda:0 + os.environ["CUDA_VISIBLE_DEVICES"] = str(idx) + mapped = "cuda:0" + + logger.info("Loading BGEM3 on %s (mapped=%s, %s)", device, mapped, precision) + mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=mapped) return mdl, precision # -----------------------------------------------------------------------------# @@ -118,18 +138,20 @@ def auto_select_and_load() -> tuple: """ if not torch.cuda.is_available(): logger.info("No GPU detected → CPU") - return (*load_model("cpu"), "cpu") + mdl, prec = load_model("cpu") + return mdl, prec, "cpu" # 收集候选卡 (free_MB, idx) candidates = [] for idx in range(torch.cuda.device_count()): free_mb = _gpu_mem_info(idx)[0] // 2**20 - if free_mb >= MODEL_VRAM_MB: # 至少能放下权重 + if free_mb >= MODEL_VRAM_MB: candidates.append((free_mb, idx)) if not candidates: logger.warning("All GPUs free_mem < %d MB → CPU", MODEL_VRAM_MB) - return (*load_model("cpu"), "cpu") + mdl, prec = load_model("cpu") + return mdl, prec, "cpu" # 空闲显存从高到低 for free_mb, idx in sorted(candidates, reverse=True): @@ -138,120 +160,206 @@ def auto_select_and_load() -> tuple: logger.info("Trying %s (free=%d MB)", dev, free_mb) mdl, prec = load_model(dev) - remain_mb = _gpu_mem_info(idx)[0] // 2**20 - if remain_mb < POST_LOAD_GAP_MB: - raise RuntimeError( - f"post-load free {remain_mb} MB < {POST_LOAD_GAP_MB} MB") + # 载入后余量检查:NVML 用全局 idx;无 NVML 时,用进程内 0 号 + if _USE_NVML: + remain_mb = _gpu_mem_info(idx)[0] // 2**20 + else: + remain_mb = _gpu_mem_info(0)[0] // 2**20 - return mdl, prec, dev # 成功 + if remain_mb < POST_LOAD_GAP_MB: + raise RuntimeError(f"post-load free {remain_mb} MB < {POST_LOAD_GAP_MB} MB") + return mdl, prec, dev except RuntimeError as e: logger.warning("%s unusable (%s) → next", dev, e) torch.cuda.empty_cache() logger.warning("No suitable GPU left → CPU fallback") - return (*load_model("cpu"), "cpu") + mdl, prec = load_model("cpu") + return mdl, prec, "cpu" # -----------------------------------------------------------------------------# # CLI # -----------------------------------------------------------------------------# +FORCE_DEVICE = os.getenv("FORCE_DEVICE") # "cpu" 或 "0" / "1" / ... + parser = argparse.ArgumentParser() -parser.add_argument( - "--device", help="GPU index (e.g. 0) or 'cpu'; overrides auto-selection" -) +parser.add_argument("--device", help="GPU index (e.g. 0) or 'cpu'; overrides auto-selection") args, _ = parser.parse_known_args() -if args.device is not None: - # Forced path +if FORCE_DEVICE is not None: + if FORCE_DEVICE.lower() == "cpu": + DEVICE = "cpu" + else: + DEVICE = f"cuda:{int(FORCE_DEVICE)}" if torch.cuda.is_available() else "cpu" + model, PRECISION = load_model(DEVICE) +elif args.device is not None: if args.device.lower() == "cpu": DEVICE = "cpu" else: DEVICE = f"cuda:{int(args.device)}" if torch.cuda.is_available() else "cpu" model, PRECISION = load_model(DEVICE) else: - # Auto path with VRAM check model, PRECISION, DEVICE = auto_select_and_load() tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) + # -----------------------------------------------------------------------------# # FastAPI # -----------------------------------------------------------------------------# app = FastAPI() -logger.info("Using SAFE_MIN_FREE_MB = %d MB", SAFE_MIN_FREE_MB) +logger.info("Using SAFE_MIN_FREE_MB = %d MB, BATCH_SIZE = %d", SAFE_MIN_FREE_MB, BATCH_SIZE) -def _warm_worker(t, q): - try: - _ = model.encode(t, return_dense=True) - q.put("ok") - except Exception as e: - q.put(str(e)) - -# ② -------- FastAPI 启动预热 -------- @app.on_event("startup") -def warm_up(): - logger.info("Warm-up on %s", DEVICE) +def _warmup(): + global _READY try: - _ = model.encode([ - "This is a warmup sentence used to initialize CUDA kernels and avoid latency spikes." - ], return_dense=True) + # 尝试用 batch_size 预热;不支持就回退 + try: + model.encode(["warmup sentence"], return_dense=True, batch_size=BATCH_SIZE) + except TypeError: + _ = _encode_chunked(model, ["warmup sentence"], max(1, min(BATCH_SIZE, 8))) + _READY = True logger.info("Warm-up complete.") except Exception as e: + _READY = False logger.warning("Warm-up failed: %s", e) +@app.get("/ready") +def ready(): + if not _READY: + # 503 让编排器/HEALTHCHECK 等待 + raise HTTPException(status_code=503, detail={"code": "not_ready", "message": "warming up"}) + return {"status": "ready"} -class EmbeddingRequest(BaseModel): - input: Union[str, List[str]] - model: str = "text-embedding-bge-m3" - - -# ③ -------- _encode() 里 worker 调用 -------- -def _worker(t, q): - try: - # out = model.encode(t, return_dense=True) # GPU or CPU 均安全 - out = model.encode(t, return_dense=True) - q.put(("ok", out)) - except Exception as e: - q.put(("err", str(e))) +# ---------------------- GPU 自愈:连续失败后重载模型 ---------------------- # +def _maybe_reload_gpu_model(): + """ + 当 GPU 连挂达到阈值时尝试重载;失败不影响请求流程(继续用 CPU 兜底)。 + """ + global _GPU_FAIL_COUNT, model, PRECISION, DEVICE + if not str(DEVICE).startswith("cuda"): + return + if _GPU_FAIL_COUNT < GPU_MAX_CONSEC_FAILS: + return + with _reload_lock: + # 双重检查,避免并发重复 reload + if _GPU_FAIL_COUNT < GPU_MAX_CONSEC_FAILS: + return + try: + logger.warning("GPU unhealthy (fail_count=%d) → reloading on %s", _GPU_FAIL_COUNT, DEVICE) + model, PRECISION = load_model(DEVICE) + _GPU_FAIL_COUNT = 0 + logger.info("GPU model reloaded on %s", DEVICE) + except Exception as e: + logger.exception("GPU reload failed: %s", e) + # 保持 CPU 兜底路径继续工作 +def _encode_chunked(model_obj, texts: List[str], chunk: int): + """不依赖 batch_size 形参的手动分块编码;返回结构与 BGEM3 一致。""" + chunk = max(1, int(chunk)) # guard + dense_all = [] + for i in range(0, len(texts), chunk): + part = texts[i:i+chunk] + out = model_obj.encode(part, return_dense=True) # 不传 batch_size + dense_all.extend(out["dense_vecs"]) + return {"dense_vecs": dense_all} def _encode(texts: List[str]): + """ + 返回 (output_dict, served_device) + 正常走 GPU;若 OOM/CUDA 错误,自动切 CPU 兜底并触发自愈计数。 + """ + served = DEVICE try: - return model.encode(texts, return_dense=True) + try: + return model.encode(texts, return_dense=True, batch_size=BATCH_SIZE), served + except TypeError: + # encode() 不支持 batch_size 形参 → 手动分块 + return _encode_chunked(model, texts, BATCH_SIZE), served except RuntimeError as e: - if "out of memory" in str(e).lower() or "cuda error" in str(e).lower(): - logger.warning("GPU OOM → fallback to CPU: %s", str(e)) + emsg = str(e).lower() + if "out of memory" in emsg or "cuda error" in emsg: + logger.warning("GPU OOM/CUDA error → fallback to CPU: %s", str(e)) torch.cuda.empty_cache() - global CPU_MODEL_CACHE + global CPU_MODEL_CACHE, _GPU_FAIL_COUNT if CPU_MODEL_CACHE is None: CPU_MODEL_CACHE, _ = load_model("cpu") - return CPU_MODEL_CACHE.encode(texts, return_dense=True) + _GPU_FAIL_COUNT += 1 # 记录连续失败 + _maybe_reload_gpu_model() # 达到阈值尝试自愈 + served = "cpu" + # CPU 也先尝试 batch_size,再回退分块 + try: + return CPU_MODEL_CACHE.encode(texts, return_dense=True, batch_size=BATCH_SIZE), served + except TypeError: + return _encode_chunked(CPU_MODEL_CACHE, texts, BATCH_SIZE), served raise +# =============== 参数规范化(入口就拦住会炸栈的输入) =============== # +def _normalize_inputs(raw) -> List[str]: + """ + 行为约束: + - 字符串:保持原行为,转为 [str](允许 "") + - 列表 :要求 list[str];不 strip、不丢弃空白(保持原行为) + - 空列表:直接 400(以前会让 HF tokenizer 抛 IndexError) + - 非字符串元素:400 + - 批量/文本超大:413(防御性限制) + """ + if isinstance(raw, str): + texts = [raw] + elif isinstance(raw, list): + if len(raw) == 0: + _http_error(400, "empty_list", "`input` is an empty list.") + for i, x in enumerate(raw): + if not isinstance(x, str): + _http_error(400, "non_string_item", f"`input[{i}]` must be a string.") + texts = raw + else: + _http_error(400, "invalid_type", "`input` must be a string or list of strings.") + if len(texts) > MAX_BATCH: + _http_error(413, "batch_too_large", f"batch size {len(texts)} > MAX_BATCH={MAX_BATCH}.") + for i, t in enumerate(texts): + if t is not None and len(t) > MAX_TEXT_LEN: + _http_error(413, "text_too_long", f"`input[{i}]` exceeds MAX_TEXT_LEN={MAX_TEXT_LEN}.") + return texts +# ------------------------------- 主路由 --------------------------------------# @app.post("/v1/embeddings") def create_embedding(request: EmbeddingRequest): - texts = [request.input] if isinstance(request.input, str) else request.input - - # Token stats - enc = tokenizer( - texts, - padding=True, - truncation=True, - max_length=8192, - return_tensors="pt", - ) - prompt_tokens = int(enc["attention_mask"].sum().item()) + # 1) 参数规范化(入口就拦住会炸栈的输入;不改变允许 "" 的原行为) + texts = _normalize_inputs(request.input) + # 2) Token 统计:失败不影响主流程(记 0) try: - output = _encode(texts) + enc = tokenizer( + texts, + padding=True, + truncation=True, + max_length=8192, + return_tensors="pt", + ) + prompt_tokens = int(enc["attention_mask"].sum().item()) + except IndexError: + # 双保险(理论上不会再发生) + _http_error(400, "empty_after_tokenize", "`input` contains no encodable texts.") + except Exception as e: + logger.warning("token counting failed: %s", e) + prompt_tokens = 0 + + # 3) 编码(GPU→CPU 兜底;自愈计数与重载) + try: + output, served_device = _encode(texts) embeddings = output["dense_vecs"] + except HTTPException: + raise except Exception as e: logger.exception("Embedding failed") - raise HTTPException(status_code=500, detail=str(e)) + _http_error(500, "embedding_failed", str(e)) + # 4) 原样返回(结构/字段保持一致) return { "object": "list", "data": [ @@ -267,10 +375,21 @@ def create_embedding(request: EmbeddingRequest): "prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens, }, - "device": DEVICE, - "precision": PRECISION, + "device": served_device, # 返回真实服务设备 + "precision": "fp32" if served_device == "cpu" else PRECISION, } +# --------------------------- 健康检查(可选) --------------------------------# +@app.get("/health") +def healthz(): + return { + "status": "ok", + "device": DEVICE, + "precision": PRECISION, + "gpu_fail_count": _GPU_FAIL_COUNT, + "batch_size": BATCH_SIZE, + "ready": _READY, + } # -----------------------------------------------------------------------------# # Entry-point for `python server.py` diff --git a/app/main.py.ok b/app/main.py.ok new file mode 100644 index 0000000..2f48691 --- /dev/null +++ b/app/main.py.ok @@ -0,0 +1,287 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +BGEM3 inference server (FastAPI) with robust GPU / CPU management. + +Launch examples: + python server.py # 自动选卡 / 自动降级 + python server.py --device 1 # 固定用第 1 张 GPU + CUDA_VISIBLE_DEVICES=0,1 python server.py +""" + +import argparse +import logging +import os +import sys +import time +from typing import List, Union +import multiprocessing as mp + +import torch +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +from transformers import AutoTokenizer +from FlagEmbedding import BGEM3FlagModel + +mp.set_start_method("spawn", force=True) + +# -----------------------------------------------------------------------------# +# Config +# -----------------------------------------------------------------------------# +MODEL_PATH = "model/bge-m3" # 按需改成你的权重路径 +MODEL_VRAM_MB = int(os.getenv("MODEL_VRAM_MB", "8000")) # bge-m3-large fp32 ≈ 8 GiB +POST_LOAD_GAP_MB = 192 +SAFE_MIN_FREE_MB = MODEL_VRAM_MB + POST_LOAD_GAP_MB # == 8192 MB + +# -----------------------------------------------------------------------------# +# Logging +# -----------------------------------------------------------------------------# +logger = logging.getLogger("bge-m3-server") +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", +) + +# -----------------------------------------------------------------------------# +# GPU memory helpers (NVML → torch fallback) +# -----------------------------------------------------------------------------# +try: + import pynvml + + pynvml.nvmlInit() + _USE_NVML = True +except Exception: + _USE_NVML = False + logger.warning("pynvml 不可用,将使用 torch.cuda.mem_get_info() 探测显存") + + +def _gpu_mem_info(idx: int) -> tuple[int, int]: + """Return (free_bytes, total_bytes) for GPU `idx`.""" + if _USE_NVML: + handle = pynvml.nvmlDeviceGetHandleByIndex(idx) + mem = pynvml.nvmlDeviceGetMemoryInfo(handle) + return mem.free, mem.total + # torch fallback + torch.cuda.set_device(idx) + free, total = torch.cuda.mem_get_info(idx) + return free, total + + +# -----------------------------------------------------------------------------# +# Precision & model loader +# -----------------------------------------------------------------------------# +def _choose_precision(device: str) -> str: + """Return 'fp16'|'bf16'|'fp32'.""" + if not device.startswith("cuda"): + return "fp32" + + major, _ = torch.cuda.get_device_capability(device) + if major >= 8: # Ampere/Hopper + return "fp16" + if major >= 7: + return "bf16" + return "fp32" + + +def load_model(device: str): + precision = _choose_precision(device) + use_fp16 = precision == "fp16" + logger.info("Loading BGEM3 on %s (%s)", device, precision) + + if device == "cpu": + # CPU 路径:彻底摘掉 CUDA + os.environ["CUDA_VISIBLE_DEVICES"] = "" + torch.cuda.is_available = lambda: False + torch.cuda.device_count = lambda: 0 + else: + # GPU 路径:只暴露本 worker 选中的那张卡 + # 例:device == "cuda:3" → 只让当前进程看到 GPU 3 + idx = device.split(":")[1] + os.environ["CUDA_VISIBLE_DEVICES"] = idx + # device_count 现在返回 1,BGEM3 只会在这张卡上建 1 个子进程 + torch.cuda.device_count = lambda: 1 + + mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device) + + # 不再包 DataParallel;每个 worker 单卡即可 + return mdl, precision + +# -----------------------------------------------------------------------------# +# Auto-select device (startup) +# -----------------------------------------------------------------------------# +def auto_select_and_load() -> tuple: + """ + 1. 过滤掉空闲显存 < MODEL_VRAM_MB 的 GPU + 2. 按空闲显存降序依次尝试加载 + 3. 载入后再次检查:若剩余 < POST_LOAD_GAP_MB → 视为失败 + 4. 若全部 GPU 不满足 → CPU + """ + if not torch.cuda.is_available(): + logger.info("No GPU detected → CPU") + return (*load_model("cpu"), "cpu") + + # 收集候选卡 (free_MB, idx) + candidates = [] + for idx in range(torch.cuda.device_count()): + free_mb = _gpu_mem_info(idx)[0] // 2**20 + if free_mb >= MODEL_VRAM_MB: # 至少能放下权重 + candidates.append((free_mb, idx)) + + if not candidates: + logger.warning("All GPUs free_mem < %d MB → CPU", MODEL_VRAM_MB) + return (*load_model("cpu"), "cpu") + + # 空闲显存从高到低 + for free_mb, idx in sorted(candidates, reverse=True): + dev = f"cuda:{idx}" + try: + logger.info("Trying %s (free=%d MB)", dev, free_mb) + mdl, prec = load_model(dev) + + remain_mb = _gpu_mem_info(idx)[0] // 2**20 + if remain_mb < POST_LOAD_GAP_MB: + raise RuntimeError( + f"post-load free {remain_mb} MB < {POST_LOAD_GAP_MB} MB") + + return mdl, prec, dev # 成功 + except RuntimeError as e: + logger.warning("%s unusable (%s) → next", dev, e) + torch.cuda.empty_cache() + + logger.warning("No suitable GPU left → CPU fallback") + return (*load_model("cpu"), "cpu") + +# -----------------------------------------------------------------------------# +# CLI +# -----------------------------------------------------------------------------# +parser = argparse.ArgumentParser() +parser.add_argument( + "--device", help="GPU index (e.g. 0) or 'cpu'; overrides auto-selection" +) +args, _ = parser.parse_known_args() + +if args.device is not None: + # Forced path + if args.device.lower() == "cpu": + DEVICE = "cpu" + else: + DEVICE = f"cuda:{int(args.device)}" if torch.cuda.is_available() else "cpu" + model, PRECISION = load_model(DEVICE) +else: + # Auto path with VRAM check + model, PRECISION, DEVICE = auto_select_and_load() + +tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) + +# -----------------------------------------------------------------------------# +# FastAPI +# -----------------------------------------------------------------------------# +app = FastAPI() +logger.info("Using SAFE_MIN_FREE_MB = %d MB", SAFE_MIN_FREE_MB) + + +def _warm_worker(t, q): + try: + _ = model.encode(t, return_dense=True) + q.put("ok") + except Exception as e: + q.put(str(e)) + +# ② -------- FastAPI 启动预热 -------- +@app.on_event("startup") +def warm_up(): + logger.info("Warm-up on %s", DEVICE) + try: + _ = model.encode([ + "This is a warmup sentence used to initialize CUDA kernels and avoid latency spikes." + ], return_dense=True) + logger.info("Warm-up complete.") + except Exception as e: + logger.warning("Warm-up failed: %s", e) + + + +class EmbeddingRequest(BaseModel): + input: Union[str, List[str]] + model: str = "text-embedding-bge-m3" + + +# ③ -------- _encode() 里 worker 调用 -------- +def _worker(t, q): + try: + # out = model.encode(t, return_dense=True) # GPU or CPU 均安全 + out = model.encode(t, return_dense=True) + q.put(("ok", out)) + except Exception as e: + q.put(("err", str(e))) + + +def _encode(texts: List[str]): + try: + return model.encode(texts, return_dense=True) + except RuntimeError as e: + if "out of memory" in str(e).lower() or "cuda error" in str(e).lower(): + logger.warning("GPU OOM → fallback to CPU: %s", str(e)) + torch.cuda.empty_cache() + global CPU_MODEL_CACHE + if CPU_MODEL_CACHE is None: + CPU_MODEL_CACHE, _ = load_model("cpu") + return CPU_MODEL_CACHE.encode(texts, return_dense=True) + raise + + + +@app.post("/v1/embeddings") +def create_embedding(request: EmbeddingRequest): + texts = [request.input] if isinstance(request.input, str) else request.input + + # Token stats + enc = tokenizer( + texts, + padding=True, + truncation=True, + max_length=8192, + return_tensors="pt", + ) + prompt_tokens = int(enc["attention_mask"].sum().item()) + + try: + output = _encode(texts) + embeddings = output["dense_vecs"] + except Exception as e: + logger.exception("Embedding failed") + raise HTTPException(status_code=500, detail=str(e)) + + return { + "object": "list", + "data": [ + { + "object": "embedding", + "index": i, + "embedding": emb.tolist() if hasattr(emb, "tolist") else emb, + } + for i, emb in enumerate(embeddings) + ], + "model": request.model, + "usage": { + "prompt_tokens": prompt_tokens, + "total_tokens": prompt_tokens, + }, + "device": DEVICE, + "precision": PRECISION, + } + + +# -----------------------------------------------------------------------------# +# Entry-point for `python server.py` +# -----------------------------------------------------------------------------# +if __name__ == "__main__": + import uvicorn + + uvicorn.run( + "server:app", + host="0.0.0.0", + port=int(os.getenv("PORT", 8000)), + log_level="info", + workers=1, # multi-process → use gunicorn/uvicorn externally + )