diff --git a/train_sft_ds.py b/train_sft_ds.py index 4962edc..c1a4500 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -672,9 +672,15 @@ def main(): trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv"))) # 无共享盘:各节点本地 output_dir 下是否已有 checkpoint-* - ckpt_exists = (os.path.isdir(args.output_dir) - and any(n.startswith("checkpoint-") for n in os.listdir(args.output_dir))) - resume_flag = True if ckpt_exists else None + # ckpt_exists = (os.path.isdir(args.output_dir) + # and any(n.startswith("checkpoint-") for n in os.listdir(args.output_dir))) + # resume_flag = True if ckpt_exists else None + + ckpt_local = 1 if (os.path.isdir(args.output_dir) and any(n.startswith("checkpoint-") for n in os.listdir(args.output_dir))) else 0 + ckpt_tensor = torch.tensor(ckpt_local, device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu")) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(ckpt_tensor, op=dist.ReduceOp.MAX) + resume_flag = True if ckpt_tensor.item() > 0 else None print_once(f"[host={socket.gethostname()}] Resume = {resume_flag is True}") print_once("***** Starting training *****")