This commit is contained in:
parent
b817640e8c
commit
574b65e2c0
|
|
@ -361,7 +361,6 @@ def main():
|
||||||
logging_dir=logging_dir,
|
logging_dir=logging_dir,
|
||||||
do_train=True,
|
do_train=True,
|
||||||
do_eval=(eval_dataset is not None),
|
do_eval=(eval_dataset is not None),
|
||||||
# evaluation_strategy / eval_strategy 通过 **ta_kwargs 传入
|
|
||||||
eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
|
eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
|
||||||
per_device_train_batch_size=args.per_device_train_batch_size,
|
per_device_train_batch_size=args.per_device_train_batch_size,
|
||||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
|
|
@ -376,15 +375,17 @@ def main():
|
||||||
save_total_limit=2,
|
save_total_limit=2,
|
||||||
deepspeed=args.deepspeed,
|
deepspeed=args.deepspeed,
|
||||||
dataloader_drop_last=True,
|
dataloader_drop_last=True,
|
||||||
dataloader_num_workers=2,
|
dataloader_num_workers=0,
|
||||||
|
dataloader_prefetch_factor=1,
|
||||||
report_to=([] if args.report_to == "none" else [args.report_to]),
|
report_to=([] if args.report_to == "none" else [args.report_to]),
|
||||||
bf16=args.bf16,
|
bf16=args.bf16,
|
||||||
fp16=(not args.bf16),
|
fp16=(not args.bf16),
|
||||||
gradient_checkpointing=args.gradient_checkpointing,
|
gradient_checkpointing=args.gradient_checkpointing,
|
||||||
remove_unused_columns=False, # 需要保留我们的字段
|
remove_unused_columns=False,
|
||||||
torch_compile=False,
|
torch_compile=False,
|
||||||
save_on_each_node=False,
|
save_on_each_node=False,
|
||||||
**ta_kwargs, # ← 兼容参数
|
logging_first_step=True,
|
||||||
|
**ta_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue