diff --git a/train_sft_ds.py b/train_sft_ds.py index ee0f055..5038b91 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -847,6 +847,7 @@ def main(): logging_steps=args.log_interval, save_steps=args.save_steps, save_total_limit=2, + optim="adamw_torch", # deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None), deepspeed=(args.deepspeed if use_ds else None), dataloader_drop_last=False,