This commit is contained in:
parent
04146a7f28
commit
53285eecad
44
app/main.py
44
app/main.py
|
|
@ -27,12 +27,9 @@ from FlagEmbedding import BGEM3FlagModel
|
||||||
# Config
|
# Config
|
||||||
# -----------------------------------------------------------------------------#
|
# -----------------------------------------------------------------------------#
|
||||||
MODEL_PATH = "model/bge-m3" # 按需改成你的权重路径
|
MODEL_PATH = "model/bge-m3" # 按需改成你的权重路径
|
||||||
SAFE_MIN_FREE_MB = int(
|
|
||||||
os.getenv("SAFE_MIN_FREE_MB", "16384")
|
|
||||||
) # 启动时要求的最小空闲显存,MB
|
|
||||||
|
|
||||||
MODEL_VRAM_MB = int(os.getenv("MODEL_VRAM_MB", "16384")) # bge-m3-large fp16 ≈ 16 GiB
|
MODEL_VRAM_MB = int(os.getenv("MODEL_VRAM_MB", "16384")) # bge-m3-large fp16 ≈ 16 GiB
|
||||||
POST_LOAD_GAP_MB = 512
|
POST_LOAD_GAP_MB = 512
|
||||||
|
SAFE_MIN_FREE_MB = MODEL_VRAM_MB + POST_LOAD_GAP_MB # == 16896 MB
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------#
|
# -----------------------------------------------------------------------------#
|
||||||
# Logging
|
# Logging
|
||||||
|
|
@ -110,7 +107,7 @@ def load_model(device: str):
|
||||||
# -----------------------------------------------------------------------------#
|
# -----------------------------------------------------------------------------#
|
||||||
# Auto-select device (startup)
|
# Auto-select device (startup)
|
||||||
# -----------------------------------------------------------------------------#
|
# -----------------------------------------------------------------------------#
|
||||||
def auto_select_and_load(min_free_mb: int = 4096):
|
def auto_select_and_load() -> tuple:
|
||||||
"""
|
"""
|
||||||
1. 过滤掉空闲显存 < MODEL_VRAM_MB 的 GPU
|
1. 过滤掉空闲显存 < MODEL_VRAM_MB 的 GPU
|
||||||
2. 按空闲显存降序依次尝试加载
|
2. 按空闲显存降序依次尝试加载
|
||||||
|
|
@ -121,47 +118,37 @@ def auto_select_and_load(min_free_mb: int = 4096):
|
||||||
logger.info("No GPU detected → CPU")
|
logger.info("No GPU detected → CPU")
|
||||||
return (*load_model("cpu"), "cpu")
|
return (*load_model("cpu"), "cpu")
|
||||||
|
|
||||||
# 收集每张卡的空闲显存
|
# 收集候选卡 (free_MB, idx)
|
||||||
candidates: list[tuple[int, int]] = [] # (free_MB, idx)
|
candidates = []
|
||||||
for idx in range(torch.cuda.device_count()):
|
for idx in range(torch.cuda.device_count()):
|
||||||
free, _ = _gpu_mem_info(idx)
|
free_mb = _gpu_mem_info(idx)[0] // 2**20
|
||||||
free_mb = free // 2**20
|
if free_mb >= MODEL_VRAM_MB: # 至少能放下权重
|
||||||
candidates.append((free_mb, idx))
|
candidates.append((free_mb, idx))
|
||||||
|
|
||||||
# 先按 min_free_mb 做一次粗过滤(通常 16384 MB)
|
|
||||||
candidates = [c for c in candidates if c[0] >= MODEL_VRAM_MB]
|
|
||||||
if not candidates:
|
if not candidates:
|
||||||
logger.warning("All GPUs free_mem < %d MB → CPU", MODEL_VRAM_MB)
|
logger.warning("All GPUs free_mem < %d MB → CPU", MODEL_VRAM_MB)
|
||||||
return (*load_model("cpu"), "cpu")
|
return (*load_model("cpu"), "cpu")
|
||||||
|
|
||||||
# 按空闲显存从高到低尝试
|
# 空闲显存从高到低
|
||||||
candidates.sort(reverse=True)
|
for free_mb, idx in sorted(candidates, reverse=True):
|
||||||
for free_mb, idx in candidates:
|
|
||||||
dev = f"cuda:{idx}"
|
dev = f"cuda:{idx}"
|
||||||
try:
|
try:
|
||||||
logger.info("Trying %s (free=%d MB)", dev, free_mb)
|
logger.info("Trying %s (free=%d MB)", dev, free_mb)
|
||||||
mdl, prec = load_model(dev)
|
mdl, prec = load_model(dev)
|
||||||
|
|
||||||
# 载入后再测一次,确保还剩余 ≥ POST_LOAD_GAP_MB
|
remain_mb = _gpu_mem_info(idx)[0] // 2**20
|
||||||
new_free, _ = _gpu_mem_info(idx)
|
if remain_mb < POST_LOAD_GAP_MB:
|
||||||
new_free_mb = new_free // 2**20
|
|
||||||
if new_free_mb < POST_LOAD_GAP_MB:
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"post-load free {new_free_mb} MB < {POST_LOAD_GAP_MB} MB"
|
f"post-load free {remain_mb} MB < {POST_LOAD_GAP_MB} MB")
|
||||||
)
|
|
||||||
|
|
||||||
return mdl, prec, dev # GPU 占用成功
|
|
||||||
|
|
||||||
|
return mdl, prec, dev # 成功
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.warning("%s unusable (%s) → next", dev, e)
|
logger.warning("%s unusable (%s) → next", dev, e)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
continue
|
|
||||||
|
|
||||||
# 全部 GPU 都不满足
|
|
||||||
logger.warning("No suitable GPU left → CPU fallback")
|
logger.warning("No suitable GPU left → CPU fallback")
|
||||||
return (*load_model("cpu"), "cpu")
|
return (*load_model("cpu"), "cpu")
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------#
|
# -----------------------------------------------------------------------------#
|
||||||
# CLI
|
# CLI
|
||||||
# -----------------------------------------------------------------------------#
|
# -----------------------------------------------------------------------------#
|
||||||
|
|
@ -180,7 +167,7 @@ if args.device is not None:
|
||||||
model, PRECISION = load_model(DEVICE)
|
model, PRECISION = load_model(DEVICE)
|
||||||
else:
|
else:
|
||||||
# Auto path with VRAM check
|
# Auto path with VRAM check
|
||||||
model, PRECISION, DEVICE = auto_select_and_load(SAFE_MIN_FREE_MB)
|
model, PRECISION, DEVICE = auto_select_and_load()
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||||
|
|
||||||
|
|
@ -195,7 +182,7 @@ logger.info("Using SAFE_MIN_FREE_MB = %d MB", SAFE_MIN_FREE_MB)
|
||||||
async def warm_up():
|
async def warm_up():
|
||||||
logger.info("Warm-up on %s", DEVICE)
|
logger.info("Warm-up on %s", DEVICE)
|
||||||
try:
|
try:
|
||||||
_ = model.encode(["warmup"], return_dense=True) # GPU 建池,CPU 单进程
|
_ = model.encode(["warmup"], return_dense=True, num_processes=1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Warm-up failed: %s — 首条请求时再退避", e)
|
logger.warning("Warm-up failed: %s — 首条请求时再退避", e)
|
||||||
|
|
||||||
|
|
@ -216,7 +203,8 @@ def _encode(texts: List[str]):
|
||||||
# ③ -------- _encode() 里 worker 调用 --------
|
# ③ -------- _encode() 里 worker 调用 --------
|
||||||
def _worker(t, q):
|
def _worker(t, q):
|
||||||
try:
|
try:
|
||||||
out = model.encode(t, return_dense=True) # GPU or CPU 均安全
|
# out = model.encode(t, return_dense=True) # GPU or CPU 均安全
|
||||||
|
out = model.encode(t, return_dense=True, num_processes=1)
|
||||||
q.put(("ok", out))
|
q.put(("ok", out))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
q.put(("err", str(e)))
|
q.put(("err", str(e)))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue