This commit is contained in:
hailin 2025-08-29 22:45:34 +08:00
parent 35f5c85446
commit e8130d9a61
1 changed files with 38 additions and 13 deletions

View File

@ -11,7 +11,7 @@ from typing import Dict, List, Iterable, Iterator, Tuple, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.utils.data import IterableDataset, Dataset from torch.utils.data import IterableDataset, Dataset
from contextlib import nullcontext # from contextlib import nullcontext
from datasets import load_dataset from datasets import load_dataset
@ -24,7 +24,7 @@ from transformers import (
) )
from transformers.trainer_callback import TrainerCallback from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from transformers.deepspeed import HfDeepSpeedConfig
# ----------------- 进程工具 ----------------- # ----------------- 进程工具 -----------------
@ -379,14 +379,10 @@ def parse_args():
# ----------------- 主函数 ----------------- # ----------------- 主函数 -----------------
def main(): def main():
args = parse_args() args = parse_args()
set_seed(args.seed) set_seed(args.seed)
if args.report_to == "wandb":
os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
# -------- 调试打印工具(每个 rank 都打)--------
host = socket.gethostname() host = socket.gethostname()
def dbg(msg): def dbg(msg):
print( print(
@ -395,6 +391,34 @@ def main():
flush=True flush=True
) )
# 是否真的启用 DeepSpeed传了配置文件且文件存在
use_ds = bool(args.deepspeed and os.path.isfile(args.deepspeed))
dschf = None
if use_ds:
try:
from transformers.integrations.deepspeed import HfDeepSpeedConfig
src = "transformers.integrations.deepspeed"
except Exception:
try:
# 备用:部分版本直接从 transformers 暴露
from transformers import HfDeepSpeedConfig
src = "transformers"
except Exception as e:
raise RuntimeError(
"当前 transformers 版本未提供 HfDeepSpeedConfig请升级/降级 transformers") from e
dschf = HfDeepSpeedConfig(args.deepspeed)
dbg(f"HfDeepSpeedConfig loaded from {src}")
if args.report_to == "wandb":
os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
# -------- 调试打印工具(每个 rank 都打)--------
# host = socket.gethostname()
# 版本 & 启动参数 & 关键环境变量 # 版本 & 启动参数 & 关键环境变量
import transformers as hf import transformers as hf
try: try:
@ -485,10 +509,10 @@ def main():
dtype = (torch.bfloat16 if use_bf16 else dtype = (torch.bfloat16 if use_bf16 else
(torch.float16 if torch.cuda.is_available() else torch.float32)) (torch.float16 if torch.cuda.is_available() else torch.float32))
dschf = None # dschf = None
if args.deepspeed and os.path.isfile(args.deepspeed): # if args.deepspeed and os.path.isfile(args.deepspeed):
dschf = HfDeepSpeedConfig(args.deepspeed) # ← 关键:提前启用插件 # dschf = HfDeepSpeedConfig(args.deepspeed) # ← 关键:提前启用插件
dbg("HfDeepSpeedConfig loaded") # dbg("HfDeepSpeedConfig loaded")
# try: # try:
@ -554,7 +578,7 @@ def main():
pass pass
# ===== 数据鲁棒性检查(多机各自执行)===== # ===== 数据鲁棒性检查(多机各自执行)=====
host = socket.gethostname() # host = socket.gethostname()
files = sorted(glob.glob(args.data_glob)) files = sorted(glob.glob(args.data_glob))
if len(files) == 0: if len(files) == 0:
@ -756,7 +780,8 @@ def main():
logging_steps=args.log_interval, logging_steps=args.log_interval,
save_steps=args.save_steps, save_steps=args.save_steps,
save_total_limit=2, save_total_limit=2,
deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None), # deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None),
deepspeed=(args.deepspeed if use_ds else None),
dataloader_drop_last=False, dataloader_drop_last=False,
dataloader_num_workers=0, dataloader_num_workers=0,
per_device_eval_batch_size=args.per_device_eval_batch_size, per_device_eval_batch_size=args.per_device_eval_batch_size,