This commit is contained in:
parent
7706fcf842
commit
ef116d1bc8
|
|
@ -672,9 +672,15 @@ def main():
|
|||
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
|
||||
|
||||
# 无共享盘:各节点本地 output_dir 下是否已有 checkpoint-*
|
||||
ckpt_exists = (os.path.isdir(args.output_dir)
|
||||
and any(n.startswith("checkpoint-") for n in os.listdir(args.output_dir)))
|
||||
resume_flag = True if ckpt_exists else None
|
||||
# ckpt_exists = (os.path.isdir(args.output_dir)
|
||||
# and any(n.startswith("checkpoint-") for n in os.listdir(args.output_dir)))
|
||||
# resume_flag = True if ckpt_exists else None
|
||||
|
||||
ckpt_local = 1 if (os.path.isdir(args.output_dir) and any(n.startswith("checkpoint-") for n in os.listdir(args.output_dir))) else 0
|
||||
ckpt_tensor = torch.tensor(ckpt_local, device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu"))
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
dist.all_reduce(ckpt_tensor, op=dist.ReduceOp.MAX)
|
||||
resume_flag = True if ckpt_tensor.item() > 0 else None
|
||||
|
||||
print_once(f"[host={socket.gethostname()}] Resume = {resume_flag is True}")
|
||||
print_once("***** Starting training *****")
|
||||
|
|
|
|||
Loading…
Reference in New Issue