This commit is contained in:
hailin 2025-08-05 15:44:09 +08:00
parent 30b865b0ff
commit c119771511
1 changed files with 13 additions and 7 deletions

View File

@ -86,16 +86,22 @@ def load_model(device: str):
use_fp16 = precision == "fp16"
logger.info("Loading BGEM3 on %s (%s)", device, precision)
if device == "cpu": # ← 仅 CPU 路径
os.environ["CUDA_VISIBLE_DEVICES"] = "" # 彻底摘掉 CUDA
if device == "cpu":
# CPU 路径:彻底摘掉 CUDA
os.environ["CUDA_VISIBLE_DEVICES"] = ""
torch.cuda.is_available = lambda: False
torch.cuda.device_count = lambda: 0
else:
# GPU 路径:只暴露本 worker 选中的那张卡
# 例device == "cuda:3" → 只让当前进程看到 GPU 3
idx = device.split(":")[1]
os.environ["CUDA_VISIBLE_DEVICES"] = idx
# device_count 现在返回 1BGEM3 只会在这张卡上建 1 个子进程
torch.cuda.device_count = lambda: 1
mdl = BGEM3FlagModel(MODEL_PATH, use_fp16=use_fp16, device=device)
# # 多 GPU 时才包 DataParallel
# if device.startswith("cuda") and torch.cuda.device_count() > 1:
# mdl = torch.nn.DataParallel(mdl)
# 不再包 DataParallel每个 worker 单卡即可
return mdl, precision
# -----------------------------------------------------------------------------#