This commit is contained in:
hailin 2025-08-25 18:33:25 +08:00
parent 120a8e6980
commit 3a45712277
1 changed files with 19 additions and 3 deletions

View File

@ -3,6 +3,7 @@ import os
import glob
import socket
import argparse
import inspect
from typing import Dict, List, Iterable, Iterator, Tuple, Optional
import torch
@ -217,7 +218,7 @@ def parse_args():
ap.add_argument("--eval_data_glob", type=str, default=None,
help="(可选) 测试/验证集 jsonl 通配符;如提供则优先使用")
ap.add_argument("--local_rank", type=int, default=-1,
help="for deepspeed/torchrun launcher; ignored by user code")
help="for deepspeed/torchrun launcher; ignored by user code")
return ap.parse_args()
@ -341,12 +342,26 @@ def main():
logging_dir = os.path.join(args.output_dir, "logs")
os.makedirs(logging_dir, exist_ok=True)
# ---- 兼容 4.51eval_strategy与旧版evaluation_strategy ----
ta_kwargs = {}
sig = inspect.signature(TrainingArguments.__init__).parameters
if eval_dataset is not None:
if "eval_strategy" in sig:
ta_kwargs["eval_strategy"] = "steps"
elif "evaluation_strategy" in sig:
ta_kwargs["evaluation_strategy"] = "steps"
else:
if "eval_strategy" in sig:
ta_kwargs["eval_strategy"] = "no"
elif "evaluation_strategy" in sig:
ta_kwargs["evaluation_strategy"] = "no"
training_args = TrainingArguments(
output_dir=args.output_dir,
logging_dir=logging_dir,
do_train=True,
do_eval=(eval_dataset is not None),
evaluation_strategy=("steps" if eval_dataset is not None else "no"),
# evaluation_strategy / eval_strategy 通过 **ta_kwargs 传入
eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
@ -368,7 +383,8 @@ def main():
gradient_checkpointing=args.gradient_checkpointing,
remove_unused_columns=False, # 需要保留我们的字段
torch_compile=False,
save_on_each_node=False
save_on_each_node=False,
**ta_kwargs, # ← 兼容参数
)
trainer = Trainer(