This commit is contained in:
hailin 2025-09-24 19:57:13 +08:00
parent 12ec3aa1f4
commit 5a633d4b1c
1 changed files with 108 additions and 25 deletions

View File

@ -27,7 +27,7 @@ from transformers import (
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import get_last_checkpoint
from torch.optim import AdamW as TorchAdamW
# from torch.optim import AdamW as TorchAdamW
from transformers import EarlyStoppingCallback
# ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ====
@ -73,23 +73,23 @@ if shutil.which("ninja") is None:
os.environ["USE_NINJA"] = "0"
print("[env] no CLI ninja on PATH -> USE_NINJA=0 fallback", flush=True)
# 4) 立即验证 ninja 与 CPUAdam 的 JIT若这里失败日志会第一时间告诉你是哪台/哪 rank 环境不对)
try:
from deepspeed.ops.op_builder import CPUAdamBuilder
CPUAdamBuilder().load()
print("[env] CPUAdamBuilder JIT OK", flush=True)
except Exception as e:
# ninja 可执行找不到时走兜底:禁用 ninja用 setuptools 构建(首次会慢一点,但必过)
if "Ninja is required to load C++ extensions" in str(e):
os.environ["USE_NINJA"] = "0"
print("[env] no CLI ninja, retry with USE_NINJA=0 (fallback build)", flush=True)
from deepspeed.ops.op_builder import CPUAdamBuilder
CPUAdamBuilder().load()
print("[env] CPUAdamBuilder JIT OK (fallback)", flush=True)
else:
import socket
print(f"[env][host={socket.gethostname()} RANK={os.environ.get('RANK','?')}] PRE-JIT FAILED: {e}", flush=True)
raise
# # 4) 立即验证 ninja 与 CPUAdam 的 JIT若这里失败日志会第一时间告诉你是哪台/哪 rank 环境不对)
# try:
# from deepspeed.ops.op_builder import CPUAdamBuilder
# CPUAdamBuilder().load()
# print("[env] CPUAdamBuilder JIT OK", flush=True)
# except Exception as e:
# # ninja 可执行找不到时走兜底:禁用 ninja用 setuptools 构建(首次会慢一点,但必过)
# if "Ninja is required to load C++ extensions" in str(e):
# os.environ["USE_NINJA"] = "0"
# print("[env] no CLI ninja, retry with USE_NINJA=0 (fallback build)", flush=True)
# from deepspeed.ops.op_builder import CPUAdamBuilder
# CPUAdamBuilder().load()
# print("[env] CPUAdamBuilder JIT OK (fallback)", flush=True)
# else:
# import socket
# print(f"[env][host={socket.gethostname()} RANK={os.environ.get('RANK','?')}] PRE-JIT FAILED: {e}", flush=True)
# raise
# ----------------- 进程工具 -----------------
@ -128,6 +128,30 @@ class DebugTrainer(Trainer):
# return super().training_step(model, inputs, num_items_in_batch)
def get_train_dataloader(self):
dl = super().get_train_dataloader()
if getattr(self.args, "max_steps", 0) and self.args.max_steps > 0:
RepeatingLoader = None
try:
from deepspeed.utils import RepeatingLoader
except Exception:
try:
from deepspeed.runtime.dataloader import RepeatingLoader
except Exception:
RepeatingLoader = None
if RepeatingLoader is not None:
return RepeatingLoader(dl)
else:
# 纯 Python 兜底
def _infinite(loader):
while True:
for batch in loader:
yield batch
return _infinite(dl)
return dl
# ----------------- 日志回调 -----------------
class CsvLossLogger(TrainerCallback):
def __init__(self, csv_path: str):
@ -178,7 +202,7 @@ class CsvLossLogger(TrainerCallback):
f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n"
)
from typing import List, Tuple, Iterable, Iterator, Dict
# from typing import List, Tuple, Iterable, Iterator, Dict
# ----------------- 仅监督 assistant 内容token-id 级,不用 offsets -----------------
class QwenChatSFTDataset(IterableDataset):
@ -572,6 +596,26 @@ def main():
dbg(f"HfDeepSpeedConfig loaded from {src}")
# ---- DeepSpeed JIT 预检测:仅在启用 DS 时检查 CPUAdam更稳
if use_ds:
try:
from deepspeed.ops.op_builder import CPUAdamBuilder
try:
CPUAdamBuilder().load()
print("[env] CPUAdamBuilder JIT OK", flush=True)
except Exception as e:
if "Ninja is required to load C++ extensions" in str(e):
os.environ["USE_NINJA"] = "0"
print("[env] no CLI ninja, retry with USE_NINJA=0 (fallback build)", flush=True)
CPUAdamBuilder().load()
print("[env] CPUAdamBuilder JIT OK (fallback)", flush=True)
else:
print(f"[env] CPUAdamBuilder pre-JIT failed: {e}", flush=True)
raise
except ImportError:
print("[env] DeepSpeed not installed; skip CPUAdam pre-JIT", flush=True)
if args.report_to == "wandb":
os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
@ -579,8 +623,8 @@ def main():
# 仅在 rank0 预初始化 W&B
is_rank0 = os.environ.get("RANK", "0") == "0" and os.environ.get("LOCAL_RANK", "-1") in ("0", "-1")
if is_rank0:
import wandb
try:
import wandb
# 避免外部遗留的 RUN_ID 强制续跑导致卡住
os.environ.pop("WANDB_RUN_ID", None)
@ -657,7 +701,15 @@ def main():
dbg(f"dist query error: {e}")
# 1) 先补 tokenizer 的 pad
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
# tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
try:
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
except Exception as e:
print(f"[warn] fast tokenizer unavailable ({e}); falling back to slow tokenizer.", flush=True)
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path, use_fast=False, trust_remote_code=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
@ -695,6 +747,10 @@ def main():
return (major, minor) >= (8, 0) # Ampere 及以上
use_bf16 = bool(args.bf16 and _bf16_supported())
if args.bf16 and not use_bf16:
print("[warn] bf16 not supported on this GPU; falling back to fp16/fp32.", flush=True)
dtype = (torch.bfloat16 if use_bf16 else
(torch.float16 if torch.cuda.is_available() else torch.float32))
@ -783,9 +839,36 @@ def main():
assert bool((labs[attn == 0] == -100).all()), "[fatal] padded tokens must have label -100"
def endless_examples(files, base_seed: int, buf: int = 50000):
"""从本地 JSONL 反复流式读取并打乱,形成无限数据流。"""
epoch = 0
while True:
s = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
s = s.shuffle(buffer_size=buf, seed=base_seed + epoch)
if hasattr(s, "set_epoch"):
s.set_epoch(epoch)
for ex in s:
yield ex
epoch += 1
# ====== 正式训练流(不做任何手动分片,交给 Accelerate/Trainer======
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True).shuffle(buffer_size=50000, seed=args.seed)
train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)\
.shuffle(buffer_size=50000, seed=args.seed)
# 先尝试 datasets 的无限流;没有就用我们自己的无限生成器
try:
ds_stream2 = ds_stream2.repeat() # ★ 若可用:官方无限流
ex_iter = (ex for ex in ds_stream2) # ★ 统一用 ex_iter 作为上游
except AttributeError:
ex_iter = endless_examples(files, args.seed, buf=50000) # ★ 兜底:自制无限流
# 关键:这里一定要用 ex_iter而不是重新从 ds_stream2 取一次
train_stream = QwenChatSFTDataset(ex_iter, tokenizer, seq_len=args.seq_len)
# Safety: IterableDataset 需要明确的 max_steps
if (args.max_steps is None) or (args.max_steps <= 0):
raise ValueError("Detected streaming/IterableDataset. Please set --max_steps > 0.")
# ====== 一致性探针(不分片)======
ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
@ -935,7 +1018,7 @@ def main():
save_steps=args.save_steps,
save_total_limit=None,
deepspeed=(args.deepspeed if use_ds else None),
dataloader_drop_last=False,
dataloader_drop_last=True,
dataloader_num_workers=0,
label_smoothing_factor=0.0,
per_device_eval_batch_size=args.per_device_eval_batch_size,