diff --git a/train_sft_ds.py b/train_sft_ds.py index a85fee5..91b3c9b 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -4,10 +4,7 @@ import os os.environ.pop("PYTHONNOUSERSITE", None) os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") -# ✅ 只有 rank0 用 wandb,其它 rank 不上报 -if os.environ.get("RANK", "0") != "0" and args.report_to == "wandb": - print(f"[rank {os.environ.get('RANK')}] force report_to=none", flush=True) - args.report_to = "none" + os.environ.setdefault("WANDB_START_METHOD", "thread") os.environ.setdefault("WANDB_DIR", f"/tmp/{os.environ.get('USER','user')}/wandb") @@ -451,6 +448,12 @@ def parse_args(): def main(): args = parse_args() + + # ✅ 只有 rank0 用 wandb,其它 rank 不上报 + if os.environ.get("RANK", "0") != "0" and args.report_to == "wandb": + print(f"[rank {os.environ.get('RANK')}] force report_to=none", flush=True) + args.report_to = "none" + set_seed(args.seed) host = socket.gethostname()