This commit is contained in:
parent
35f5c85446
commit
e8130d9a61
|
|
@ -11,7 +11,7 @@ from typing import Dict, List, Iterable, Iterator, Tuple, Optional
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.utils.data import IterableDataset, Dataset
|
from torch.utils.data import IterableDataset, Dataset
|
||||||
from contextlib import nullcontext
|
# from contextlib import nullcontext
|
||||||
|
|
||||||
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
@ -24,7 +24,7 @@ from transformers import (
|
||||||
)
|
)
|
||||||
from transformers.trainer_callback import TrainerCallback
|
from transformers.trainer_callback import TrainerCallback
|
||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
from transformers.deepspeed import HfDeepSpeedConfig
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------- 进程工具 -----------------
|
# ----------------- 进程工具 -----------------
|
||||||
|
|
@ -379,14 +379,10 @@ def parse_args():
|
||||||
|
|
||||||
# ----------------- 主函数 -----------------
|
# ----------------- 主函数 -----------------
|
||||||
def main():
|
def main():
|
||||||
|
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
set_seed(args.seed)
|
set_seed(args.seed)
|
||||||
|
|
||||||
if args.report_to == "wandb":
|
|
||||||
os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
|
|
||||||
|
|
||||||
|
|
||||||
# -------- 调试打印工具(每个 rank 都打)--------
|
|
||||||
host = socket.gethostname()
|
host = socket.gethostname()
|
||||||
def dbg(msg):
|
def dbg(msg):
|
||||||
print(
|
print(
|
||||||
|
|
@ -395,6 +391,34 @@ def main():
|
||||||
flush=True
|
flush=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 是否真的启用 DeepSpeed(传了配置文件且文件存在)
|
||||||
|
use_ds = bool(args.deepspeed and os.path.isfile(args.deepspeed))
|
||||||
|
|
||||||
|
dschf = None
|
||||||
|
if use_ds:
|
||||||
|
try:
|
||||||
|
from transformers.integrations.deepspeed import HfDeepSpeedConfig
|
||||||
|
src = "transformers.integrations.deepspeed"
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
# 备用:部分版本直接从 transformers 暴露
|
||||||
|
from transformers import HfDeepSpeedConfig
|
||||||
|
src = "transformers"
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
"当前 transformers 版本未提供 HfDeepSpeedConfig,请升级/降级 transformers") from e
|
||||||
|
dschf = HfDeepSpeedConfig(args.deepspeed)
|
||||||
|
dbg(f"HfDeepSpeedConfig loaded from {src}")
|
||||||
|
|
||||||
|
|
||||||
|
if args.report_to == "wandb":
|
||||||
|
os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
|
||||||
|
|
||||||
|
|
||||||
|
# -------- 调试打印工具(每个 rank 都打)--------
|
||||||
|
# host = socket.gethostname()
|
||||||
|
|
||||||
|
|
||||||
# 版本 & 启动参数 & 关键环境变量
|
# 版本 & 启动参数 & 关键环境变量
|
||||||
import transformers as hf
|
import transformers as hf
|
||||||
try:
|
try:
|
||||||
|
|
@ -485,10 +509,10 @@ def main():
|
||||||
dtype = (torch.bfloat16 if use_bf16 else
|
dtype = (torch.bfloat16 if use_bf16 else
|
||||||
(torch.float16 if torch.cuda.is_available() else torch.float32))
|
(torch.float16 if torch.cuda.is_available() else torch.float32))
|
||||||
|
|
||||||
dschf = None
|
# dschf = None
|
||||||
if args.deepspeed and os.path.isfile(args.deepspeed):
|
# if args.deepspeed and os.path.isfile(args.deepspeed):
|
||||||
dschf = HfDeepSpeedConfig(args.deepspeed) # ← 关键:提前启用插件
|
# dschf = HfDeepSpeedConfig(args.deepspeed) # ← 关键:提前启用插件
|
||||||
dbg("HfDeepSpeedConfig loaded")
|
# dbg("HfDeepSpeedConfig loaded")
|
||||||
|
|
||||||
|
|
||||||
# try:
|
# try:
|
||||||
|
|
@ -554,7 +578,7 @@ def main():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# ===== 数据鲁棒性检查(多机各自执行)=====
|
# ===== 数据鲁棒性检查(多机各自执行)=====
|
||||||
host = socket.gethostname()
|
# host = socket.gethostname()
|
||||||
|
|
||||||
files = sorted(glob.glob(args.data_glob))
|
files = sorted(glob.glob(args.data_glob))
|
||||||
if len(files) == 0:
|
if len(files) == 0:
|
||||||
|
|
@ -756,7 +780,8 @@ def main():
|
||||||
logging_steps=args.log_interval,
|
logging_steps=args.log_interval,
|
||||||
save_steps=args.save_steps,
|
save_steps=args.save_steps,
|
||||||
save_total_limit=2,
|
save_total_limit=2,
|
||||||
deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None),
|
# deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None),
|
||||||
|
deepspeed=(args.deepspeed if use_ds else None),
|
||||||
dataloader_drop_last=False,
|
dataloader_drop_last=False,
|
||||||
dataloader_num_workers=0,
|
dataloader_num_workers=0,
|
||||||
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue