diff --git a/train_sft_ds.py b/train_sft_ds.py index c1a4500..8ec18b9 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -634,7 +634,9 @@ def main(): save_steps=args.save_steps, save_total_limit=2, deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None), - dataloader_drop_last=False, # 关键:别丢尾,避免空 batch + # dataloader_drop_last=False, # 关键:别丢尾,避免空 batch + dataloader_drop_last=True, # 关键:别丢尾,避免空 batch + dispatch_batches=False, dataloader_num_workers=0, dataloader_prefetch_factor=None, dataloader_pin_memory=False,