This commit is contained in:
parent
120a8e6980
commit
3a45712277
|
|
@ -3,6 +3,7 @@ import os
|
||||||
import glob
|
import glob
|
||||||
import socket
|
import socket
|
||||||
import argparse
|
import argparse
|
||||||
|
import inspect
|
||||||
from typing import Dict, List, Iterable, Iterator, Tuple, Optional
|
from typing import Dict, List, Iterable, Iterator, Tuple, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -341,12 +342,26 @@ def main():
|
||||||
logging_dir = os.path.join(args.output_dir, "logs")
|
logging_dir = os.path.join(args.output_dir, "logs")
|
||||||
os.makedirs(logging_dir, exist_ok=True)
|
os.makedirs(logging_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# ---- 兼容 4.51(eval_strategy)与旧版(evaluation_strategy) ----
|
||||||
|
ta_kwargs = {}
|
||||||
|
sig = inspect.signature(TrainingArguments.__init__).parameters
|
||||||
|
if eval_dataset is not None:
|
||||||
|
if "eval_strategy" in sig:
|
||||||
|
ta_kwargs["eval_strategy"] = "steps"
|
||||||
|
elif "evaluation_strategy" in sig:
|
||||||
|
ta_kwargs["evaluation_strategy"] = "steps"
|
||||||
|
else:
|
||||||
|
if "eval_strategy" in sig:
|
||||||
|
ta_kwargs["eval_strategy"] = "no"
|
||||||
|
elif "evaluation_strategy" in sig:
|
||||||
|
ta_kwargs["evaluation_strategy"] = "no"
|
||||||
|
|
||||||
training_args = TrainingArguments(
|
training_args = TrainingArguments(
|
||||||
output_dir=args.output_dir,
|
output_dir=args.output_dir,
|
||||||
logging_dir=logging_dir,
|
logging_dir=logging_dir,
|
||||||
do_train=True,
|
do_train=True,
|
||||||
do_eval=(eval_dataset is not None),
|
do_eval=(eval_dataset is not None),
|
||||||
evaluation_strategy=("steps" if eval_dataset is not None else "no"),
|
# evaluation_strategy / eval_strategy 通过 **ta_kwargs 传入
|
||||||
eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
|
eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
|
||||||
per_device_train_batch_size=args.per_device_train_batch_size,
|
per_device_train_batch_size=args.per_device_train_batch_size,
|
||||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
|
|
@ -368,7 +383,8 @@ def main():
|
||||||
gradient_checkpointing=args.gradient_checkpointing,
|
gradient_checkpointing=args.gradient_checkpointing,
|
||||||
remove_unused_columns=False, # 需要保留我们的字段
|
remove_unused_columns=False, # 需要保留我们的字段
|
||||||
torch_compile=False,
|
torch_compile=False,
|
||||||
save_on_each_node=False
|
save_on_each_node=False,
|
||||||
|
**ta_kwargs, # ← 兼容参数
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue