This commit is contained in:
parent
30b865b0ff
commit
c119771511
18
app/main.py
18
app/main.py
|
|
@ -86,16 +86,22 @@ def load_model(device: str):
|
|||
use_fp16 = precision == "fp16"
|
||||
logger.info("Loading BGEM3 on %s (%s)", device, precision)
|
||||
|
||||
if device == "cpu": # ← 仅 CPU 路径
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "" # 彻底摘掉 CUDA
|
||||
if device == "cpu":
|
||||
# CPU 路径:彻底摘掉 CUDA
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||
torch.cuda.is_available = lambda: False
|
||||
torch.cuda.device_count = lambda: 0
|
||||
else:
|
||||
# GPU 路径:只暴露本 worker 选中的那张卡
|
||||
# 例:device == "cuda:3" → 只让当前进程看到 GPU 3
|
||||
idx = device.split(":")[1]
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = idx
|
||||
# device_count 现在返回 1,BGEM3 只会在这张卡上建 1 个子进程
|
||||
torch.cuda.device_count = lambda: 1
|
||||
|
||||
mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device)
|
||||
|
||||
# # 多 GPU 时才包 DataParallel
|
||||
# if device.startswith("cuda") and torch.cuda.device_count() > 1:
|
||||
# mdl = torch.nn.DataParallel(mdl)
|
||||
|
||||
# 不再包 DataParallel;每个 worker 单卡即可
|
||||
return mdl, precision
|
||||
|
||||
# -----------------------------------------------------------------------------#
|
||||
|
|
|
|||
Loading…
Reference in New Issue