This commit is contained in:
parent
b817640e8c
commit
574b65e2c0
|
|
@ -361,7 +361,6 @@ def main():
|
|||
logging_dir=logging_dir,
|
||||
do_train=True,
|
||||
do_eval=(eval_dataset is not None),
|
||||
# evaluation_strategy / eval_strategy 通过 **ta_kwargs 传入
|
||||
eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
|
||||
per_device_train_batch_size=args.per_device_train_batch_size,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
|
|
@ -376,15 +375,17 @@ def main():
|
|||
save_total_limit=2,
|
||||
deepspeed=args.deepspeed,
|
||||
dataloader_drop_last=True,
|
||||
dataloader_num_workers=2,
|
||||
dataloader_num_workers=0,
|
||||
dataloader_prefetch_factor=1,
|
||||
report_to=([] if args.report_to == "none" else [args.report_to]),
|
||||
bf16=args.bf16,
|
||||
fp16=(not args.bf16),
|
||||
gradient_checkpointing=args.gradient_checkpointing,
|
||||
remove_unused_columns=False, # 需要保留我们的字段
|
||||
remove_unused_columns=False,
|
||||
torch_compile=False,
|
||||
save_on_each_node=False,
|
||||
**ta_kwargs, # ← 兼容参数
|
||||
logging_first_step=True,
|
||||
**ta_kwargs,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
|
|
|
|||
Loading…
Reference in New Issue