This commit is contained in:
parent
35f5c85446
commit
e8130d9a61
|
|
@ -11,7 +11,7 @@ from typing import Dict, List, Iterable, Iterator, Tuple, Optional
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import IterableDataset, Dataset
|
||||
from contextlib import nullcontext
|
||||
# from contextlib import nullcontext
|
||||
|
||||
|
||||
from datasets import load_dataset
|
||||
|
|
@ -24,7 +24,7 @@ from transformers import (
|
|||
)
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.deepspeed import HfDeepSpeedConfig
|
||||
|
||||
|
||||
|
||||
# ----------------- 进程工具 -----------------
|
||||
|
|
@ -379,14 +379,10 @@ def parse_args():
|
|||
|
||||
# ----------------- 主函数 -----------------
|
||||
def main():
|
||||
|
||||
args = parse_args()
|
||||
set_seed(args.seed)
|
||||
|
||||
if args.report_to == "wandb":
|
||||
os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
|
||||
|
||||
|
||||
# -------- 调试打印工具(每个 rank 都打)--------
|
||||
host = socket.gethostname()
|
||||
def dbg(msg):
|
||||
print(
|
||||
|
|
@ -395,6 +391,34 @@ def main():
|
|||
flush=True
|
||||
)
|
||||
|
||||
# 是否真的启用 DeepSpeed(传了配置文件且文件存在)
|
||||
use_ds = bool(args.deepspeed and os.path.isfile(args.deepspeed))
|
||||
|
||||
dschf = None
|
||||
if use_ds:
|
||||
try:
|
||||
from transformers.integrations.deepspeed import HfDeepSpeedConfig
|
||||
src = "transformers.integrations.deepspeed"
|
||||
except Exception:
|
||||
try:
|
||||
# 备用:部分版本直接从 transformers 暴露
|
||||
from transformers import HfDeepSpeedConfig
|
||||
src = "transformers"
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"当前 transformers 版本未提供 HfDeepSpeedConfig,请升级/降级 transformers") from e
|
||||
dschf = HfDeepSpeedConfig(args.deepspeed)
|
||||
dbg(f"HfDeepSpeedConfig loaded from {src}")
|
||||
|
||||
|
||||
if args.report_to == "wandb":
|
||||
os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
|
||||
|
||||
|
||||
# -------- 调试打印工具(每个 rank 都打)--------
|
||||
# host = socket.gethostname()
|
||||
|
||||
|
||||
# 版本 & 启动参数 & 关键环境变量
|
||||
import transformers as hf
|
||||
try:
|
||||
|
|
@ -485,10 +509,10 @@ def main():
|
|||
dtype = (torch.bfloat16 if use_bf16 else
|
||||
(torch.float16 if torch.cuda.is_available() else torch.float32))
|
||||
|
||||
dschf = None
|
||||
if args.deepspeed and os.path.isfile(args.deepspeed):
|
||||
dschf = HfDeepSpeedConfig(args.deepspeed) # ← 关键:提前启用插件
|
||||
dbg("HfDeepSpeedConfig loaded")
|
||||
# dschf = None
|
||||
# if args.deepspeed and os.path.isfile(args.deepspeed):
|
||||
# dschf = HfDeepSpeedConfig(args.deepspeed) # ← 关键:提前启用插件
|
||||
# dbg("HfDeepSpeedConfig loaded")
|
||||
|
||||
|
||||
# try:
|
||||
|
|
@ -554,7 +578,7 @@ def main():
|
|||
pass
|
||||
|
||||
# ===== 数据鲁棒性检查(多机各自执行)=====
|
||||
host = socket.gethostname()
|
||||
# host = socket.gethostname()
|
||||
|
||||
files = sorted(glob.glob(args.data_glob))
|
||||
if len(files) == 0:
|
||||
|
|
@ -756,7 +780,8 @@ def main():
|
|||
logging_steps=args.log_interval,
|
||||
save_steps=args.save_steps,
|
||||
save_total_limit=2,
|
||||
deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None),
|
||||
# deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None),
|
||||
deepspeed=(args.deepspeed if use_ds else None),
|
||||
dataloader_drop_last=False,
|
||||
dataloader_num_workers=0,
|
||||
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
||||
|
|
|
|||
Loading…
Reference in New Issue