diff --git a/train_sft_ds.py b/train_sft_ds.py index 99824c9..141fee3 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -11,7 +11,7 @@ from typing import Dict, List, Iterable, Iterator, Tuple, Optional import torch import torch.distributed as dist from torch.utils.data import IterableDataset, Dataset -from contextlib import nullcontext +# from contextlib import nullcontext from datasets import load_dataset @@ -24,7 +24,7 @@ from transformers import ( ) from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import get_last_checkpoint -from transformers.deepspeed import HfDeepSpeedConfig + # ----------------- 进程工具 ----------------- @@ -379,14 +379,10 @@ def parse_args(): # ----------------- 主函数 ----------------- def main(): + args = parse_args() set_seed(args.seed) - if args.report_to == "wandb": - os.environ.setdefault("WANDB_PROJECT", args.wandb_project) - - - # -------- 调试打印工具(每个 rank 都打)-------- host = socket.gethostname() def dbg(msg): print( @@ -395,6 +391,34 @@ def main(): flush=True ) + # 是否真的启用 DeepSpeed(传了配置文件且文件存在) + use_ds = bool(args.deepspeed and os.path.isfile(args.deepspeed)) + + dschf = None + if use_ds: + try: + from transformers.integrations.deepspeed import HfDeepSpeedConfig + src = "transformers.integrations.deepspeed" + except Exception: + try: + # 备用:部分版本直接从 transformers 暴露 + from transformers import HfDeepSpeedConfig + src = "transformers" + except Exception as e: + raise RuntimeError( + "当前 transformers 版本未提供 HfDeepSpeedConfig,请升级/降级 transformers") from e + dschf = HfDeepSpeedConfig(args.deepspeed) + dbg(f"HfDeepSpeedConfig loaded from {src}") + + + if args.report_to == "wandb": + os.environ.setdefault("WANDB_PROJECT", args.wandb_project) + + + # -------- 调试打印工具(每个 rank 都打)-------- + # host = socket.gethostname() + + # 版本 & 启动参数 & 关键环境变量 import transformers as hf try: @@ -485,10 +509,10 @@ def main(): dtype = (torch.bfloat16 if use_bf16 else (torch.float16 if torch.cuda.is_available() else torch.float32)) - dschf = None - if args.deepspeed and os.path.isfile(args.deepspeed): - dschf = HfDeepSpeedConfig(args.deepspeed) # ← 关键:提前启用插件 - dbg("HfDeepSpeedConfig loaded") + # dschf = None + # if args.deepspeed and os.path.isfile(args.deepspeed): + # dschf = HfDeepSpeedConfig(args.deepspeed) # ← 关键:提前启用插件 + # dbg("HfDeepSpeedConfig loaded") # try: @@ -554,7 +578,7 @@ def main(): pass # ===== 数据鲁棒性检查(多机各自执行)===== - host = socket.gethostname() + # host = socket.gethostname() files = sorted(glob.glob(args.data_glob)) if len(files) == 0: @@ -756,7 +780,8 @@ def main(): logging_steps=args.log_interval, save_steps=args.save_steps, save_total_limit=2, - deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None), + # deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None), + deepspeed=(args.deepspeed if use_ds else None), dataloader_drop_last=False, dataloader_num_workers=0, per_device_eval_batch_size=args.per_device_eval_batch_size,