From 5a633d4b1cc7c87031444e1454325ac42434a037 Mon Sep 17 00:00:00 2001 From: hailin Date: Wed, 24 Sep 2025 19:57:13 +0800 Subject: [PATCH] . --- train_sft_ds.py | 133 +++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 108 insertions(+), 25 deletions(-) diff --git a/train_sft_ds.py b/train_sft_ds.py index da7fd8c..45f944d 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -27,7 +27,7 @@ from transformers import ( ) from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import get_last_checkpoint -from torch.optim import AdamW as TorchAdamW +# from torch.optim import AdamW as TorchAdamW from transformers import EarlyStoppingCallback # ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ==== @@ -73,23 +73,23 @@ if shutil.which("ninja") is None: os.environ["USE_NINJA"] = "0" print("[env] no CLI ninja on PATH -> USE_NINJA=0 fallback", flush=True) -# 4) 立即验证 ninja 与 CPUAdam 的 JIT(若这里失败,日志会第一时间告诉你是哪台/哪 rank 环境不对) -try: - from deepspeed.ops.op_builder import CPUAdamBuilder - CPUAdamBuilder().load() - print("[env] CPUAdamBuilder JIT OK", flush=True) -except Exception as e: - # ninja 可执行找不到时走兜底:禁用 ninja,用 setuptools 构建(首次会慢一点,但必过) - if "Ninja is required to load C++ extensions" in str(e): - os.environ["USE_NINJA"] = "0" - print("[env] no CLI ninja, retry with USE_NINJA=0 (fallback build)", flush=True) - from deepspeed.ops.op_builder import CPUAdamBuilder - CPUAdamBuilder().load() - print("[env] CPUAdamBuilder JIT OK (fallback)", flush=True) - else: - import socket - print(f"[env][host={socket.gethostname()} RANK={os.environ.get('RANK','?')}] PRE-JIT FAILED: {e}", flush=True) - raise +# # 4) 立即验证 ninja 与 CPUAdam 的 JIT(若这里失败,日志会第一时间告诉你是哪台/哪 rank 环境不对) +# try: +# from deepspeed.ops.op_builder import CPUAdamBuilder +# CPUAdamBuilder().load() +# print("[env] CPUAdamBuilder JIT OK", flush=True) +# except Exception as e: +# # ninja 可执行找不到时走兜底:禁用 ninja,用 setuptools 构建(首次会慢一点,但必过) +# if "Ninja is required to load C++ extensions" in str(e): +# os.environ["USE_NINJA"] = "0" +# print("[env] no CLI ninja, retry with USE_NINJA=0 (fallback build)", flush=True) +# from deepspeed.ops.op_builder import CPUAdamBuilder +# CPUAdamBuilder().load() +# print("[env] CPUAdamBuilder JIT OK (fallback)", flush=True) +# else: +# import socket +# print(f"[env][host={socket.gethostname()} RANK={os.environ.get('RANK','?')}] PRE-JIT FAILED: {e}", flush=True) +# raise # ----------------- 进程工具 ----------------- @@ -127,7 +127,31 @@ class DebugTrainer(Trainer): return super().training_step(model, inputs) # return super().training_step(model, inputs, num_items_in_batch) - + + def get_train_dataloader(self): + dl = super().get_train_dataloader() + if getattr(self.args, "max_steps", 0) and self.args.max_steps > 0: + RepeatingLoader = None + try: + from deepspeed.utils import RepeatingLoader + except Exception: + try: + from deepspeed.runtime.dataloader import RepeatingLoader + except Exception: + RepeatingLoader = None + + if RepeatingLoader is not None: + return RepeatingLoader(dl) + else: + # 纯 Python 兜底 + def _infinite(loader): + while True: + for batch in loader: + yield batch + return _infinite(dl) + return dl + + # ----------------- 日志回调 ----------------- class CsvLossLogger(TrainerCallback): def __init__(self, csv_path: str): @@ -178,7 +202,7 @@ class CsvLossLogger(TrainerCallback): f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n" ) -from typing import List, Tuple, Iterable, Iterator, Dict +# from typing import List, Tuple, Iterable, Iterator, Dict # ----------------- 仅监督 assistant 内容(token-id 级,不用 offsets) ----------------- class QwenChatSFTDataset(IterableDataset): @@ -572,6 +596,26 @@ def main(): dbg(f"HfDeepSpeedConfig loaded from {src}") + # ---- DeepSpeed JIT 预检测:仅在启用 DS 时检查 CPUAdam(更稳) + if use_ds: + try: + from deepspeed.ops.op_builder import CPUAdamBuilder + try: + CPUAdamBuilder().load() + print("[env] CPUAdamBuilder JIT OK", flush=True) + except Exception as e: + if "Ninja is required to load C++ extensions" in str(e): + os.environ["USE_NINJA"] = "0" + print("[env] no CLI ninja, retry with USE_NINJA=0 (fallback build)", flush=True) + CPUAdamBuilder().load() + print("[env] CPUAdamBuilder JIT OK (fallback)", flush=True) + else: + print(f"[env] CPUAdamBuilder pre-JIT failed: {e}", flush=True) + raise + except ImportError: + print("[env] DeepSpeed not installed; skip CPUAdam pre-JIT", flush=True) + + if args.report_to == "wandb": os.environ.setdefault("WANDB_PROJECT", args.wandb_project) @@ -579,8 +623,8 @@ def main(): # 仅在 rank0 预初始化 W&B is_rank0 = os.environ.get("RANK", "0") == "0" and os.environ.get("LOCAL_RANK", "-1") in ("0", "-1") if is_rank0: - import wandb try: + import wandb # 避免外部遗留的 RUN_ID 强制续跑导致卡住 os.environ.pop("WANDB_RUN_ID", None) @@ -657,7 +701,15 @@ def main(): dbg(f"dist query error: {e}") # 1) 先补 tokenizer 的 pad - tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True) + # tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True) + try: + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True) + except Exception as e: + print(f"[warn] fast tokenizer unavailable ({e}); falling back to slow tokenizer.", flush=True) + tokenizer = AutoTokenizer.from_pretrained( + args.model_name_or_path, use_fast=False, trust_remote_code=True + ) + if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token @@ -695,6 +747,10 @@ def main(): return (major, minor) >= (8, 0) # Ampere 及以上 use_bf16 = bool(args.bf16 and _bf16_supported()) + + if args.bf16 and not use_bf16: + print("[warn] bf16 not supported on this GPU; falling back to fp16/fp32.", flush=True) + dtype = (torch.bfloat16 if use_bf16 else (torch.float16 if torch.cuda.is_available() else torch.float32)) @@ -783,9 +839,36 @@ def main(): assert bool((labs[attn == 0] == -100).all()), "[fatal] padded tokens must have label -100" + def endless_examples(files, base_seed: int, buf: int = 50000): + """从本地 JSONL 反复流式读取并打乱,形成无限数据流。""" + epoch = 0 + while True: + s = load_dataset("json", data_files={"train": files}, split="train", streaming=True) + s = s.shuffle(buffer_size=buf, seed=base_seed + epoch) + if hasattr(s, "set_epoch"): + s.set_epoch(epoch) + for ex in s: + yield ex + epoch += 1 + + # ====== 正式训练流(不做任何手动分片,交给 Accelerate/Trainer)====== - ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True).shuffle(buffer_size=50000, seed=args.seed) - train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len) + ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)\ + .shuffle(buffer_size=50000, seed=args.seed) + + # 先尝试 datasets 的无限流;没有就用我们自己的无限生成器 + try: + ds_stream2 = ds_stream2.repeat() # ★ 若可用:官方无限流 + ex_iter = (ex for ex in ds_stream2) # ★ 统一用 ex_iter 作为上游 + except AttributeError: + ex_iter = endless_examples(files, args.seed, buf=50000) # ★ 兜底:自制无限流 + + # 关键:这里一定要用 ex_iter,而不是重新从 ds_stream2 取一次 + train_stream = QwenChatSFTDataset(ex_iter, tokenizer, seq_len=args.seq_len) + + # Safety: IterableDataset 需要明确的 max_steps + if (args.max_steps is None) or (args.max_steps <= 0): + raise ValueError("Detected streaming/IterableDataset. Please set --max_steps > 0.") # ====== 一致性探针(不分片)====== ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) @@ -935,7 +1018,7 @@ def main(): save_steps=args.save_steps, save_total_limit=None, deepspeed=(args.deepspeed if use_ds else None), - dataloader_drop_last=False, + dataloader_drop_last=True, dataloader_num_workers=0, label_smoothing_factor=0.0, per_device_eval_batch_size=args.per_device_eval_batch_size,