This commit is contained in:
parent
120a8e6980
commit
3a45712277
|
|
@ -3,6 +3,7 @@ import os
|
|||
import glob
|
||||
import socket
|
||||
import argparse
|
||||
import inspect
|
||||
from typing import Dict, List, Iterable, Iterator, Tuple, Optional
|
||||
|
||||
import torch
|
||||
|
|
@ -341,12 +342,26 @@ def main():
|
|||
logging_dir = os.path.join(args.output_dir, "logs")
|
||||
os.makedirs(logging_dir, exist_ok=True)
|
||||
|
||||
# ---- 兼容 4.51(eval_strategy)与旧版(evaluation_strategy) ----
|
||||
ta_kwargs = {}
|
||||
sig = inspect.signature(TrainingArguments.__init__).parameters
|
||||
if eval_dataset is not None:
|
||||
if "eval_strategy" in sig:
|
||||
ta_kwargs["eval_strategy"] = "steps"
|
||||
elif "evaluation_strategy" in sig:
|
||||
ta_kwargs["evaluation_strategy"] = "steps"
|
||||
else:
|
||||
if "eval_strategy" in sig:
|
||||
ta_kwargs["eval_strategy"] = "no"
|
||||
elif "evaluation_strategy" in sig:
|
||||
ta_kwargs["evaluation_strategy"] = "no"
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir=args.output_dir,
|
||||
logging_dir=logging_dir,
|
||||
do_train=True,
|
||||
do_eval=(eval_dataset is not None),
|
||||
evaluation_strategy=("steps" if eval_dataset is not None else "no"),
|
||||
# evaluation_strategy / eval_strategy 通过 **ta_kwargs 传入
|
||||
eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
|
||||
per_device_train_batch_size=args.per_device_train_batch_size,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
|
|
@ -368,7 +383,8 @@ def main():
|
|||
gradient_checkpointing=args.gradient_checkpointing,
|
||||
remove_unused_columns=False, # 需要保留我们的字段
|
||||
torch_compile=False,
|
||||
save_on_each_node=False
|
||||
save_on_each_node=False,
|
||||
**ta_kwargs, # ← 兼容参数
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
|
|
|
|||
Loading…
Reference in New Issue