diff --git a/train_sft_ds.py b/train_sft_ds.py index 8ec18b9..c1a4500 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -634,9 +634,7 @@ def main(): save_steps=args.save_steps, save_total_limit=2, deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None), - # dataloader_drop_last=False, # 关键:别丢尾,避免空 batch - dataloader_drop_last=True, # 关键:别丢尾,避免空 batch - dispatch_batches=False, + dataloader_drop_last=False, # 关键:别丢尾,避免空 batch dataloader_num_workers=0, dataloader_prefetch_factor=None, dataloader_pin_memory=False,