This commit is contained in:
hailin 2025-08-29 22:45:34 +08:00
parent 35f5c85446
commit e8130d9a61
1 changed files with 38 additions and 13 deletions

View File

@ -11,7 +11,7 @@ from typing import Dict, List, Iterable, Iterator, Tuple, Optional
import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset, Dataset
from contextlib import nullcontext
# from contextlib import nullcontext
from datasets import load_dataset
@ -24,7 +24,7 @@ from transformers import (
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import get_last_checkpoint
from transformers.deepspeed import HfDeepSpeedConfig
# ----------------- 进程工具 -----------------
@ -379,14 +379,10 @@ def parse_args():
# ----------------- 主函数 -----------------
def main():
args = parse_args()
set_seed(args.seed)
if args.report_to == "wandb":
os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
# -------- 调试打印工具(每个 rank 都打)--------
host = socket.gethostname()
def dbg(msg):
print(
@ -395,6 +391,34 @@ def main():
flush=True
)
# 是否真的启用 DeepSpeed传了配置文件且文件存在
use_ds = bool(args.deepspeed and os.path.isfile(args.deepspeed))
dschf = None
if use_ds:
try:
from transformers.integrations.deepspeed import HfDeepSpeedConfig
src = "transformers.integrations.deepspeed"
except Exception:
try:
# 备用:部分版本直接从 transformers 暴露
from transformers import HfDeepSpeedConfig
src = "transformers"
except Exception as e:
raise RuntimeError(
"当前 transformers 版本未提供 HfDeepSpeedConfig请升级/降级 transformers") from e
dschf = HfDeepSpeedConfig(args.deepspeed)
dbg(f"HfDeepSpeedConfig loaded from {src}")
if args.report_to == "wandb":
os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
# -------- 调试打印工具(每个 rank 都打)--------
# host = socket.gethostname()
# 版本 & 启动参数 & 关键环境变量
import transformers as hf
try:
@ -485,10 +509,10 @@ def main():
dtype = (torch.bfloat16 if use_bf16 else
(torch.float16 if torch.cuda.is_available() else torch.float32))
dschf = None
if args.deepspeed and os.path.isfile(args.deepspeed):
dschf = HfDeepSpeedConfig(args.deepspeed) # ← 关键:提前启用插件
dbg("HfDeepSpeedConfig loaded")
# dschf = None
# if args.deepspeed and os.path.isfile(args.deepspeed):
# dschf = HfDeepSpeedConfig(args.deepspeed) # ← 关键:提前启用插件
# dbg("HfDeepSpeedConfig loaded")
# try:
@ -554,7 +578,7 @@ def main():
pass
# ===== 数据鲁棒性检查(多机各自执行)=====
host = socket.gethostname()
# host = socket.gethostname()
files = sorted(glob.glob(args.data_glob))
if len(files) == 0:
@ -756,7 +780,8 @@ def main():
logging_steps=args.log_interval,
save_steps=args.save_steps,
save_total_limit=2,
deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None),
# deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None),
deepspeed=(args.deepspeed if use_ds else None),
dataloader_drop_last=False,
dataloader_num_workers=0,
per_device_eval_batch_size=args.per_device_eval_batch_size,