This commit is contained in:
hailin 2025-08-25 18:33:25 +08:00
parent 120a8e6980
commit 3a45712277
1 changed files with 19 additions and 3 deletions

View File

@ -3,6 +3,7 @@ import os
import glob import glob
import socket import socket
import argparse import argparse
import inspect
from typing import Dict, List, Iterable, Iterator, Tuple, Optional from typing import Dict, List, Iterable, Iterator, Tuple, Optional
import torch import torch
@ -217,7 +218,7 @@ def parse_args():
ap.add_argument("--eval_data_glob", type=str, default=None, ap.add_argument("--eval_data_glob", type=str, default=None,
help="(可选) 测试/验证集 jsonl 通配符;如提供则优先使用") help="(可选) 测试/验证集 jsonl 通配符;如提供则优先使用")
ap.add_argument("--local_rank", type=int, default=-1, ap.add_argument("--local_rank", type=int, default=-1,
help="for deepspeed/torchrun launcher; ignored by user code") help="for deepspeed/torchrun launcher; ignored by user code")
return ap.parse_args() return ap.parse_args()
@ -341,12 +342,26 @@ def main():
logging_dir = os.path.join(args.output_dir, "logs") logging_dir = os.path.join(args.output_dir, "logs")
os.makedirs(logging_dir, exist_ok=True) os.makedirs(logging_dir, exist_ok=True)
# ---- 兼容 4.51eval_strategy与旧版evaluation_strategy ----
ta_kwargs = {}
sig = inspect.signature(TrainingArguments.__init__).parameters
if eval_dataset is not None:
if "eval_strategy" in sig:
ta_kwargs["eval_strategy"] = "steps"
elif "evaluation_strategy" in sig:
ta_kwargs["evaluation_strategy"] = "steps"
else:
if "eval_strategy" in sig:
ta_kwargs["eval_strategy"] = "no"
elif "evaluation_strategy" in sig:
ta_kwargs["evaluation_strategy"] = "no"
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir=args.output_dir, output_dir=args.output_dir,
logging_dir=logging_dir, logging_dir=logging_dir,
do_train=True, do_train=True,
do_eval=(eval_dataset is not None), do_eval=(eval_dataset is not None),
evaluation_strategy=("steps" if eval_dataset is not None else "no"), # evaluation_strategy / eval_strategy 通过 **ta_kwargs 传入
eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None, eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
per_device_train_batch_size=args.per_device_train_batch_size, per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
@ -368,7 +383,8 @@ def main():
gradient_checkpointing=args.gradient_checkpointing, gradient_checkpointing=args.gradient_checkpointing,
remove_unused_columns=False, # 需要保留我们的字段 remove_unused_columns=False, # 需要保留我们的字段
torch_compile=False, torch_compile=False,
save_on_each_node=False save_on_each_node=False,
**ta_kwargs, # ← 兼容参数
) )
trainer = Trainer( trainer = Trainer(