From 3a45712277fe1a5cc31c2ebdfdcd18c6bf021774 Mon Sep 17 00:00:00 2001 From: hailin Date: Mon, 25 Aug 2025 18:33:25 +0800 Subject: [PATCH] . --- train_sft_ds.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/train_sft_ds.py b/train_sft_ds.py index 10024fa..7147ea3 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -3,6 +3,7 @@ import os import glob import socket import argparse +import inspect from typing import Dict, List, Iterable, Iterator, Tuple, Optional import torch @@ -217,7 +218,7 @@ def parse_args(): ap.add_argument("--eval_data_glob", type=str, default=None, help="(可选) 测试/验证集 jsonl 通配符;如提供则优先使用") ap.add_argument("--local_rank", type=int, default=-1, - help="for deepspeed/torchrun launcher; ignored by user code") + help="for deepspeed/torchrun launcher; ignored by user code") return ap.parse_args() @@ -341,12 +342,26 @@ def main(): logging_dir = os.path.join(args.output_dir, "logs") os.makedirs(logging_dir, exist_ok=True) + # ---- 兼容 4.51(eval_strategy)与旧版(evaluation_strategy) ---- + ta_kwargs = {} + sig = inspect.signature(TrainingArguments.__init__).parameters + if eval_dataset is not None: + if "eval_strategy" in sig: + ta_kwargs["eval_strategy"] = "steps" + elif "evaluation_strategy" in sig: + ta_kwargs["evaluation_strategy"] = "steps" + else: + if "eval_strategy" in sig: + ta_kwargs["eval_strategy"] = "no" + elif "evaluation_strategy" in sig: + ta_kwargs["evaluation_strategy"] = "no" + training_args = TrainingArguments( output_dir=args.output_dir, logging_dir=logging_dir, do_train=True, do_eval=(eval_dataset is not None), - evaluation_strategy=("steps" if eval_dataset is not None else "no"), + # evaluation_strategy / eval_strategy 通过 **ta_kwargs 传入 eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None, per_device_train_batch_size=args.per_device_train_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -368,7 +383,8 @@ def main(): gradient_checkpointing=args.gradient_checkpointing, remove_unused_columns=False, # 需要保留我们的字段 torch_compile=False, - save_on_each_node=False + save_on_each_node=False, + **ta_kwargs, # ← 兼容参数 ) trainer = Trainer(