jd_train/train_sft_ds.py

1090 lines
45 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
os.environ.pop("PYTHONNOUSERSITE", None)
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("WANDB_START_METHOD", "thread")
os.environ.setdefault("WANDB_DIR", f"/tmp/{os.environ.get('USER','user')}/wandb")
os.environ.setdefault("WANDB_BASE_URL", "https://wandb.szaiai.com")
os.environ.setdefault("WANDB_INIT_TIMEOUT", "300")
import glob
import socket
import argparse
import inspect
import sys
from typing import Dict, List, Iterable, Iterator, Tuple, Optional
import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset, Dataset
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
set_seed
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import get_last_checkpoint
from torch.optim import AdamW as TorchAdamW
from transformers import EarlyStoppingCallback
# ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ====
import site, shutil
home = os.path.expanduser("~")
want = [f"{home}/.local/bin", "/usr/local/cuda-11.8/bin"]
cur = os.environ.get("PATH", "").split(":")
new = [d for d in want if d and d not in cur] + cur
os.environ["PATH"] = ":".join(new)
# 可见性打印,方便你在日志里确认 tn06 是否拿到了
print(f"[env] PATH={os.environ['PATH']}", flush=True)
print(f"[env] which ninja={shutil.which('ninja')} which nvcc={shutil.which('nvcc')}", flush=True)
os.environ.setdefault("CUDA_HOME", "/usr/local/cuda-11.8")
ld = os.environ.get("LD_LIBRARY_PATH", "")
cuda_lib = "/usr/local/cuda-11.8/lib64"
if cuda_lib not in ld.split(":"):
os.environ["LD_LIBRARY_PATH"] = f"{cuda_lib}:{ld}" if ld else cuda_lib
# 可视化确认
print(f"[env] torch.version.cuda={torch.version.cuda} CUDA_HOME={os.environ['CUDA_HOME']}", flush=True)
# 1) 确保不会屏蔽用户站点包ninja 安在 ~/.local 里)
os.environ.pop("DS_BUILD_OPS", None)
os.environ.pop("DS_SKIP_CUDA_BUILD", None)
# 2) 把用户站点目录插入 sys.path比如 /home/test/.local/lib/python3.10/site-packages
try:
user_site = site.getusersitepackages()
if user_site and user_site not in sys.path:
sys.path.insert(0, user_site)
except Exception:
pass
# 3) 统一 JIT 缓存目录(可选,但更稳;日志里你现在用的是 ~/.cache/torch_extensions
os.environ.setdefault("TORCH_EXTENSIONS_DIR", f"/tmp/{os.environ.get('USER','user')}/torch_ext")
os.environ.setdefault("MAX_JOBS", "12")
if shutil.which("ninja") is None:
os.environ["USE_NINJA"] = "0"
print("[env] no CLI ninja on PATH -> USE_NINJA=0 fallback", flush=True)
# 4) 立即验证 ninja 与 CPUAdam 的 JIT若这里失败日志会第一时间告诉你是哪台/哪 rank 环境不对)
try:
from deepspeed.ops.op_builder import CPUAdamBuilder
CPUAdamBuilder().load()
print("[env] CPUAdamBuilder JIT OK", flush=True)
except Exception as e:
# ninja 可执行找不到时走兜底:禁用 ninja用 setuptools 构建(首次会慢一点,但必过)
if "Ninja is required to load C++ extensions" in str(e):
os.environ["USE_NINJA"] = "0"
print("[env] no CLI ninja, retry with USE_NINJA=0 (fallback build)", flush=True)
from deepspeed.ops.op_builder import CPUAdamBuilder
CPUAdamBuilder().load()
print("[env] CPUAdamBuilder JIT OK (fallback)", flush=True)
else:
import socket
print(f"[env][host={socket.gethostname()} RANK={os.environ.get('RANK','?')}] PRE-JIT FAILED: {e}", flush=True)
raise
# ----------------- 进程工具 -----------------
def is_main_process():
return int(os.environ.get("RANK", "0")) == 0
def print_once(*args, **kwargs):
if is_main_process():
print(*args, **kwargs, flush=True)
class DebugTrainer(Trainer):
def training_step(self, model, inputs, num_items_in_batch=None):
if not hasattr(self, "_dbg_printed"):
rank = int(os.environ.get("RANK", "0"))
host = socket.gethostname()
ids = inputs["input_ids"]
msk = inputs["attention_mask"]
labs = inputs["labels"]
print(f"[step0] ids={ids.device} mask={msk.device} labs={labs.device} "
f"supervised={(labs!=-100).sum().item()}",
flush=True)
print(
f"[step0][host={host} RANK={rank}] "
f"input_ids.shape={tuple(ids.shape)} "
f"attention_mask.shape={tuple(msk.shape)} "
f"labels.shape={tuple(labs.shape)} "
f"num_items_in_batch={num_items_in_batch}",
flush=True
)
self._dbg_printed = True
try:
return super().training_step(model, inputs, num_items_in_batch=num_items_in_batch)
except TypeError:
return super().training_step(model, inputs)
# return super().training_step(model, inputs, num_items_in_batch)
# ----------------- 日志回调 -----------------
class CsvLossLogger(TrainerCallback):
def __init__(self, csv_path: str):
self.csv_path = csv_path
if is_main_process():
os.makedirs(os.path.dirname(csv_path), exist_ok=True)
with open(self.csv_path, "w", encoding="utf-8") as f:
f.write("step,loss,lr,total_flos\n")
def on_train_begin(self, args, state, control, **kwargs):
tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0)
tot = tmp if isinstance(tmp, int) and tmp > 0 else 0
rank = os.environ.get("RANK", "?")
host = socket.gethostname()
print(f"[{host} rank={rank}] total_steps={tot}", flush=True)
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is None:
return
# ---- 控制台打印:所有 rank 都打当前步/总步 ----
cur = int(getattr(state, "global_step", 0) or 0)
tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0)
tot = tmp if isinstance(tmp, int) and tmp > 0 else 0
pct = (f"{(cur / tot * 100):.1f}%" if tot else "n/a")
# —— tot 一旦可用,就再宣布一次总步数(只打印一次)
if tot and not hasattr(self, "_tot_announced"):
print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] total_steps={tot}", flush=True)
self._tot_announced = True
rank = os.environ.get("RANK", "?")
host = socket.gethostname()
print(
f"[{host} rank={rank}] step {cur}/{tot} ({pct}) "
f"loss={logs.get('loss')} lr={logs.get('learning_rate')}",
flush=True
)
# ---- 只在主进程写 CSV避免并发写 ----
if not is_main_process():
return
with open(self.csv_path, "a", encoding="utf-8") as f:
f.write(
f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n"
)
from typing import List, Tuple, Iterable, Iterator, Dict
# ----------------- 仅监督 assistant 内容token-id 级,不用 offsets -----------------
class QwenChatSFTDataset(IterableDataset):
"""
- 通过 chat_template 得到 token ids
- 以 special token id 定位 assistant 片段(<|im_start|>assistant\n ... <|im_end|>
- 只监督 assistant 内容本体;默认把 <think>…</think>(含标签)整体屏蔽
- 超长时保最后一个 assistant 片段完整,左侧补齐到 seq_len
"""
def __init__(self,
ex_iter: Iterable[dict],
tokenizer: AutoTokenizer,
seq_len: int = 4096,
mask_think_and_tags: bool = True):
self.ex_iter = ex_iter
self.tok = tokenizer
self.seq_len = seq_len
self.mask_think_and_tags = mask_think_and_tags
# 关键标记的 token 序列
self.id_START = self.tok.convert_tokens_to_ids("<|im_start|>")
self.id_END = self.tok.convert_tokens_to_ids("<|im_end|>")
# self.ids_ASSISTANT_NL = self.tok.encode("assistant\n", add_special_tokens=False)
# 支持两种常见写法:'assistant\\n' 或 'assistant'
self.ids_ASSISTANT_CANDIDATES = [
self.tok.encode("assistant\n", add_special_tokens=False),
self.tok.encode("assistant", add_special_tokens=False),
]
# 过滤空候选(极端 tokenizer 配置)
self.ids_ASSISTANT_CANDIDATES = [c for c in self.ids_ASSISTANT_CANDIDATES if len(c) > 0]
if not self.ids_ASSISTANT_CANDIDATES:
raise RuntimeError("[fatal] no valid 'assistant' role token sequence found; check chat template/tokenizer.")
self.ids_THINK_OPEN = self.tok.encode("<think>", add_special_tokens=False)
self.ids_THINK_CLOSE = self.tok.encode("</think>", add_special_tokens=False)
# 兜底:有些模型未注册这些特殊 id 时,直接 fail-fast
for name, val in {
"id_START": self.id_START, "id_END": self.id_END
}.items():
if val is None or val == self.tok.unk_token_id:
raise RuntimeError(f"[fatal] tokenizer missing special token id for {name}")
@staticmethod
def _find_subseq(hay: list, needle: list, start: int) -> int:
n = len(needle)
if n == 0: return start
for i in range(start, len(hay) - n + 1):
if hay[i:i+n] == needle:
return i
return -1
def _find_role_after_start(self, ids, j_start: int) -> Optional[Tuple[int, int]]:
"""
从 j_start 开始,尝试匹配任一 'assistant' 角色 token 序列。
返回 (pos, length);匹配失败返回 None。
"""
for cand in self.ids_ASSISTANT_CANDIDATES:
pos = self._find_subseq(ids, cand, j_start)
if pos == j_start:
return (pos, len(cand))
return None
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
# 调试开关
dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1"
dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3"))
seen = 0
host = socket.gethostname()
rank = int(os.environ.get("RANK", "0"))
lrank = int(os.environ.get("LOCAL_RANK", "-1"))
for ex in self.ex_iter:
msgs = ex.get("messages")
if not msgs or not isinstance(msgs, list):
continue
tools = ex.get("tools", None)
# 直接让模板 tokenization -> ids避免 offset 落坑)
try:
ids = self.tok.apply_chat_template(
msgs, tools=tools, add_generation_prompt=False,
tokenize=True, return_tensors=None
)
# 兼容老版本返回 dict 的情况
if isinstance(ids, dict):
ids = ids["input_ids"]
except TypeError:
# 极端回退:先渲染字符串再手动分词
rendered: str = self.tok.apply_chat_template(
msgs, add_generation_prompt=False, tokenize=False
)
ids = self.tok(rendered, add_special_tokens=False)["input_ids"]
if not ids:
continue
# 构建监督掩码0/1
mask = [0] * len(ids)
i = 0
while i < len(ids):
# 找到一个 <|im_start|>
try:
a = ids.index(self.id_START, i)
except ValueError:
break
# 必须是 assistant 角色(兼容 'assistant\\n' 或 'assistant'
j = a + 1
role_match = self._find_role_after_start(ids, j)
if role_match is None:
i = a + 1
continue
_, role_len = role_match
content_lo = j + role_len # 跳过角色 token 序列
# 找匹配的 <|im_end|>
try:
b = ids.index(self.id_END, content_lo)
except ValueError:
# 不闭合就放弃这个片段
i = a + 1
continue
content_hi = b # 不含 END
# 先把整个内容区间标 1监督
for t in range(content_lo, content_hi):
mask[t] = 1
# 可选:把 <think>…</think>(含标签)整体屏蔽
if self.mask_think_and_tags:
p = content_lo
while True:
o = self._find_subseq(ids, self.ids_THINK_OPEN, p)
if o == -1 or o >= content_hi:
break
c = self._find_subseq(ids, self.ids_THINK_CLOSE, o + len(self.ids_THINK_OPEN))
if c == -1 or c > content_hi:
break
x_lo = o # 含 <think>
x_hi = c + len(self.ids_THINK_CLOSE) # 含 </think>
for t in range(x_lo, min(x_hi, content_hi)):
mask[t] = 0
p = x_hi
# 继续找下一个片段
i = b + 1
# 如果没有任何可监督 token跳过
if not any(mask):
continue
# ======== 截断策略:优先保留“最后一个被监督 token”为终点 ========
if len(ids) > self.seq_len:
last_on = max(idx for idx, v in enumerate(mask) if v == 1)
end = min(len(ids), last_on + 1)
start = max(0, end - self.seq_len)
ids = ids[start:end]
mask = mask[start:end]
# ======== 左侧 pad ========
pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id
L = len(ids)
if L < self.seq_len:
pad = self.seq_len - L
input_ids = [pad_id] * pad + ids
attention_mask = [0] * pad + [1] * L
labels = [-100] * pad + [tok if m == 1 else -100 for tok, m in zip(ids, mask)]
else:
input_ids = ids
attention_mask = [1] * self.seq_len
labels = [tok if m == 1 else -100 for tok, m in zip(ids, mask)]
# >>> 调试打印(可选)
if dbg_on and seen < dbg_limit:
sup_tok = sum(1 for v in labels if v != -100)
print(
f"[sample][host={host} RANK={rank} LRank={lrank}] "
f"toks={len(input_ids)} sup_toks={sup_tok} "
f"seq_len={self.seq_len} pad_id={pad_id}",
flush=True
)
seen += 1
yield {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long),
}
# ----------------- Collator保持与上游一致pad->label=-100, attn=0 -----------------
class SFTDataCollator:
def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None):
self.tok = tokenizer
self.pad_to_length = pad_to_length
assert self.tok.pad_token_id is not None, "tokenizer.pad_token_id must be set"
def __call__(self, features):
if not features:
raise RuntimeError("Empty batch passed to collator")
def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x)
input_ids = [_to_list(f["input_ids"]) for f in features]
attn_masks = [_to_list(f["attention_mask"]) for f in features]
labels_list = [_to_list(f["labels"]) for f in features]
max_len_in_batch = max(len(x) for x in input_ids)
target_len = self.pad_to_length if self.pad_to_length is not None else max_len_in_batch
pad_id = self.tok.pad_token_id
batch_inp, batch_attn, batch_lab = [], [], []
for inp, msk, lab in zip(input_ids, attn_masks, labels_list):
pad_len = target_len - len(inp)
if pad_len < 0:
inp, msk, lab = inp[:target_len], msk[:target_len], lab[:target_len]
pad_len = 0
batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long))
batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long))
batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long))
return {
"input_ids": torch.stack(batch_inp, dim=0),
"attention_mask": torch.stack(batch_attn, dim=0),
"labels": torch.stack(batch_lab, dim=0),
}
# ----------------- 参数 -----------------
def parse_args():
ap = argparse.ArgumentParser()
ap.add_argument("--model_name_or_path", type=str, required=True,
help="本地权重目录或 HF 名称(如 /home/test/Qwen3-8B")
ap.add_argument("--data_glob", type=str, required=True,
help="本地 jsonl 通配符(每台机器都需有同路径数据;每行应含 messages/可选 tools")
ap.add_argument("--output_dir", type=str, required=True,
help="本地输出目录(各节点各自本地写)")
ap.add_argument("--seq_len", type=int, default=4096)
ap.add_argument("--learning_rate", type=float, default=2e-5)
ap.add_argument("--weight_decay", type=float, default=0.1)
ap.add_argument("--warmup_ratio", type=float, default=0.02)
ap.add_argument("--num_train_epochs", type=float, default=1.0)
ap.add_argument("--max_steps", type=int, default=-1)
ap.add_argument("--log_interval", type=int, default=10)
ap.add_argument("--save_steps", type=int, default=500)
ap.add_argument("--eval_ratio", type=float, default=0.0)
ap.add_argument("--seed", type=int, default=1337)
ap.add_argument("--gradient_checkpointing", action="store_true")
ap.add_argument("--bf16", action="store_true",
help="3090/A100/H100 等可开 bf16同时在 DS 配置里也要开")
ap.add_argument("--per_device_train_batch_size", type=int, default=1)
ap.add_argument("--gradient_accumulation_steps", type=int, default=64)
ap.add_argument("--report_to", type=str, default="tensorboard",
choices=["none","tensorboard","wandb"])
ap.add_argument("--wandb_project", type=str, default="ds-qwen3")
ap.add_argument("--eval_data_glob", type=str, default=None,
help="(可选) 测试/验证集 jsonl 通配符;如提供则优先使用")
ap.add_argument("--local_rank", type=int, default=-1,
help="for deepspeed/torchrun launcher; ignored by user code")
ap.add_argument("--per_device_eval_batch_size", type=int, default=1)
ap.add_argument("--deepspeed", type=str, default=None)
ap.add_argument("--eval_steps", type=int, default=10,
help="Evaluate every N optimizer steps when eval_dataset is provided")
ap.add_argument("--load_best_model_at_end", action="store_true",
help="训练结束时自动加载最优 checkpoint")
ap.add_argument("--metric_for_best_model", type=str, default="eval_loss",
help="用哪个指标选最优,默认 eval_loss")
ap.add_argument("--greater_is_better", action="store_true",
help="是否指标越大越好eval_loss 用 False默认不传即可")
ap.add_argument("--early_stopping_patience", type=int, default=0,
help=">0 启用早停;单位是 eval 轮次数(非 step 数)")
ap.add_argument("--early_stopping_threshold", type=float, default=0.0,
help="改进阈值0 表示严格变好才算改进")
return ap.parse_args()
# ----------------- 主函数 -----------------
def main():
args = parse_args()
# ✅ 只有 rank0 用 wandb其它 rank 不上报
if os.environ.get("RANK", "0") != "0" and args.report_to == "wandb":
print(f"[rank {os.environ.get('RANK')}] force report_to=none", flush=True)
args.report_to = "none"
set_seed(args.seed)
host = socket.gethostname()
# ==== colored dbg (robust default to info) ====
try:
import colorama
colorama.just_fix_windows_console()
except Exception:
pass
def _use_color() -> bool:
if os.environ.get("NO_COLOR"): return False
if os.environ.get("FORCE_COLOR"): return True
return sys.stdout.isatty()
class _C:
reset = "\033[0m"
gray = "\033[90m"
green = "\033[32m"
yellow = "\033[33m"
red = "\033[31m"
cyan = "\033[36m"
def _paint(s, color):
return f"{color}{s}{_C.reset}" if _use_color() else s
_LEVEL_ALIAS = {
"": "info", None: "info",
"ok": "ok", "success": "ok", "pass": "ok",
"warn": "warn", "warning": "warn",
"err": "err", "error": "err", "fatal": "err", "fail": "err",
"info": "info", "information": "info"
}
_LEVEL_COLOR = {
"ok": _C.green,
"warn": _C.yellow,
"err": _C.red,
"info": _C.cyan,
}
def _norm_level(level) -> str:
# 默认 info
if level is None:
return "info"
# 数字等级兼容(类似 logging
if isinstance(level, (int, float)):
if level >= 40: return "err"
if level >= 30: return "warn"
return "info"
# 字符串别名
if isinstance(level, str):
key = level.strip().lower()
return _LEVEL_ALIAS.get(key, "info")
return "info"
def dbg(msg, level=None):
lvl = _norm_level(level) # 未指定/非法 -> "info"
host = socket.gethostname()
rank = os.environ.get("RANK", "0")
lrank = os.environ.get("LOCAL_RANK", "-1")
prefix = f"[dbg][host={host} RANK={rank} LOCAL_RANK={lrank}] "
color = _LEVEL_COLOR.get(lvl, _C.cyan)
print(_paint(prefix, _C.gray) + _paint(str(msg), color), flush=True)
# 便捷别名(可选)
def dbg_ok(m): dbg(m, "ok")
def dbg_warn(m): dbg(m, "warn")
def dbg_err(m):
s = _paint(f"[dbg]{m}", _C.red)
print(s, flush=True, file=sys.stderr)
# def dbg(msg):
# print(
# f"[dbg][host={host} RANK={os.environ.get('RANK','0')} "
# f"LOCAL_RANK={os.environ.get('LOCAL_RANK', str(args.local_rank))}] {msg}",
# flush=True
# )
# 是否真的启用 DeepSpeed传了配置文件且文件存在
use_ds = bool(args.deepspeed and os.path.isfile(args.deepspeed))
dschf = None
if use_ds:
try:
from transformers.integrations.deepspeed import HfDeepSpeedConfig
src = "transformers.integrations.deepspeed"
except Exception:
try:
# 备用:部分版本直接从 transformers 暴露
from transformers import HfDeepSpeedConfig
src = "transformers"
except Exception as e:
raise RuntimeError(
"当前 transformers 版本未提供 HfDeepSpeedConfig请升级/降级 transformers") from e
dschf = HfDeepSpeedConfig(args.deepspeed)
dbg(f"HfDeepSpeedConfig loaded from {src}")
if args.report_to == "wandb":
os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
# 仅在 rank0 预初始化 W&B
is_rank0 = os.environ.get("RANK", "0") == "0" and os.environ.get("LOCAL_RANK", "-1") in ("0", "-1")
if is_rank0:
import wandb
try:
# 避免外部遗留的 RUN_ID 强制续跑导致卡住
os.environ.pop("WANDB_RUN_ID", None)
# 可选字段从环境注入(有就用)
extra = {}
if os.getenv("WANDB_NAME"): extra["name"] = os.getenv("WANDB_NAME")
if os.getenv("WANDB_GROUP"): extra["group"] = os.getenv("WANDB_GROUP")
if os.getenv("WANDB_RESUME"): extra["resume"] = os.getenv("WANDB_RESUME") # 建议 'allow'
run = wandb.init(
project=args.wandb_project,
entity=os.getenv("WANDB_ENTITY") or os.getenv("WB_ENTITY") or "hailin",
settings=wandb.Settings(
base_url=os.getenv("WANDB_BASE_URL", "https://wandb.szaiai.com"),
init_timeout=int(os.getenv("WANDB_INIT_TIMEOUT", "300")),
),
**extra,
)
print(f"[wandb] run url: {getattr(run, 'url', '(n/a)')}", flush=True)
except Exception as e:
print(f"[wandb] init failed -> disable logging, reason={e}", flush=True)
os.environ["WANDB_DISABLED"] = "true"
args.report_to = "none"
else:
os.environ["WANDB_DISABLED"] = "true"
# 版本 & 启动参数 & 关键环境变量
import transformers as hf
try:
import deepspeed as ds
ds_ver = ds.__version__
except Exception:
ds_ver = "n/a"
dbg(f"torch={torch.__version__}, transformers={hf.__version__}, deepspeed={ds_ver}")
dbg(f"args={args}")
dbg("ENV: WORLD_SIZE=%s RANK=%s LOCAL_RANK=%s MASTER_ADDR=%s MASTER_PORT=%s CUDA_VISIBLE_DEVICES=%s" % (
os.environ.get("WORLD_SIZE"),
os.environ.get("RANK"),
os.environ.get("LOCAL_RANK", str(args.local_rank)),
os.environ.get("MASTER_ADDR"),
os.environ.get("MASTER_PORT"),
os.environ.get("CUDA_VISIBLE_DEVICES"),
))
dbg(f"cuda_available={torch.cuda.is_available()} device_count={torch.cuda.device_count()}")
# ---- 初始化分布式(供一致性探针使用)----
world_size = int(os.environ.get("WORLD_SIZE", "1"))
rank = int(os.environ.get("RANK", "0"))
local_rank = int(os.environ.get("LOCAL_RANK", str(args.local_rank)))
dbg(f"pre-init: world_size={world_size}, rank={rank}, local_rank={local_rank}")
if torch.cuda.is_available() and local_rank >= 0:
torch.cuda.set_device(local_rank)
dbg(f"set_device({local_rank}); current_device={torch.cuda.current_device()} "
f"name={torch.cuda.get_device_name(torch.cuda.current_device())}")
else:
dbg("no cuda or invalid local_rank; not calling set_device")
if world_size > 1 and dist.is_available() and not dist.is_initialized():
backend = "nccl" if torch.cuda.is_available() else "gloo"
dbg(f"init_process_group backend={backend} via env://")
dist.init_process_group(backend=backend, init_method="env://")
else:
dbg(f"skip init_process_group: world_size>1? {world_size>1}, dist_available={dist.is_available()}, already_init={dist.is_initialized()}")
if dist.is_available() and dist.is_initialized():
try:
dbg(f"dist.get_backend()={dist.get_backend()} "
f"dist.get_world_size()={dist.get_world_size()} dist.get_rank()={dist.get_rank()}")
except Exception as e:
dbg(f"dist query error: {e}")
# 1) 先补 tokenizer 的 pad
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 左侧补齐以匹配 Dataset 的左 pad 策略
try:
if getattr(tokenizer, "padding_side", None) != "left":
tokenizer.padding_side = "left"
except Exception:
pass
# 强制要求 fast tokenizeroffset_mapping 依赖 fast
# from transformers import PreTrainedTokenizerFast
# if not isinstance(tokenizer, PreTrainedTokenizerFast) or not getattr(tokenizer, "is_fast", False):
# raise RuntimeError("需要 *Fast* tokenizer 以获取 offset_mapping请安装 tokenizers>=0.14 并使用对应 Fast 版分词器。")
# 建议使用 fast 分词器(更快);不再依赖 offset_mapping
if not getattr(tokenizer, "is_fast", False):
print("[warn] using a slow tokenizer; masks are token-id based and still correct, just slower.", flush=True)
tokenizer.model_max_length = args.seq_len
dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} "
f"pad_token={repr(tokenizer.pad_token)} model_max_length={tokenizer.model_max_length}")
# 2) 再加载模型 之前,先算 dtype
def _bf16_supported():
if not torch.cuda.is_available():
return False
# 兼容不同 torch 版本:优先用 API退化到算力判断
if hasattr(torch.cuda, "is_bf16_supported"):
return torch.cuda.is_bf16_supported()
major, minor = torch.cuda.get_device_capability()
return (major, minor) >= (8, 0) # Ampere 及以上
use_bf16 = bool(args.bf16 and _bf16_supported())
dtype = (torch.bfloat16 if use_bf16 else
(torch.float16 if torch.cuda.is_available() else torch.float32))
try:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")
except Exception:
pass
# 交给插件做 ZeRO-Init/分片加载
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
torch_dtype=dtype,
low_cpu_mem_usage=True,
trust_remote_code=True,
attn_implementation="sdpa",
)
print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True)
# dbg(f"model loaded: dtype={next(model.parameters()).dtype} "
# f"use_cache={getattr(model.config,'use_cache',None)} "
# f"pad_token_id={getattr(model.config,'pad_token_id',None)}")
# 3) pad/alibi 等配置
model.config.pad_token_id = tokenizer.pad_token_id
if getattr(model, "generation_config", None) is not None:
model.generation_config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False
if args.gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
try:
# 让 PyTorch 自己选,或显式打开高效实现(任选其一):
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(False)
except Exception:
pass
# ✅ 放在这里打印“修正后”的值
dbg(f"post-config: use_cache={model.config.use_cache} "
f"model.pad_token_id={model.config.pad_token_id} "
f"gen.pad_token_id={getattr(getattr(model,'generation_config',None),'pad_token_id',None)} "
f"tok.pad={tokenizer.pad_token}/{tokenizer.pad_token_id}")
assert tokenizer.pad_token_id is not None, "tokenizer.pad_token_id must not be None"
assert model.config.pad_token_id == tokenizer.pad_token_id, \
f"model.pad_token_id {model.config.pad_token_id} != tokenizer.pad_token_id {tokenizer.pad_token_id}"
# ===== 数据鲁棒性检查(多机各自执行)=====
files = sorted(glob.glob(args.data_glob))
if len(files) == 0:
raise FileNotFoundError(
f"[host={host} rank={rank}] No files matched DATA_GLOB={args.data_glob}\n"
"每台机器都必须在相同本地路径下放置数据;"
"可通过 DATA_GLOB=<your_glob> ./run_ds.sh 覆写。"
)
if is_main_process():
print(f"[data] matched {len(files)} files on host={host}, example[0]={files[0]}", flush=True)
# ====== 小探针:样本结构 ======
ds_stream_probe = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
def ex_iter_probe():
for ex in ds_stream_probe:
yield ex
train_stream_probe = QwenChatSFTDataset(ex_iter_probe(), tokenizer, seq_len=args.seq_len)
try:
sample = next(iter(train_stream_probe))
except StopIteration:
raise RuntimeError(
f"[host={host} rank={rank}] 数据文件匹配到了,但没有产生任何可训练样本。\n"
"请确认每行 JSON 至少包含 'messages'(列表,含 user/assistant字段"
"若含 <think>…</think> 请确保不包含真实思维文本,或移除。\n"
"另外检查 seq_len 是否过小导致全部被裁。"
)
# 更靠谱的自检(替换你现在的两行 assert
ids, attn, labs = sample["input_ids"], sample["attention_mask"], sample["labels"]
assert (labs != -100).any(), "[fatal] no supervised tokens in first valid sample"
# pad 区必须被忽略监督
assert bool((labs[attn == 0] == -100).all()), "[fatal] padded tokens must have label -100"
# ====== 正式训练流(不做任何手动分片,交给 Accelerate/Trainer======
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True).shuffle(buffer_size=50000, seed=args.seed)
train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
# ====== 一致性探针(不分片)======
ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len)
def has_at_least(stream, n: int):
it = iter(stream)
for _ in range(n):
try:
next(it)
except StopIteration:
return 0
return 1
need = max(1, args.gradient_accumulation_steps)
local_ok = has_at_least(probe_stream, need)
if dist.is_available() and dist.is_initialized():
t = torch.tensor(local_ok, device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu"))
dist.all_reduce(t, op=dist.ReduceOp.MIN)
if t.item() == 0:
if is_main_process():
print(
f"[FATAL] 至少有一个 rank 在一个优化 step 内供不上 {need} 个微批 (GA={need})。 "
f"请减少 GA 或扩大/清洗数据;本次训练不会启动。",
flush=True
)
dist.barrier()
sys.exit(2)
else:
if local_ok == 0:
print(
f"[FATAL] 本机在一个优化 step 内供不上 {need} 个微批 (GA={need})。 "
f"请减少 GA 或扩大/清洗数据;本次训练不会启动。",
flush=True
)
sys.exit(2)
# ---- Eval 构造:优先使用 --eval_data_glob否则才用 eval_ratio 抽样 ----
eval_dataset: Optional[Dataset] = None
class ListDataset(Dataset):
def __init__(self, items): self.items = items
def __len__(self): return len(self.items)
def __getitem__(self, idx): return self.items[idx]
if args.eval_data_glob:
eval_files = sorted(glob.glob(args.eval_data_glob))
if len(eval_files) == 0:
raise FileNotFoundError(f"[host={host} rank={rank}] No eval files matched EVAL_DATA_GLOB={args.eval_data_glob}")
if is_main_process():
print(f"[eval] matched {len(eval_files)} files, example[0]={eval_files[0]}", flush=True)
ds_eval_stream = load_dataset("json", data_files={"eval": eval_files}, split="eval", streaming=True)
def ex_iter_eval():
for ex in ds_eval_stream:
yield ex
eval_iterable = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len)
eval_items: List[Dict[str, torch.Tensor]] = []
for sample in eval_iterable:
eval_items.append(sample)
if len(eval_items) == 0:
raise RuntimeError("[eval] eval_data_glob 读到了 0 条有效样本,请检查 messages 结构。")
eval_dataset = ListDataset(eval_items)
elif args.eval_ratio and args.eval_ratio > 0:
desired_eval_batches = 200
tmp_stream = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
def ex_iter_eval2():
for ex in tmp_stream:
yield ex
eval_stream = QwenChatSFTDataset(ex_iter_eval2(), tokenizer, seq_len=args.seq_len)
eval_samples = []
it = iter(eval_stream)
for _ in range(desired_eval_batches):
try:
eval_samples.append(next(it))
except StopIteration:
break
if len(eval_samples) > 0:
eval_dataset = ListDataset(eval_samples)
# ---- 统一补齐 eval 集(确保不会出现空 batch----
if eval_dataset is not None:
ws = max(world_size, 1)
be = max(1, args.per_device_eval_batch_size)
global_bs = ws * be
r = len(eval_dataset) % global_bs
if r != 0:
pad_need = global_bs - r
eval_dataset.items += eval_dataset.items[:pad_need]
if is_main_process():
print(f"[eval] padded eval set to {len(eval_dataset)} "
f"(world_size={ws}, per_device_eval_batch_size={be}, global_bs={global_bs})",
flush=True)
# 补齐后再做 sanity check
assert len(eval_dataset) % global_bs == 0, \
f"eval size {len(eval_dataset)} still not divisible by global_bs {global_bs}"
# 更稳:联调阶段不强行 pad 到 4096
data_collator = SFTDataCollator(tokenizer, pad_to_length=None)
os.makedirs(args.output_dir, exist_ok=True)
logging_dir = os.path.join(args.output_dir, "logs")
os.makedirs(logging_dir, exist_ok=True)
# ---- 兼容 4.51eval_strategy与旧版evaluation_strategy ----
ta_kwargs = {}
sig = inspect.signature(TrainingArguments.__init__).parameters
if eval_dataset is not None:
if "eval_strategy" in sig:
ta_kwargs["eval_strategy"] = "steps"
elif "evaluation_strategy" in sig:
ta_kwargs["evaluation_strategy"] = "steps"
else:
if "eval_strategy" in sig:
ta_kwargs["eval_strategy"] = "no"
elif "evaluation_strategy" in sig:
ta_kwargs["evaluation_strategy"] = "no"
ta_sig = inspect.signature(TrainingArguments.__init__).parameters
ta_kwargs2 = dict(
output_dir=args.output_dir,
logging_dir=logging_dir,
run_name=f"sft-{os.path.basename(args.output_dir)}-{socket.gethostname()}",
do_train=True,
do_eval=(eval_dataset is not None),
# eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
# 用用户指定的 eval_steps没有 eval 集就 None
eval_steps=(args.eval_steps if eval_dataset is not None else None),
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
warmup_ratio=args.warmup_ratio,
num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0,
max_steps=args.max_steps if args.max_steps > 0 else -1,
lr_scheduler_type="cosine",
logging_steps=args.log_interval,
save_steps=args.save_steps,
save_total_limit=None,
deepspeed=(args.deepspeed if use_ds else None),
dataloader_drop_last=False,
dataloader_num_workers=0,
label_smoothing_factor=0.0,
per_device_eval_batch_size=args.per_device_eval_batch_size,
report_to=([] if args.report_to == "none" else [args.report_to]),
#bf16=args.bf16,
#fp16=(not args.bf16),
gradient_checkpointing=args.gradient_checkpointing,
remove_unused_columns=False,
save_on_each_node=True,
logging_first_step=True,
**ta_kwargs, # 你之前构造的 eval_strategy 兼容项
)
if "dataloader_pin_memory" in ta_sig:
ta_kwargs2["dataloader_pin_memory"] = False
if "torch_compile" in ta_sig:
ta_kwargs2["torch_compile"] = False
# 构造 TrainingArguments 之前,沿用上面的 use_bf16 判定
ta_kwargs2.update({
"bf16": use_bf16,
"fp16": (torch.cuda.is_available() and not use_bf16),
})
ta_sig = inspect.signature(TrainingArguments.__init__).parameters
if "save_strategy" in ta_sig:
ta_kwargs2["save_strategy"] = "steps"
ta_kwargs2.update({
"load_best_model_at_end": args.load_best_model_at_end,
"metric_for_best_model": args.metric_for_best_model,
"greater_is_better": args.greater_is_better, # 对 eval_loss 保持 False默认
# "save_strategy": "steps", # 与 eval_steps 对齐
})
if args.early_stopping_patience > 0 and eval_dataset is None:
print("[warn] early_stopping_patience>0 但未提供 eval 数据集;早停将不会触发。", flush=True)
training_args = TrainingArguments(**ta_kwargs2)
trainer_kwargs = {}
if "processing_class" in inspect.signature(Trainer.__init__).parameters:
trainer_kwargs["processing_class"] = tokenizer
else:
trainer_kwargs["tokenizer"] = tokenizer
trainer = DebugTrainer(
model=model,
args=training_args,
train_dataset=train_stream,
eval_dataset=eval_dataset,
data_collator=data_collator,
**trainer_kwargs,
)
if args.early_stopping_patience and args.early_stopping_patience > 0:
trainer.add_callback(EarlyStoppingCallback(
early_stopping_patience=args.early_stopping_patience,
early_stopping_threshold=args.early_stopping_threshold
))
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
# ==== 断点恢复判定(非共享盘安全写法)====
def last_step(path: str) -> int:
ck = get_last_checkpoint(path)
if ck is None:
return -1
base = os.path.basename(ck)
try:
return int(base.split("-")[-1])
except Exception:
return -1
local_last = last_step(args.output_dir) # -1 表示本机没有任何 checkpoint
device = torch.device(f"cuda:{local_rank}" if (torch.cuda.is_available() and local_rank >= 0) else "cpu")
resume_flag = None
if dist.is_available() and dist.is_initialized():
# 只要有任意一个 rank 没有 ckpt -> 不恢复
has_local = torch.tensor(1 if local_last >= 0 else 0, device=device)
dist.all_reduce(has_local, op=dist.ReduceOp.MIN)
if has_local.item() == 1:
# 全员都有:收集每个 rank 的 last step取公共最小步 k每台机器都一定存在
ts = torch.tensor(local_last, device=device)
world = dist.get_world_size()
buf = [torch.zeros_like(ts) for _ in range(world)]
dist.all_gather(buf, ts)
steps = [b.item() for b in buf]
k = min(steps)
if k >= 0:
resume_flag = os.path.join(args.output_dir, f"checkpoint-{k}")
if is_main_process():
print(f"[resume] steps={steps}, resume={resume_flag}", flush=True)
else:
# 单机或未初始化分布式:本地有就按本机最后一步恢复
if local_last >= 0:
resume_flag = os.path.join(args.output_dir, f"checkpoint-{local_last}")
print_once(f"[host={socket.gethostname()}] Resume = {resume_flag is not None}")
# —— 全局一致性检测:如果有任意 rank 缺这个 ckpt就禁用恢复 ——
if dist.is_available() and dist.is_initialized():
device = torch.device(f"cuda:{local_rank}" if (torch.cuda.is_available() and local_rank >= 0) else "cpu")
present = torch.tensor(1 if (resume_flag is not None and os.path.isdir(resume_flag)) else 0, device=device)
dist.all_reduce(present, op=dist.ReduceOp.MIN)
if present.item() == 0:
if is_main_process():
print(f"[resume] {resume_flag} missing on some ranks -> disable resume.", flush=True)
resume_flag = None
dist.barrier()
else:
if resume_flag is not None and not os.path.isdir(resume_flag):
# 单机:缺就直接禁用恢复
print(f"[resume] {resume_flag} not found locally -> disable resume.", flush=True)
resume_flag = None
print_once(f"[resume] final = {resume_flag if resume_flag else 'None (fresh start)'}")
print_once("***** Starting training *****")
dbg(f"allocated={torch.cuda.memory_allocated()/1024**2:.1f} MB, "
f"reserved={torch.cuda.memory_reserved()/1024**2:.1f} MB")
train_result = trainer.train(resume_from_checkpoint=resume_flag)
trainer.save_model() # DeepSpeed stage3_gather_16bit_weights_on_model_save=true 时,在 rank0 聚合整模型
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
if eval_dataset is not None:
print_once("***** Running eval *****")
eval_metrics = trainer.evaluate()
trainer.log_metrics("eval", eval_metrics)
trainer.save_metrics("eval", eval_metrics)
print_once("Done.")
if __name__ == "__main__":
main()