From b4c29cde486ad57ab6365e7a46e5be7f190dffe9 Mon Sep 17 00:00:00 2001 From: hailin Date: Sat, 30 Aug 2025 15:46:41 +0800 Subject: [PATCH] . --- train_sft_ds.py | 47 +++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 39 insertions(+), 8 deletions(-) diff --git a/train_sft_ds.py b/train_sft_ds.py index 055ce42..ee0f055 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -26,6 +26,29 @@ from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import get_last_checkpoint +# ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ==== +import os, sys, site, shutil + +home = os.path.expanduser("~") +want = [f"{home}/.local/bin", "/usr/local/cuda-11.8/bin"] +cur = os.environ.get("PATH", "").split(":") +new = [d for d in want if d and d not in cur] + cur +os.environ["PATH"] = ":".join(new) + +# 可见性打印,方便你在日志里确认 tn06 是否拿到了 +print(f"[env] PATH={os.environ['PATH']}", flush=True) +print(f"[env] which ninja={shutil.which('ninja')} which nvcc={shutil.which('nvcc')}", flush=True) + + +os.environ.setdefault("CUDA_HOME", "/usr/local/cuda-11.8") +ld = os.environ.get("LD_LIBRARY_PATH", "") +cuda_lib = "/usr/local/cuda-11.8/lib64" +if cuda_lib not in ld.split(":"): + os.environ["LD_LIBRARY_PATH"] = f"{cuda_lib}:{ld}" if ld else cuda_lib + +# 可视化确认 +import torch +print(f"[env] torch.version.cuda={torch.version.cuda} CUDA_HOME={os.environ['CUDA_HOME']}", flush=True) # ==== ensure python can see user site & set torch extensions dir ==== import os, sys, site @@ -45,22 +68,30 @@ except Exception: # 3) 统一 JIT 缓存目录(可选,但更稳;日志里你现在用的是 ~/.cache/torch_extensions) os.environ.setdefault("TORCH_EXTENSIONS_DIR", f"/tmp/{os.environ.get('USER','user')}/torch_ext") -os.environ.setdefault("TORCH_EXTENSIONS_DIR", f"/tmp/{os.environ.get('USER','user')}/torch_ext") os.environ.setdefault("MAX_JOBS", "8") - +import shutil +if shutil.which("ninja") is None: + os.environ["USE_NINJA"] = "0" + print("[env] no CLI ninja on PATH -> USE_NINJA=0 fallback", flush=True) + # 4) 立即验证 ninja 与 CPUAdam 的 JIT(若这里失败,日志会第一时间告诉你是哪台/哪 rank 环境不对) try: - import ninja - print(f"[env] ninja {getattr(ninja,'__version__','?')} @ {getattr(ninja,'__file__','?')}", flush=True) from deepspeed.ops.op_builder import CPUAdamBuilder CPUAdamBuilder().load() print("[env] CPUAdamBuilder JIT OK", flush=True) except Exception as e: - import socket - print(f"[env][host={socket.gethostname()} RANK={os.environ.get('RANK','?')}] PRE-JIT FAILED: {e}", flush=True) - raise - + # ninja 可执行找不到时走兜底:禁用 ninja,用 setuptools 构建(首次会慢一点,但必过) + if "Ninja is required to load C++ extensions" in str(e): + os.environ["USE_NINJA"] = "0" + print("[env] no CLI ninja, retry with USE_NINJA=0 (fallback build)", flush=True) + from deepspeed.ops.op_builder import CPUAdamBuilder + CPUAdamBuilder().load() + print("[env] CPUAdamBuilder JIT OK (fallback)", flush=True) + else: + import socket + print(f"[env][host={socket.gethostname()} RANK={os.environ.get('RANK','?')}] PRE-JIT FAILED: {e}", flush=True) + raise # ----------------- 进程工具 -----------------