diff --git a/train_sft_ds.py b/train_sft_ds.py index 76bf3fc..2f87781 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -315,6 +315,7 @@ def parse_args(): ap.add_argument("--local_rank", type=int, default=-1, help="for deepspeed/torchrun launcher; ignored by user code") ap.add_argument("--per_device_eval_batch_size", type=int, default=1) + ap.add_argument("--deepspeed", type=str, default=None) return ap.parse_args() @@ -590,7 +591,7 @@ def main(): logging_steps=args.log_interval, save_steps=args.save_steps, save_total_limit=2, - deepspeed=args.deepspeed, + deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None), dataloader_drop_last=False, # 关键:别丢尾,避免空 batch dataloader_num_workers=0, dataloader_prefetch_factor=None,