This commit is contained in:
parent
c473527297
commit
b4c29cde48
|
|
@ -26,6 +26,29 @@ from transformers.trainer_callback import TrainerCallback
|
|||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
|
||||
# ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ====
|
||||
import os, sys, site, shutil
|
||||
|
||||
home = os.path.expanduser("~")
|
||||
want = [f"{home}/.local/bin", "/usr/local/cuda-11.8/bin"]
|
||||
cur = os.environ.get("PATH", "").split(":")
|
||||
new = [d for d in want if d and d not in cur] + cur
|
||||
os.environ["PATH"] = ":".join(new)
|
||||
|
||||
# 可见性打印,方便你在日志里确认 tn06 是否拿到了
|
||||
print(f"[env] PATH={os.environ['PATH']}", flush=True)
|
||||
print(f"[env] which ninja={shutil.which('ninja')} which nvcc={shutil.which('nvcc')}", flush=True)
|
||||
|
||||
|
||||
os.environ.setdefault("CUDA_HOME", "/usr/local/cuda-11.8")
|
||||
ld = os.environ.get("LD_LIBRARY_PATH", "")
|
||||
cuda_lib = "/usr/local/cuda-11.8/lib64"
|
||||
if cuda_lib not in ld.split(":"):
|
||||
os.environ["LD_LIBRARY_PATH"] = f"{cuda_lib}:{ld}" if ld else cuda_lib
|
||||
|
||||
# 可视化确认
|
||||
import torch
|
||||
print(f"[env] torch.version.cuda={torch.version.cuda} CUDA_HOME={os.environ['CUDA_HOME']}", flush=True)
|
||||
|
||||
# ==== ensure python can see user site & set torch extensions dir ====
|
||||
import os, sys, site
|
||||
|
|
@ -45,24 +68,32 @@ except Exception:
|
|||
|
||||
# 3) 统一 JIT 缓存目录(可选,但更稳;日志里你现在用的是 ~/.cache/torch_extensions)
|
||||
os.environ.setdefault("TORCH_EXTENSIONS_DIR", f"/tmp/{os.environ.get('USER','user')}/torch_ext")
|
||||
os.environ.setdefault("TORCH_EXTENSIONS_DIR", f"/tmp/{os.environ.get('USER','user')}/torch_ext")
|
||||
os.environ.setdefault("MAX_JOBS", "8")
|
||||
|
||||
import shutil
|
||||
if shutil.which("ninja") is None:
|
||||
os.environ["USE_NINJA"] = "0"
|
||||
print("[env] no CLI ninja on PATH -> USE_NINJA=0 fallback", flush=True)
|
||||
|
||||
# 4) 立即验证 ninja 与 CPUAdam 的 JIT(若这里失败,日志会第一时间告诉你是哪台/哪 rank 环境不对)
|
||||
try:
|
||||
import ninja
|
||||
print(f"[env] ninja {getattr(ninja,'__version__','?')} @ {getattr(ninja,'__file__','?')}", flush=True)
|
||||
from deepspeed.ops.op_builder import CPUAdamBuilder
|
||||
CPUAdamBuilder().load()
|
||||
print("[env] CPUAdamBuilder JIT OK", flush=True)
|
||||
except Exception as e:
|
||||
# ninja 可执行找不到时走兜底:禁用 ninja,用 setuptools 构建(首次会慢一点,但必过)
|
||||
if "Ninja is required to load C++ extensions" in str(e):
|
||||
os.environ["USE_NINJA"] = "0"
|
||||
print("[env] no CLI ninja, retry with USE_NINJA=0 (fallback build)", flush=True)
|
||||
from deepspeed.ops.op_builder import CPUAdamBuilder
|
||||
CPUAdamBuilder().load()
|
||||
print("[env] CPUAdamBuilder JIT OK (fallback)", flush=True)
|
||||
else:
|
||||
import socket
|
||||
print(f"[env][host={socket.gethostname()} RANK={os.environ.get('RANK','?')}] PRE-JIT FAILED: {e}", flush=True)
|
||||
raise
|
||||
|
||||
|
||||
|
||||
# ----------------- 进程工具 -----------------
|
||||
def is_main_process():
|
||||
return int(os.environ.get("RANK", "0")) == 0
|
||||
|
|
|
|||
Loading…
Reference in New Issue