This commit is contained in:
hailin 2025-08-25 19:57:10 +08:00
parent b817640e8c
commit 574b65e2c0
1 changed files with 5 additions and 4 deletions

View File

@ -361,7 +361,6 @@ def main():
logging_dir=logging_dir, logging_dir=logging_dir,
do_train=True, do_train=True,
do_eval=(eval_dataset is not None), do_eval=(eval_dataset is not None),
# evaluation_strategy / eval_strategy 通过 **ta_kwargs 传入
eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None, eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
per_device_train_batch_size=args.per_device_train_batch_size, per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
@ -376,15 +375,17 @@ def main():
save_total_limit=2, save_total_limit=2,
deepspeed=args.deepspeed, deepspeed=args.deepspeed,
dataloader_drop_last=True, dataloader_drop_last=True,
dataloader_num_workers=2, dataloader_num_workers=0,
dataloader_prefetch_factor=1,
report_to=([] if args.report_to == "none" else [args.report_to]), report_to=([] if args.report_to == "none" else [args.report_to]),
bf16=args.bf16, bf16=args.bf16,
fp16=(not args.bf16), fp16=(not args.bf16),
gradient_checkpointing=args.gradient_checkpointing, gradient_checkpointing=args.gradient_checkpointing,
remove_unused_columns=False, # 需要保留我们的字段 remove_unused_columns=False,
torch_compile=False, torch_compile=False,
save_on_each_node=False, save_on_each_node=False,
**ta_kwargs, # ← 兼容参数 logging_first_step=True,
**ta_kwargs,
) )
trainer = Trainer( trainer = Trainer(