diff --git a/train_sft_ds.py b/train_sft_ds.py index 2f78eaa..a85fee5 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -4,8 +4,10 @@ import os os.environ.pop("PYTHONNOUSERSITE", None) os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") -if os.environ.get("RANK","0") != "0": - os.environ["WANDB_DISABLED"] = "true" +# ✅ 只有 rank0 用 wandb,其它 rank 不上报 +if os.environ.get("RANK", "0") != "0" and args.report_to == "wandb": + print(f"[rank {os.environ.get('RANK')}] force report_to=none", flush=True) + args.report_to = "none" os.environ.setdefault("WANDB_START_METHOD", "thread") os.environ.setdefault("WANDB_DIR", f"/tmp/{os.environ.get('USER','user')}/wandb")