This commit is contained in:
parent
12ec3aa1f4
commit
5a633d4b1c
133
train_sft_ds.py
133
train_sft_ds.py
|
|
@ -27,7 +27,7 @@ from transformers import (
|
|||
)
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from torch.optim import AdamW as TorchAdamW
|
||||
# from torch.optim import AdamW as TorchAdamW
|
||||
from transformers import EarlyStoppingCallback
|
||||
|
||||
# ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ====
|
||||
|
|
@ -73,23 +73,23 @@ if shutil.which("ninja") is None:
|
|||
os.environ["USE_NINJA"] = "0"
|
||||
print("[env] no CLI ninja on PATH -> USE_NINJA=0 fallback", flush=True)
|
||||
|
||||
# 4) 立即验证 ninja 与 CPUAdam 的 JIT(若这里失败,日志会第一时间告诉你是哪台/哪 rank 环境不对)
|
||||
try:
|
||||
from deepspeed.ops.op_builder import CPUAdamBuilder
|
||||
CPUAdamBuilder().load()
|
||||
print("[env] CPUAdamBuilder JIT OK", flush=True)
|
||||
except Exception as e:
|
||||
# ninja 可执行找不到时走兜底:禁用 ninja,用 setuptools 构建(首次会慢一点,但必过)
|
||||
if "Ninja is required to load C++ extensions" in str(e):
|
||||
os.environ["USE_NINJA"] = "0"
|
||||
print("[env] no CLI ninja, retry with USE_NINJA=0 (fallback build)", flush=True)
|
||||
from deepspeed.ops.op_builder import CPUAdamBuilder
|
||||
CPUAdamBuilder().load()
|
||||
print("[env] CPUAdamBuilder JIT OK (fallback)", flush=True)
|
||||
else:
|
||||
import socket
|
||||
print(f"[env][host={socket.gethostname()} RANK={os.environ.get('RANK','?')}] PRE-JIT FAILED: {e}", flush=True)
|
||||
raise
|
||||
# # 4) 立即验证 ninja 与 CPUAdam 的 JIT(若这里失败,日志会第一时间告诉你是哪台/哪 rank 环境不对)
|
||||
# try:
|
||||
# from deepspeed.ops.op_builder import CPUAdamBuilder
|
||||
# CPUAdamBuilder().load()
|
||||
# print("[env] CPUAdamBuilder JIT OK", flush=True)
|
||||
# except Exception as e:
|
||||
# # ninja 可执行找不到时走兜底:禁用 ninja,用 setuptools 构建(首次会慢一点,但必过)
|
||||
# if "Ninja is required to load C++ extensions" in str(e):
|
||||
# os.environ["USE_NINJA"] = "0"
|
||||
# print("[env] no CLI ninja, retry with USE_NINJA=0 (fallback build)", flush=True)
|
||||
# from deepspeed.ops.op_builder import CPUAdamBuilder
|
||||
# CPUAdamBuilder().load()
|
||||
# print("[env] CPUAdamBuilder JIT OK (fallback)", flush=True)
|
||||
# else:
|
||||
# import socket
|
||||
# print(f"[env][host={socket.gethostname()} RANK={os.environ.get('RANK','?')}] PRE-JIT FAILED: {e}", flush=True)
|
||||
# raise
|
||||
|
||||
|
||||
# ----------------- 进程工具 -----------------
|
||||
|
|
@ -127,7 +127,31 @@ class DebugTrainer(Trainer):
|
|||
return super().training_step(model, inputs)
|
||||
|
||||
# return super().training_step(model, inputs, num_items_in_batch)
|
||||
|
||||
|
||||
def get_train_dataloader(self):
|
||||
dl = super().get_train_dataloader()
|
||||
if getattr(self.args, "max_steps", 0) and self.args.max_steps > 0:
|
||||
RepeatingLoader = None
|
||||
try:
|
||||
from deepspeed.utils import RepeatingLoader
|
||||
except Exception:
|
||||
try:
|
||||
from deepspeed.runtime.dataloader import RepeatingLoader
|
||||
except Exception:
|
||||
RepeatingLoader = None
|
||||
|
||||
if RepeatingLoader is not None:
|
||||
return RepeatingLoader(dl)
|
||||
else:
|
||||
# 纯 Python 兜底
|
||||
def _infinite(loader):
|
||||
while True:
|
||||
for batch in loader:
|
||||
yield batch
|
||||
return _infinite(dl)
|
||||
return dl
|
||||
|
||||
|
||||
# ----------------- 日志回调 -----------------
|
||||
class CsvLossLogger(TrainerCallback):
|
||||
def __init__(self, csv_path: str):
|
||||
|
|
@ -178,7 +202,7 @@ class CsvLossLogger(TrainerCallback):
|
|||
f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n"
|
||||
)
|
||||
|
||||
from typing import List, Tuple, Iterable, Iterator, Dict
|
||||
# from typing import List, Tuple, Iterable, Iterator, Dict
|
||||
|
||||
# ----------------- 仅监督 assistant 内容(token-id 级,不用 offsets) -----------------
|
||||
class QwenChatSFTDataset(IterableDataset):
|
||||
|
|
@ -572,6 +596,26 @@ def main():
|
|||
dbg(f"HfDeepSpeedConfig loaded from {src}")
|
||||
|
||||
|
||||
# ---- DeepSpeed JIT 预检测:仅在启用 DS 时检查 CPUAdam(更稳)
|
||||
if use_ds:
|
||||
try:
|
||||
from deepspeed.ops.op_builder import CPUAdamBuilder
|
||||
try:
|
||||
CPUAdamBuilder().load()
|
||||
print("[env] CPUAdamBuilder JIT OK", flush=True)
|
||||
except Exception as e:
|
||||
if "Ninja is required to load C++ extensions" in str(e):
|
||||
os.environ["USE_NINJA"] = "0"
|
||||
print("[env] no CLI ninja, retry with USE_NINJA=0 (fallback build)", flush=True)
|
||||
CPUAdamBuilder().load()
|
||||
print("[env] CPUAdamBuilder JIT OK (fallback)", flush=True)
|
||||
else:
|
||||
print(f"[env] CPUAdamBuilder pre-JIT failed: {e}", flush=True)
|
||||
raise
|
||||
except ImportError:
|
||||
print("[env] DeepSpeed not installed; skip CPUAdam pre-JIT", flush=True)
|
||||
|
||||
|
||||
if args.report_to == "wandb":
|
||||
os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
|
||||
|
||||
|
|
@ -579,8 +623,8 @@ def main():
|
|||
# 仅在 rank0 预初始化 W&B
|
||||
is_rank0 = os.environ.get("RANK", "0") == "0" and os.environ.get("LOCAL_RANK", "-1") in ("0", "-1")
|
||||
if is_rank0:
|
||||
import wandb
|
||||
try:
|
||||
import wandb
|
||||
# 避免外部遗留的 RUN_ID 强制续跑导致卡住
|
||||
os.environ.pop("WANDB_RUN_ID", None)
|
||||
|
||||
|
|
@ -657,7 +701,15 @@ def main():
|
|||
dbg(f"dist query error: {e}")
|
||||
|
||||
# 1) 先补 tokenizer 的 pad
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
|
||||
# tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
|
||||
except Exception as e:
|
||||
print(f"[warn] fast tokenizer unavailable ({e}); falling back to slow tokenizer.", flush=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.model_name_or_path, use_fast=False, trust_remote_code=True
|
||||
)
|
||||
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
|
|
@ -695,6 +747,10 @@ def main():
|
|||
return (major, minor) >= (8, 0) # Ampere 及以上
|
||||
|
||||
use_bf16 = bool(args.bf16 and _bf16_supported())
|
||||
|
||||
if args.bf16 and not use_bf16:
|
||||
print("[warn] bf16 not supported on this GPU; falling back to fp16/fp32.", flush=True)
|
||||
|
||||
dtype = (torch.bfloat16 if use_bf16 else
|
||||
(torch.float16 if torch.cuda.is_available() else torch.float32))
|
||||
|
||||
|
|
@ -783,9 +839,36 @@ def main():
|
|||
assert bool((labs[attn == 0] == -100).all()), "[fatal] padded tokens must have label -100"
|
||||
|
||||
|
||||
def endless_examples(files, base_seed: int, buf: int = 50000):
|
||||
"""从本地 JSONL 反复流式读取并打乱,形成无限数据流。"""
|
||||
epoch = 0
|
||||
while True:
|
||||
s = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||||
s = s.shuffle(buffer_size=buf, seed=base_seed + epoch)
|
||||
if hasattr(s, "set_epoch"):
|
||||
s.set_epoch(epoch)
|
||||
for ex in s:
|
||||
yield ex
|
||||
epoch += 1
|
||||
|
||||
|
||||
# ====== 正式训练流(不做任何手动分片,交给 Accelerate/Trainer)======
|
||||
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True).shuffle(buffer_size=50000, seed=args.seed)
|
||||
train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
|
||||
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)\
|
||||
.shuffle(buffer_size=50000, seed=args.seed)
|
||||
|
||||
# 先尝试 datasets 的无限流;没有就用我们自己的无限生成器
|
||||
try:
|
||||
ds_stream2 = ds_stream2.repeat() # ★ 若可用:官方无限流
|
||||
ex_iter = (ex for ex in ds_stream2) # ★ 统一用 ex_iter 作为上游
|
||||
except AttributeError:
|
||||
ex_iter = endless_examples(files, args.seed, buf=50000) # ★ 兜底:自制无限流
|
||||
|
||||
# 关键:这里一定要用 ex_iter,而不是重新从 ds_stream2 取一次
|
||||
train_stream = QwenChatSFTDataset(ex_iter, tokenizer, seq_len=args.seq_len)
|
||||
|
||||
# Safety: IterableDataset 需要明确的 max_steps
|
||||
if (args.max_steps is None) or (args.max_steps <= 0):
|
||||
raise ValueError("Detected streaming/IterableDataset. Please set --max_steps > 0.")
|
||||
|
||||
# ====== 一致性探针(不分片)======
|
||||
ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||||
|
|
@ -935,7 +1018,7 @@ def main():
|
|||
save_steps=args.save_steps,
|
||||
save_total_limit=None,
|
||||
deepspeed=(args.deepspeed if use_ds else None),
|
||||
dataloader_drop_last=False,
|
||||
dataloader_drop_last=True,
|
||||
dataloader_num_workers=0,
|
||||
label_smoothing_factor=0.0,
|
||||
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
||||
|
|
|
|||
Loading…
Reference in New Issue