#!/usr/bin/env python3
import os
# 让 user-site 生效(deepspeed/torchrun 常把 PYTHONNOUSERSITE=1 带进来)
os.environ.pop("PYTHONNOUSERSITE", None)
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("WANDB_START_METHOD", "thread")
os.environ.setdefault("WANDB_DIR", f"/tmp/{os.environ.get('USER','user')}/wandb")
# ★ 新增:自建服务的 base_url(避免走默认的 cloud)
os.environ.setdefault("WANDB_BASE_URL", "https://wandb.szaiai.com")
# (可选)某些版本支持这个 env;真正生效仍以下面的 Settings(init_timeout=...) 为准
os.environ.setdefault("WANDB_INIT_TIMEOUT", "300")
import glob
import socket
import argparse
import inspect
import sys
from typing import Dict, List, Iterable, Iterator, Tuple, Optional
import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset, Dataset
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
set_seed
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import get_last_checkpoint
from torch.optim import AdamW as TorchAdamW
# ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ====
import os, sys, site, shutil
home = os.path.expanduser("~")
want = [f"{home}/.local/bin", "/usr/local/cuda-11.8/bin"]
cur = os.environ.get("PATH", "").split(":")
new = [d for d in want if d and d not in cur] + cur
os.environ["PATH"] = ":".join(new)
# 可见性打印,方便你在日志里确认 tn06 是否拿到了
print(f"[env] PATH={os.environ['PATH']}", flush=True)
print(f"[env] which ninja={shutil.which('ninja')} which nvcc={shutil.which('nvcc')}", flush=True)
os.environ.setdefault("CUDA_HOME", "/usr/local/cuda-11.8")
ld = os.environ.get("LD_LIBRARY_PATH", "")
cuda_lib = "/usr/local/cuda-11.8/lib64"
if cuda_lib not in ld.split(":"):
os.environ["LD_LIBRARY_PATH"] = f"{cuda_lib}:{ld}" if ld else cuda_lib
# 可视化确认
import torch
print(f"[env] torch.version.cuda={torch.version.cuda} CUDA_HOME={os.environ['CUDA_HOME']}", flush=True)
# ==== ensure python can see user site & set torch extensions dir ====
import os, sys, site
# 1) 确保不会屏蔽用户站点包(ninja 安在 ~/.local 里)
# os.environ.pop("PYTHONNOUSERSITE", None)
os.environ.pop("DS_BUILD_OPS", None)
os.environ.pop("DS_SKIP_CUDA_BUILD", None)
# 2) 把用户站点目录插入 sys.path(比如 /home/test/.local/lib/python3.10/site-packages)
try:
user_site = site.getusersitepackages()
if user_site and user_site not in sys.path:
sys.path.insert(0, user_site)
except Exception:
pass
# 3) 统一 JIT 缓存目录(可选,但更稳;日志里你现在用的是 ~/.cache/torch_extensions)
os.environ.setdefault("TORCH_EXTENSIONS_DIR", f"/tmp/{os.environ.get('USER','user')}/torch_ext")
os.environ.setdefault("MAX_JOBS", "12")
import shutil
if shutil.which("ninja") is None:
os.environ["USE_NINJA"] = "0"
print("[env] no CLI ninja on PATH -> USE_NINJA=0 fallback", flush=True)
# 4) 立即验证 ninja 与 CPUAdam 的 JIT(若这里失败,日志会第一时间告诉你是哪台/哪 rank 环境不对)
try:
from deepspeed.ops.op_builder import CPUAdamBuilder
CPUAdamBuilder().load()
print("[env] CPUAdamBuilder JIT OK", flush=True)
except Exception as e:
# ninja 可执行找不到时走兜底:禁用 ninja,用 setuptools 构建(首次会慢一点,但必过)
if "Ninja is required to load C++ extensions" in str(e):
os.environ["USE_NINJA"] = "0"
print("[env] no CLI ninja, retry with USE_NINJA=0 (fallback build)", flush=True)
from deepspeed.ops.op_builder import CPUAdamBuilder
CPUAdamBuilder().load()
print("[env] CPUAdamBuilder JIT OK (fallback)", flush=True)
else:
import socket
print(f"[env][host={socket.gethostname()} RANK={os.environ.get('RANK','?')}] PRE-JIT FAILED: {e}", flush=True)
raise
# ----------------- 进程工具 -----------------
def is_main_process():
return int(os.environ.get("RANK", "0")) == 0
def print_once(*args, **kwargs):
if is_main_process():
print(*args, **kwargs, flush=True)
class DebugTrainer(Trainer):
def training_step(self, model, inputs, num_items_in_batch=None):
if not hasattr(self, "_dbg_printed"):
rank = int(os.environ.get("RANK", "0"))
host = socket.gethostname()
ids = inputs["input_ids"]
msk = inputs["attention_mask"]
labs = inputs["labels"]
print(f"[step0] ids={ids.device} mask={msk.device} labs={labs.device} "
f"supervised={(labs!=-100).sum().item()}",
flush=True)
print(
f"[step0][host={host} RANK={rank}] "
f"input_ids.shape={tuple(ids.shape)} "
f"attention_mask.shape={tuple(msk.shape)} "
f"labels.shape={tuple(labs.shape)} "
f"num_items_in_batch={num_items_in_batch}",
flush=True
)
self._dbg_printed = True
return super().training_step(model, inputs, num_items_in_batch)
# ----------------- 日志回调 -----------------
class CsvLossLogger(TrainerCallback):
def __init__(self, csv_path: str):
self.csv_path = csv_path
if is_main_process():
os.makedirs(os.path.dirname(csv_path), exist_ok=True)
with open(self.csv_path, "w", encoding="utf-8") as f:
f.write("step,loss,lr,total_flos\n")
# def on_log(self, args, state, control, logs=None, **kwargs):
# if not is_main_process() or logs is None:
# return
# with open(self.csv_path, "a", encoding="utf-8") as f:
# f.write(f"{state.global_step},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n")
def on_train_begin(self, args, state, control, **kwargs):
tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0)
tot = tmp if isinstance(tmp, int) and tmp > 0 else 0
rank = os.environ.get("RANK", "?")
host = socket.gethostname()
print(f"[{host} rank={rank}] total_steps={tot}", flush=True)
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is None:
return
# ---- 控制台打印:所有 rank 都打当前步/总步 ----
cur = int(getattr(state, "global_step", 0) or 0)
# if getattr(args, "logging_steps", None) and cur % args.logging_steps != 0:
# return
tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0)
tot = tmp if isinstance(tmp, int) and tmp > 0 else 0
pct = (f"{(cur / tot * 100):.1f}%" if tot else "n/a")
# —— tot 一旦可用,就再宣布一次总步数(只打印一次)
if tot and not hasattr(self, "_tot_announced"):
print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] total_steps={tot}", flush=True)
self._tot_announced = True
# if not is_main_process():
# return
rank = os.environ.get("RANK", "?")
host = socket.gethostname()
print(
f"[{host} rank={rank}] step {cur}/{tot} ({pct}) "
f"loss={logs.get('loss')} lr={logs.get('learning_rate')}",
flush=True
)
# ---- 只在主进程写 CSV,避免并发写 ----
if not is_main_process():
return
with open(self.csv_path, "a", encoding="utf-8") as f:
f.write(
f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n"
)
# # ----------------- 仅监督 assistant 的数据集 -----------------
# def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
# """
# 在 apply_chat_template 渲染后的文本中,返回所有 assistant 内容的字符区间 [start, end)。
# """
# spans: List[Tuple[int, int]] = []
# open_tag = "<|im_start|>assistant\n"
# close_tag = "<|im_end|>\n"
# pos = 0
# while True:
# a = rendered.find(open_tag, pos)
# if a == -1:
# break
# start = a + len(open_tag)
# b = rendered.find(close_tag, start)
# if b == -1:
# break
# spans.append((start, b))
# pos = b + len(close_tag)
# return spans
# class QwenChatSFTDataset(IterableDataset):
# """
# 期望 jsonl 每行形如:
# {"messages":[{"role":"system","content":"..."},{"role":"user","content":"..."},{"role":"assistant","content":"..."}]}
# 可选包含工具:
# {"messages":[...], "tools":[{...}]}
# 工作流:
# - 使用 tokenizer.apply_chat_template 渲染
# - 仅对 assistant 片段计损失(其他 token 的 label = -100)
# - 超长序列保留尾部(通常包含回答)
# """
# def __init__(self,
# ex_iter: Iterable[dict],
# tokenizer: AutoTokenizer,
# seq_len: int = 4096):
# self.ex_iter = ex_iter
# self.tok = tokenizer
# self.seq_len = seq_len
# def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
# # >>> DEBUG BEGIN
# dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1"
# if not hasattr(self, "_dbg_seen"): self._dbg_seen = 0
# dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3"))
# rank = int(os.environ.get("RANK", "0"))
# lrank = int(os.environ.get("LOCAL_RANK", "-1"))
# host = socket.gethostname()
# # >>> DEBUG END
# for ex in self.ex_iter:
# msgs = ex.get("messages", None)
# if not msgs or not isinstance(msgs, list):
# continue
# # 可选过滤 think
# bad = False
# for m in msgs:
# if m.get("role") == "assistant" and isinstance(m.get("content"), str):
# c = m["content"]
# if "" in c and "" in c:
# inner = c.split("")[-1].split("")[0].strip()
# if inner:
# bad = True; break
# # 注销这里就可以确保参与计算监督微调,打开就表示跳过
# if bad:
# continue
# tools = ex.get("tools", None)
# # 兼容老版本 tokenizer.apply_chat_template 不支持 tools 参数的情况
# try:
# rendered: str = self.tok.apply_chat_template(
# msgs, tools=tools, add_generation_prompt=False, tokenize=False
# )
# except TypeError:
# rendered: str = self.tok.apply_chat_template(
# msgs, add_generation_prompt=False, tokenize=False
# )
# if not isinstance(rendered, str) or not rendered.strip():
# continue
# spans = _assistant_char_spans(rendered)
# if not spans:
# continue
# enc = self.tok(
# rendered,
# add_special_tokens=False,
# return_offsets_mapping=True
# )
# input_ids: List[int] = enc["input_ids"]
# offsets: List[Tuple[int, int]] = enc["offset_mapping"]
# if not input_ids:
# continue
# labels = [-100] * len(input_ids)
# def in_any_span(lo: int, hi: int) -> bool:
# for s, e in spans:
# if not (hi <= s or lo >= e):
# return True
# return False
# for i, (lo, hi) in enumerate(offsets):
# if in_any_span(lo, hi):
# labels[i] = input_ids[i]
# # —— 固定长度策略:先截尾,再在 Dataset 层补到固定 seq_len ——
# # 1) 截断到 seq_len(保留尾部)
# if len(input_ids) > self.seq_len:
# input_ids = input_ids[-self.seq_len:]
# labels = labels[-self.seq_len:]
# # 2) 左侧补齐到 seq_len(保证所有样本长度一致)
# pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id
# L = len(input_ids)
# if L < self.seq_len:
# pad = self.seq_len - L
# input_ids = ([pad_id] * pad) + input_ids
# labels = ([-100] * pad) + labels
# attn_mask = [0] * pad + [1] * L
# else:
# # 恰好等于 seq_len
# attn_mask = [1] * self.seq_len
# # 若没有任何可训练 token(labels 全 -100),跳过
# if all(v == -100 for v in labels):
# continue
# assert len(input_ids) == self.seq_len
# assert len(labels) == self.seq_len
# assert len(attn_mask) == self.seq_len
# # >>> DEBUG PRINT(此时变量已定义)
# if dbg_on and self._dbg_seen < dbg_limit:
# sup_tok = sum(1 for v in labels if v != -100)
# print(
# f"[sample][host={host} RANK={rank} LRank={lrank}] "
# f"rendered_len={len(rendered)} toks={len(input_ids)} sup_toks={sup_tok} "
# f"seq_len={self.seq_len} pad_id={pad_id}",
# flush=True
# )
# if sup_tok == 0:
# print(
# f"[WARN][host={host} RANK={rank}] sample has 0 supervised tokens -> would be skipped",
# flush=True
# )
# self._dbg_seen += 1
# # <<< DEBUG PRINT
# yield {
# "input_ids": torch.tensor(input_ids, dtype=torch.long),
# "attention_mask": torch.tensor(attn_mask, dtype=torch.long),
# "labels": torch.tensor(labels, dtype=torch.long),
# }
# # ================================= 监督 ============================================
# # ----------------- 工具:提取 assistant 字符区间 -----------------
# def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
# """
# 在 apply_chat_template 渲染后的纯文本中,返回所有 assistant 段的字符区间 [start, end)
# 这些区间覆盖了 assistant 的全部内容(包括 ... 标签与正文)。
# """
# spans: List[Tuple[int, int]] = []
# open_tag = "<|im_start|>assistant\n"
# close_tag = "<|im_end|>\n"
# pos = 0
# while True:
# a = rendered.find(open_tag, pos)
# if a == -1:
# break
# s = a + len(open_tag)
# b = rendered.find(close_tag, s)
# if b == -1:
# break
# spans.append((s, b))
# pos = b + len(close_tag)
# return spans
# # ----------------- 数据集:SFT(监督 assistant 全段,含 标签与内容) -----------------
# class QwenChatSFTDataset(IterableDataset):
# """
# 期望 jsonl 每行形如:
# {"messages":[{"role":"system","content":"..."},{"role":"user","content":"..."},{"role":"assistant","content":"..."}]}
# 可选包含工具:
# {"messages":[...], "tools":[{...}]}
# 工作流:
# - 使用 tokenizer.apply_chat_template 渲染
# - 仅对 assistant 片段计损失(其他 token 的 label = -100)
# - 截断时“优先确保最后一个 assistant 不被截断”;若其长度 > seq_len,则保留其“结尾”以避免切尾
# """
# def __init__(self,
# ex_iter: Iterable[dict],
# tokenizer: AutoTokenizer,
# seq_len: int = 4096):
# self.ex_iter = ex_iter
# self.tok = tokenizer
# self.seq_len = seq_len
# def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
# # >>> DEBUG BEGIN
# dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1"
# if not hasattr(self, "_dbg_seen"): self._dbg_seen = 0
# dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3"))
# rank = int(os.environ.get("RANK", "0"))
# lrank = int(os.environ.get("LOCAL_RANK", "-1"))
# host = socket.gethostname()
# # >>> DEBUG END
# for ex in self.ex_iter:
# msgs = ex.get("messages", None)
# if not msgs or not isinstance(msgs, list):
# continue
# # —— 不再过滤 :显式允许其参与监督(包括标签与正文)
# tools = ex.get("tools", None)
# # 兼容老版本 tokenizer.apply_chat_template 不支持 tools 参数的情况
# try:
# rendered: str = self.tok.apply_chat_template(
# msgs, tools=tools, add_generation_prompt=False, tokenize=False
# )
# except TypeError:
# rendered: str = self.tok.apply_chat_template(
# msgs, add_generation_prompt=False, tokenize=False
# )
# if not isinstance(rendered, str) or not rendered.strip():
# continue
# spans = _assistant_char_spans(rendered)
# if not spans:
# continue
# # 编码并拿到字符偏移,确保与 rendered 对齐
# enc = self.tok(
# rendered,
# add_special_tokens=False,
# return_offsets_mapping=True
# )
# input_ids: List[int] = enc["input_ids"]
# offsets: List[Tuple[int, int]] = enc["offset_mapping"]
# if not input_ids:
# continue
# # 先对“所有 assistant 片段”打标签;包含 标签与内容、以及回答正文
# labels = [-100] * len(input_ids)
# def in_any_span(lo: int, hi: int) -> bool:
# for s, e in spans:
# # 与任一 [s, e) 有交集即监督
# if not (hi <= s or lo >= e):
# return True
# return False
# for i, (lo, hi) in enumerate(offsets):
# if in_any_span(lo, hi):
# labels[i] = input_ids[i]
# # 若没有任何可训练 token(labels 全 -100),跳过
# if all(v == -100 for v in labels):
# continue
# # ======== Assistant 感知的截断策略(保证“最后一个 assistant 不被截掉”)========
# if len(input_ids) > self.seq_len:
# # 取“最后一个 assistant”的字符区间
# s_last, e_last = spans[-1]
# # 将字符区间映射到 token 索引区间 [j, k_excl)
# # j: 第一个 token,其右端 hi > s_last
# j = 0
# while j < len(offsets) and offsets[j][1] <= s_last:
# j += 1
# # k_excl: 第一个 token,其左端 lo >= e_last(即不再与 [s_last, e_last) 相交)
# k_excl = j
# while k_excl < len(offsets) and offsets[k_excl][0] < e_last:
# k_excl += 1
# A = max(0, k_excl - j) # 最后一个 assistant 覆盖的 token 数
# if A >= self.seq_len:
# # 单个 assistant 本身超过窗口 —— 保“结尾”,避免被切尾
# start = max(0, k_excl - self.seq_len)
# end = start + self.seq_len
# else:
# # 有空间容纳整个 assistant:尽量把窗口对齐到包括完整 assistant
# # 先试图把窗口从 j 开始,但要保证 k_excl 也在窗口内
# start = max(0, min(j, len(input_ids) - self.seq_len))
# end = start + self.seq_len
# if end < k_excl:
# # 还没覆盖到 assistant 末尾,则右移窗口到恰好覆盖末尾
# end = k_excl
# start = end - self.seq_len
# if start < 0:
# start = 0
# end = self.seq_len
# # 可选:尝试“居中”一点(留部分历史上下文),但仍需包含完整 [j, k_excl)
# leftover = self.seq_len - A
# # 把剩余的一半尽量分配给左侧上下文(不越界)
# left_wish = leftover // 2
# start = max(0, min(j - left_wish, start))
# end = start + self.seq_len
# if end < k_excl:
# # 若居中导致末尾又被排除,再纠正一次
# end = k_excl
# start = end - self.seq_len
# if start < 0:
# start = 0
# end = self.seq_len
# # 真正切片
# input_ids = input_ids[start:end]
# labels = labels[start:end]
# # 注意:offsets 后续不再使用(只为确定切片窗口),无需同步裁剪
# # 训练注意:这里的策略保证:
# # - 若最后一个 assistant <= seq_len:完整保留;
# # - 若 > seq_len:至少保证 assistant 的“结尾”在窗口内,不会“切尾”。
# # ======== 统一长度:左侧补齐到 seq_len ========
# pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id
# L = len(input_ids)
# if L < self.seq_len:
# pad = self.seq_len - L
# input_ids = ([pad_id] * pad) + input_ids
# labels = ([-100] * pad) + labels
# attn_mask = [0] * pad + [1] * L
# else:
# attn_mask = [1] * self.seq_len
# # Sanity
# assert len(input_ids) == self.seq_len
# assert len(labels) == self.seq_len
# assert len(attn_mask) == self.seq_len
# # >>> DEBUG PRINT
# if dbg_on and self._dbg_seen < dbg_limit:
# sup_tok = sum(1 for v in labels if v != -100)
# print(
# f"[sample][host={host} RANK={rank} LRank={lrank}] "
# f"toks={len(input_ids)} sup_toks={sup_tok} seq_len={self.seq_len} pad_id={pad_id}",
# flush=True
# )
# if sup_tok == 0:
# print(
# f"[WARN][host={host} RANK={rank}] sample has 0 supervised tokens -> would be skipped",
# flush=True
# )
# self._dbg_seen += 1
# # <<< DEBUG PRINT
# yield {
# "input_ids": torch.tensor(input_ids, dtype=torch.long),
# "attention_mask": torch.tensor(attn_mask, dtype=torch.long),
# "labels": torch.tensor(labels, dtype=torch.long),
# }
# # ================================= end ============================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import socket
from typing import List, Tuple, Iterable, Iterator, Dict
import torch
from torch.utils.data import IterableDataset
from transformers import AutoTokenizer # 仅作类型提示/引用,不强依赖
# # ----------------- 工具:提取 assistant 字符区间 -----------------
# def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
# """
# 在 apply_chat_template 渲染后的文本中,返回所有 assistant 内容的字符区间 [start, end)。
# (覆盖了 assistant 的全部内容,包括其中可能出现的 …)
# """
# spans: List[Tuple[int, int]] = []
# open_tag = "<|im_start|>assistant\n"
# close_tag = "<|im_end|>\n"
# pos = 0
# while True:
# a = rendered.find(open_tag, pos)
# if a == -1:
# break
# start = a + len(open_tag)
# b = rendered.find(close_tag, start)
# if b == -1:
# break
# end = b + len("<|im_end|>")
# spans.append((start, end))
# # spans.append((start, b))
# pos = b + len(close_tag)
# return spans
# ----------------- 工具:提取 assistant 字符区间 -----------------
def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
spans: List[Tuple[int, int]] = []
open_tag = "<|im_start|>assistant\n"
close_token = "<|im_end|>"
close_tag = close_token + "\n" # 常见模板带换行
pos = 0
while True:
a = rendered.find(open_tag, pos)
if a == -1:
break
start = a + len(open_tag)
# 先找含换行版本,找不到再退化找不带换行的
b = rendered.find(close_tag, start)
if b == -1:
b = rendered.find(close_token, start)
if b == -1:
break
end = b + len(close_token) # 把 <|im_end|> 本体纳入监督
spans.append((start, end))
# pos 跳过这一轮结束标记(带换行就多跳一格)
pos = b + (len(close_tag) if rendered.startswith(close_tag, b) else len(close_token))
return spans
# ----------------- 工具:提取所有 … 的字符区间(包含标签本身) -----------------
# def _think_char_spans(rendered: str) -> List[Tuple[int, int]]:
# """
# 纯 str.find 实现,不用正则。
# 返回全局的 … 区间列表,坐标为 rendered 上的绝对位置。
# """
# spans: List[Tuple[int, int]] = []
# open_tag = ""
# close_tag = ""
# pos = 0
# while True:
# a = rendered.find(open_tag, pos)
# if a == -1:
# break
# b = rendered.find(close_tag, a + len(open_tag))
# if b == -1:
# break
# spans.append((a, b + len(close_tag)))
# pos = b + len(close_tag)
# return spans
def _think_char_spans(rendered: str) -> List[Tuple[int, int]]:
"""
返回需要忽略监督的区间(仅 ... 的“内部”),
标签本身 与 仍参与监督,以便模型学会闭合。
"""
spans: List[Tuple[int, int]] = []
open_tag = ""
close_tag = ""
pos = 0
while True:
a = rendered.find(open_tag, pos)
if a == -1:
break
b = rendered.find(close_tag, a + len(open_tag))
if b == -1:
break
# 只忽略内部,不忽略两侧标签
spans.append((a + len(open_tag), b))
pos = b + len(close_tag)
return spans
# ----------------- 仅监督 assistant 的数据集(忽略 …) -----------------
class QwenChatSFTDataset(IterableDataset):
"""
期望 jsonl 每行形如:
{"messages":[{"role":"system","content":"..."},{"role":"user","content":"..."},{"role":"assistant","content":"..."}]}
可选包含工具:
{"messages":[...], "tools":[{...}]}
工作流:
- 使用 tokenizer.apply_chat_template 渲染(可带 tools)
- 仅对 assistant 片段计损失;凡落在 … 内的 token,labels 置 -100(不监督)
- 超长序列保留尾部(通常包含回答),再左侧补齐到固定长度
"""
def __init__(self,
ex_iter: Iterable[dict],
tokenizer: AutoTokenizer,
seq_len: int = 4096):
self.ex_iter = ex_iter
self.tok = tokenizer
self.seq_len = seq_len
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
# >>> DEBUG 开关
dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1"
if not hasattr(self, "_dbg_seen"):
self._dbg_seen = 0
dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3"))
rank = int(os.environ.get("RANK", "0"))
lrank = int(os.environ.get("LOCAL_RANK", "-1"))
host = socket.gethostname()
# <<< DEBUG
for ex in self.ex_iter:
msgs = ex.get("messages", None)
if not msgs or not isinstance(msgs, list):
continue
tools = ex.get("tools", None)
# 兼容老版本 tokenizer.apply_chat_template 不支持 tools 参数的情况
try:
rendered: str = self.tok.apply_chat_template(
msgs, tools=tools, add_generation_prompt=False, tokenize=False
)
except TypeError:
rendered: str = self.tok.apply_chat_template(
msgs, add_generation_prompt=False, tokenize=False
)
if not isinstance(rendered, str) or not rendered.strip():
continue
# —— 样本级终止符:确保训练时每条样本以 eos 结束 ——
# if self.tok.eos_token and not rendered.endswith(self.tok.eos_token):
# rendered += self.tok.eos_token
# —— 样本级终止符:把 eos 插入到最后一个 assistant 的 <|im_end|>\n 之前 ——
# if self.tok.eos_token:
# open_tag = "<|im_start|>assistant\n"
# close_tag = "<|im_end|>\n"
# head, sep, tail = rendered.rpartition(close_tag) # 按最后一个 close_tag 切
# if sep: # 找到了收尾标记
# # 仅当 assistant 文本末尾还没有 eos 才插入,避免重复
# if not head.endswith(self.tok.eos_token):
# rendered = head + self.tok.eos_token + sep + tail
# 1) 找到所有 assistant 区间 & 全局 think 区间
asst_spans = _assistant_char_spans(rendered)
if not asst_spans:
continue
think_spans = _think_char_spans(rendered)
# 2) 编码 & offset 对齐
enc = self.tok(
rendered,
add_special_tokens=False,
return_offsets_mapping=True
)
input_ids = enc["input_ids"]
offsets = enc["offset_mapping"]
if not input_ids:
continue
# 3) 仅监督 assistant 片段,且排除落在 think 区间内的 token
labels = [-100] * len(input_ids)
def in_any_span(lo: int, hi: int, intervals: List[Tuple[int, int]]) -> bool:
for s, e in intervals:
# 有交集即为 True
if not (hi <= s or lo >= e):
return True
return False
for i, (lo, hi) in enumerate(offsets):
if in_any_span(lo, hi, asst_spans) and not in_any_span(lo, hi, think_spans):
labels[i] = input_ids[i]
# —— 固定长度策略:先截尾(保留尾部),再左侧补齐 ——
# if len(input_ids) > self.seq_len:
# input_ids = input_ids[-self.seq_len:]
# labels = labels[-self.seq_len:]
# ======== 助手感知的截断策略:尽量保证“最后一个 assistant 片段”完整 ========
if len(input_ids) > self.seq_len:
# 取最后一个 assistant 的字符区间([s_last, e_last))
s_last, e_last = asst_spans[-1]
# 用 offsets 把字符区间映射到 token 索引区间 [j, k_excl)
j = 0
while j < len(offsets) and offsets[j][1] <= s_last:
j += 1
k_excl = j
while k_excl < len(offsets) and offsets[k_excl][0] < e_last:
k_excl += 1
A = max(0, k_excl - j) # 最后一个 assistant 覆盖的 token 数
if A >= self.seq_len:
# 单个 assistant 本身超过窗口 —— 保“结尾”,避免切尾
start = max(0, k_excl - self.seq_len)
end = start + self.seq_len
else:
# 有空间容纳整个 assistant:让窗口覆盖完整 [j, k_excl)
start = max(0, min(j, len(input_ids) - self.seq_len))
end = start + self.seq_len
if end < k_excl:
end = k_excl
start = max(0, end - self.seq_len)
# 可选:尝试“居中”一点(给最后一个 assistant 左右留些上下文)
leftover = self.seq_len - A
left_wish = leftover // 2
start = max(0, min(j - left_wish, start))
end = start + self.seq_len
if end < k_excl:
end = k_excl
start = max(0, end - self.seq_len)
# 真正切片
input_ids = input_ids[start:end]
labels = labels[start:end]
pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id
L = len(input_ids)
if L < self.seq_len:
pad = self.seq_len - L
input_ids = ([pad_id] * pad) + input_ids
labels = ([-100] * pad) + labels
attn_mask = [0] * pad + [1] * L
else:
attn_mask = [1] * self.seq_len
# 若没有任何可训练 token(labels 全 -100),跳过
if all(v == -100 for v in labels):
continue
# Sanity
assert len(input_ids) == self.seq_len
assert len(labels) == self.seq_len
assert len(attn_mask) == self.seq_len
# >>> DEBUG
if dbg_on and self._dbg_seen < dbg_limit:
sup_tok = sum(1 for v in labels if v != -100)
print(
f"[sample][host={host} RANK={rank} LRank={lrank}] "
f"rendered_len={len(rendered)} toks={len(input_ids)} sup_toks={sup_tok} "
f"seq_len={self.seq_len} pad_id={pad_id}",
flush=True
)
if sup_tok == 0:
print(
f"[WARN][host={host} RANK={rank}] sample has 0 supervised tokens -> skipped",
flush=True
)
self._dbg_seen += 1
# <<< DEBUG
yield {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.tensor(attn_mask, dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long),
}
# ----------------- 专用 Collator:pad inputs, pad labels=-100 -----------------
class SFTDataCollator:
def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None):
self.tok = tokenizer
self.pad_to_length = pad_to_length
assert self.tok.pad_token_id is not None
def __call__(self, features):
if not features:
raise RuntimeError(
f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. "
f"Check dataset sharding/streaming."
)
def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x)
input_ids = [_to_list(f["input_ids"]) for f in features]
attn_masks = [_to_list(f["attention_mask"]) for f in features]
labels_list = [_to_list(f["labels"]) for f in features]
max_len_in_batch = max(len(x) for x in input_ids)
target_len = self.pad_to_length if self.pad_to_length is not None else max_len_in_batch
pad_id = self.tok.pad_token_id
batch_inp, batch_attn, batch_lab = [], [], []
for inp, msk, lab in zip(input_ids, attn_masks, labels_list):
pad_len = target_len - len(inp)
if pad_len < 0:
inp, msk, lab = inp[:target_len], msk[:target_len], lab[:target_len]
pad_len = 0
batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long))
batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long))
batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long))
# >>> DEBUG BEGIN
dbg_on = os.environ.get("DBG_COLLATE", "0") == "1"
if dbg_on:
rank = int(os.environ.get("RANK", "0"))
host = socket.gethostname()
bs = len(features)
first_len = len(input_ids[0]) if bs > 0 else None
print(
f"[collate][host={host} RANK={rank}] features={bs} "
f"target_len={target_len} first_len={first_len}",
flush=True
)
return {
"input_ids": torch.stack(batch_inp, dim=0),
"attention_mask": torch.stack(batch_attn, dim=0),
"labels": torch.stack(batch_lab, dim=0),
}
# ----------------- 参数 -----------------
def parse_args():
ap = argparse.ArgumentParser()
ap.add_argument("--model_name_or_path", type=str, required=True,
help="本地权重目录或 HF 名称(如 /home/test/Qwen3-8B)")
ap.add_argument("--data_glob", type=str, required=True,
help="本地 jsonl 通配符(每台机器都需有同路径数据;每行应含 messages/可选 tools)")
ap.add_argument("--output_dir", type=str, required=True,
help="本地输出目录(各节点各自本地写)")
ap.add_argument("--seq_len", type=int, default=4096)
ap.add_argument("--learning_rate", type=float, default=2e-5)
ap.add_argument("--weight_decay", type=float, default=0.1)
ap.add_argument("--warmup_ratio", type=float, default=0.02)
ap.add_argument("--num_train_epochs", type=float, default=1.0)
ap.add_argument("--max_steps", type=int, default=-1)
ap.add_argument("--log_interval", type=int, default=10)
ap.add_argument("--save_steps", type=int, default=500)
ap.add_argument("--eval_ratio", type=float, default=0.0)
ap.add_argument("--seed", type=int, default=1337)
#ap.add_argument("--deepspeed", type=str, default="ds_config_zero3.json")
ap.add_argument("--gradient_checkpointing", action="store_true")
ap.add_argument("--bf16", action="store_true",
help="3090/A100/H100 等可开 bf16;同时在 DS 配置里也要开")
ap.add_argument("--per_device_train_batch_size", type=int, default=1)
ap.add_argument("--gradient_accumulation_steps", type=int, default=64)
ap.add_argument("--report_to", type=str, default="tensorboard",
choices=["none","tensorboard","wandb"])
ap.add_argument("--wandb_project", type=str, default="ds-qwen3")
ap.add_argument("--eval_data_glob", type=str, default=None,
help="(可选) 测试/验证集 jsonl 通配符;如提供则优先使用")
ap.add_argument("--local_rank", type=int, default=-1,
help="for deepspeed/torchrun launcher; ignored by user code")
ap.add_argument("--per_device_eval_batch_size", type=int, default=1)
ap.add_argument("--deepspeed", type=str, default=None)
ap.add_argument("--eval_steps", type=int, default=10,
help="Evaluate every N optimizer steps when eval_dataset is provided")
return ap.parse_args()
# ----------------- 主函数 -----------------
def main():
args = parse_args()
# ✅ 只有 rank0 用 wandb,其它 rank 不上报
if os.environ.get("RANK", "0") != "0" and args.report_to == "wandb":
print(f"[rank {os.environ.get('RANK')}] force report_to=none", flush=True)
args.report_to = "none"
set_seed(args.seed)
host = socket.gethostname()
def dbg(msg):
print(
f"[dbg][host={host} RANK={os.environ.get('RANK','0')} "
f"LOCAL_RANK={os.environ.get('LOCAL_RANK', str(args.local_rank))}] {msg}",
flush=True
)
# 是否真的启用 DeepSpeed(传了配置文件且文件存在)
use_ds = bool(args.deepspeed and os.path.isfile(args.deepspeed))
dschf = None
if use_ds:
try:
from transformers.integrations.deepspeed import HfDeepSpeedConfig
src = "transformers.integrations.deepspeed"
except Exception:
try:
# 备用:部分版本直接从 transformers 暴露
from transformers import HfDeepSpeedConfig
src = "transformers"
except Exception as e:
raise RuntimeError(
"当前 transformers 版本未提供 HfDeepSpeedConfig,请升级/降级 transformers") from e
dschf = HfDeepSpeedConfig(args.deepspeed)
dbg(f"HfDeepSpeedConfig loaded from {src}")
if args.report_to == "wandb":
os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
# 仅在 rank0 预初始化 W&B
is_rank0 = os.environ.get("RANK", "0") == "0" and os.environ.get("LOCAL_RANK", "-1") in ("0", "-1")
if is_rank0:
import wandb
try:
# 避免外部遗留的 RUN_ID 强制续跑导致卡住
os.environ.pop("WANDB_RUN_ID", None)
# 可选字段从环境注入(有就用)
extra = {}
if os.getenv("WANDB_NAME"): extra["name"] = os.getenv("WANDB_NAME")
if os.getenv("WANDB_GROUP"): extra["group"] = os.getenv("WANDB_GROUP")
if os.getenv("WANDB_RESUME"): extra["resume"] = os.getenv("WANDB_RESUME") # 建议 'allow'
run = wandb.init(
project=args.wandb_project,
entity=os.getenv("WANDB_ENTITY") or os.getenv("WB_ENTITY") or "hailin",
settings=wandb.Settings(
base_url=os.getenv("WANDB_BASE_URL", "https://wandb.szaiai.com"),
init_timeout=int(os.getenv("WANDB_INIT_TIMEOUT", "300")),
),
**extra,
)
print(f"[wandb] run url: {getattr(run, 'url', '(n/a)')}", flush=True)
except Exception as e:
print(f"[wandb] init failed -> disable logging, reason={e}", flush=True)
os.environ["WANDB_DISABLED"] = "true"
args.report_to = "none"
else:
os.environ["WANDB_DISABLED"] = "true"
# 版本 & 启动参数 & 关键环境变量
import transformers as hf
try:
import deepspeed as ds
ds_ver = ds.__version__
except Exception:
ds_ver = "n/a"
dbg(f"torch={torch.__version__}, transformers={hf.__version__}, deepspeed={ds_ver}")
dbg(f"args={args}")
dbg("ENV: WORLD_SIZE=%s RANK=%s LOCAL_RANK=%s MASTER_ADDR=%s MASTER_PORT=%s CUDA_VISIBLE_DEVICES=%s" % (
os.environ.get("WORLD_SIZE"),
os.environ.get("RANK"),
os.environ.get("LOCAL_RANK", str(args.local_rank)),
os.environ.get("MASTER_ADDR"),
os.environ.get("MASTER_PORT"),
os.environ.get("CUDA_VISIBLE_DEVICES"),
))
dbg(f"cuda_available={torch.cuda.is_available()} device_count={torch.cuda.device_count()}")
# ---- 初始化分布式(供一致性探针使用)----
world_size = int(os.environ.get("WORLD_SIZE", "1"))
rank = int(os.environ.get("RANK", "0"))
local_rank = int(os.environ.get("LOCAL_RANK", str(args.local_rank)))
dbg(f"pre-init: world_size={world_size}, rank={rank}, local_rank={local_rank}")
if torch.cuda.is_available() and local_rank >= 0:
torch.cuda.set_device(local_rank)
dbg(f"set_device({local_rank}); current_device={torch.cuda.current_device()} "
f"name={torch.cuda.get_device_name(torch.cuda.current_device())}")
else:
dbg("no cuda or invalid local_rank; not calling set_device")
if world_size > 1 and dist.is_available() and not dist.is_initialized():
backend = "nccl" if torch.cuda.is_available() else "gloo"
dbg(f"init_process_group backend={backend} via env://")
dist.init_process_group(backend=backend, init_method="env://")
else:
dbg(f"skip init_process_group: world_size>1? {world_size>1}, dist_available={dist.is_available()}, already_init={dist.is_initialized()}")
if dist.is_available() and dist.is_initialized():
try:
dbg(f"dist.get_backend()={dist.get_backend()} "
f"dist.get_world_size()={dist.get_world_size()} dist.get_rank()={dist.get_rank()}")
except Exception as e:
dbg(f"dist query error: {e}")
# 1) 先补 tokenizer 的 pad
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 左侧补齐以匹配 Dataset 的左 pad 策略
try:
if getattr(tokenizer, "padding_side", None) != "left":
tokenizer.padding_side = "left"
except Exception:
pass
# 强制要求 fast tokenizer(offset_mapping 依赖 fast)
from transformers import PreTrainedTokenizerFast
if not isinstance(tokenizer, PreTrainedTokenizerFast) or not getattr(tokenizer, "is_fast", False):
raise RuntimeError("需要 *Fast* tokenizer 以获取 offset_mapping;请安装 tokenizers>=0.14 并使用对应 Fast 版分词器。")
tokenizer.model_max_length = args.seq_len
dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} "
f"pad_token={repr(tokenizer.pad_token)} model_max_length={tokenizer.model_max_length}")
# 2) 再加载模型 之前,先算 dtype
def _bf16_supported():
if not torch.cuda.is_available():
return False
# 兼容不同 torch 版本:优先用 API,退化到算力判断
if hasattr(torch.cuda, "is_bf16_supported"):
return torch.cuda.is_bf16_supported()
major, minor = torch.cuda.get_device_capability()
return (major, minor) >= (8, 0) # Ampere 及以上
use_bf16 = bool(args.bf16 and _bf16_supported())
dtype = (torch.bfloat16 if use_bf16 else
(torch.float16 if torch.cuda.is_available() else torch.float32))
# 交给插件做 ZeRO-Init/分片加载
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
torch_dtype=dtype,
low_cpu_mem_usage=True,
trust_remote_code=True,
attn_implementation="sdpa",
)
print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True)
dbg(f"model loaded: dtype={next(model.parameters()).dtype} "
f"use_cache={getattr(model.config,'use_cache',None)} "
f"pad_token_id={getattr(model.config,'pad_token_id',None)}")
# 3) pad/alibi 等配置
model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False
if args.gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
try:
# torch.backends.cuda.enable_flash_sdp(False)
# torch.backends.cuda.enable_mem_efficient_sdp(False)
# torch.backends.cuda.enable_math_sdp(True)
# 让 PyTorch 自己选,或显式打开高效实现(任选其一):
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(False)
except Exception:
pass
# ===== 数据鲁棒性检查(多机各自执行)=====
files = sorted(glob.glob(args.data_glob))
if len(files) == 0:
raise FileNotFoundError(
f"[host={host} rank={rank}] No files matched DATA_GLOB={args.data_glob}\n"
"每台机器都必须在相同本地路径下放置数据;"
"可通过 DATA_GLOB= ./run_ds.sh 覆写。"
)
if is_main_process():
print(f"[data] matched {len(files)} files on host={host}, example[0]={files[0]}", flush=True)
# ====== 小探针:样本结构 ======
ds_stream_probe = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
def ex_iter_probe():
for ex in ds_stream_probe:
yield ex
train_stream_probe = QwenChatSFTDataset(ex_iter_probe(), tokenizer, seq_len=args.seq_len)
try:
_ = next(iter(train_stream_probe))
except StopIteration:
raise RuntimeError(
f"[host={host} rank={rank}] 数据文件匹配到了,但没有产生任何可训练样本。\n"
"请确认每行 JSON 至少包含 'messages'(列表,含 user/assistant)字段;"
"若含 … 请确保不包含真实思维文本,或移除。\n"
"另外检查 seq_len 是否过小导致全部被裁。"
)
# # ====== 正式训练流 ======
# ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
# if world_size > 1 and len(files) >= world_size:
# # 多文件,按文件连续分片
# ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True)
# train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
# else:
# # 单文件或文件数不足,按样本取模轮转
# def ex_iter2():
# for i, ex in enumerate(ds_stream2):
# if i % max(world_size, 1) == rank:
# yield ex
# train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
# ====== 正式训练流(不做任何手动分片,交给 Accelerate/Trainer)======
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True).shuffle(buffer_size=50000, seed=args.seed)
train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
# # ====== 一致性探针(与上面保持同逻辑)=====
# ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
# if world_size > 1 and len(files) >= world_size:
# ds_stream_probe2 = ds_stream_probe2.shard(num_shards=world_size, index=rank, contiguous=True)
# probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len)
# else:
# def ex_iter2_probe():
# for i, ex in enumerate(ds_stream_probe2):
# if i % max(world_size, 1) == rank:
# yield ex
# probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len)
# ====== 一致性探针(不分片)======
ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len)
def has_at_least(stream, n: int):
it = iter(stream)
for _ in range(n):
try:
next(it)
except StopIteration:
return 0
return 1
need = max(1, args.gradient_accumulation_steps)
local_ok = has_at_least(probe_stream, need)
if dist.is_available() and dist.is_initialized():
t = torch.tensor(local_ok, device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu"))
dist.all_reduce(t, op=dist.ReduceOp.MIN)
if t.item() == 0:
if is_main_process():
print(
f"[FATAL] 至少有一个 rank 在一个优化 step 内供不上 {need} 个微批 (GA={need})。 "
f"请减少 GA 或扩大/清洗数据;本次训练不会启动。",
flush=True
)
dist.barrier()
sys.exit(2)
else:
if local_ok == 0:
print(
f"[FATAL] 本机在一个优化 step 内供不上 {need} 个微批 (GA={need})。 "
f"请减少 GA 或扩大/清洗数据;本次训练不会启动。",
flush=True
)
sys.exit(2)
# ---- Eval 构造:优先使用 --eval_data_glob;否则才用 eval_ratio 抽样 ----
eval_dataset: Optional[Dataset] = None
class ListDataset(Dataset):
def __init__(self, items): self.items = items
def __len__(self): return len(self.items)
def __getitem__(self, idx): return self.items[idx]
if args.eval_data_glob:
eval_files = sorted(glob.glob(args.eval_data_glob))
if len(eval_files) == 0:
raise FileNotFoundError(f"[host={host} rank={rank}] No eval files matched EVAL_DATA_GLOB={args.eval_data_glob}")
if is_main_process():
print(f"[eval] matched {len(eval_files)} files, example[0]={eval_files[0]}", flush=True)
ds_eval_stream = load_dataset("json", data_files={"eval": eval_files}, split="eval", streaming=True)
def ex_iter_eval():
for ex in ds_eval_stream:
yield ex
eval_iterable = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len)
eval_items: List[Dict[str, torch.Tensor]] = []
for sample in eval_iterable:
eval_items.append(sample)
if len(eval_items) == 0:
raise RuntimeError("[eval] eval_data_glob 读到了 0 条有效样本,请检查 messages 结构。")
eval_dataset = ListDataset(eval_items)
elif args.eval_ratio and args.eval_ratio > 0:
desired_eval_batches = 200
tmp_stream = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
def ex_iter_eval2():
for ex in tmp_stream:
yield ex
eval_stream = QwenChatSFTDataset(ex_iter_eval2(), tokenizer, seq_len=args.seq_len)
eval_samples = []
it = iter(eval_stream)
for _ in range(desired_eval_batches):
try:
eval_samples.append(next(it))
except StopIteration:
break
if len(eval_samples) > 0:
eval_dataset = ListDataset(eval_samples)
# ---- 统一补齐 eval 集(确保不会出现空 batch)----
if eval_dataset is not None:
ws = max(world_size, 1)
be = max(1, args.per_device_eval_batch_size)
global_bs = ws * be
r = len(eval_dataset) % global_bs
if r != 0:
pad_need = global_bs - r
eval_dataset.items += eval_dataset.items[:pad_need]
if is_main_process():
print(f"[eval] padded eval set to {len(eval_dataset)} "
f"(world_size={ws}, per_device_eval_batch_size={be}, global_bs={global_bs})",
flush=True)
# 补齐后再做 sanity check
assert len(eval_dataset) % global_bs == 0, \
f"eval size {len(eval_dataset)} still not divisible by global_bs {global_bs}"
# 更稳:联调阶段不强行 pad 到 4096
data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len)
os.makedirs(args.output_dir, exist_ok=True)
logging_dir = os.path.join(args.output_dir, "logs")
os.makedirs(logging_dir, exist_ok=True)
# ---- 兼容 4.51(eval_strategy)与旧版(evaluation_strategy) ----
ta_kwargs = {}
sig = inspect.signature(TrainingArguments.__init__).parameters
if eval_dataset is not None:
if "eval_strategy" in sig:
ta_kwargs["eval_strategy"] = "steps"
elif "evaluation_strategy" in sig:
ta_kwargs["evaluation_strategy"] = "steps"
else:
if "eval_strategy" in sig:
ta_kwargs["eval_strategy"] = "no"
elif "evaluation_strategy" in sig:
ta_kwargs["evaluation_strategy"] = "no"
ta_sig = inspect.signature(TrainingArguments.__init__).parameters
ta_kwargs2 = dict(
output_dir=args.output_dir,
logging_dir=logging_dir,
# ★ 新增:自定义 run_name,避免等于 output_dir 的 warning
run_name=f"sft-{os.path.basename(args.output_dir)}-{socket.gethostname()}",
do_train=True,
do_eval=(eval_dataset is not None),
# eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
# 用用户指定的 eval_steps;没有 eval 集就 None
eval_steps=(args.eval_steps if eval_dataset is not None else None),
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
warmup_ratio=args.warmup_ratio,
num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0,
max_steps=args.max_steps if args.max_steps > 0 else -1,
lr_scheduler_type="cosine",
logging_steps=args.log_interval,
save_steps=args.save_steps,
save_total_limit=2,
# deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None),
deepspeed=(args.deepspeed if use_ds else None),
dataloader_drop_last=False,
dataloader_num_workers=0,
per_device_eval_batch_size=args.per_device_eval_batch_size,
report_to=([] if args.report_to == "none" else [args.report_to]),
bf16=args.bf16,
fp16=(not args.bf16),
gradient_checkpointing=args.gradient_checkpointing,
remove_unused_columns=False,
save_on_each_node=True,
logging_first_step=True,
**ta_kwargs, # 你之前构造的 eval_strategy 兼容项
)
# if "dataloader_prefetch_factor" in ta_sig:
# ta_kwargs2["dataloader_prefetch_factor"] = None
if "dataloader_pin_memory" in ta_sig:
ta_kwargs2["dataloader_pin_memory"] = False
if "torch_compile" in ta_sig:
ta_kwargs2["torch_compile"] = False
# 构造 TrainingArguments 之前,沿用上面的 use_bf16 判定
ta_kwargs2.update({
"bf16": use_bf16,
"fp16": (torch.cuda.is_available() and not use_bf16),
})
training_args = TrainingArguments(**ta_kwargs2)
trainer_kwargs = {}
if "processing_class" in inspect.signature(Trainer.__init__).parameters:
trainer_kwargs["processing_class"] = tokenizer
else:
trainer_kwargs["tokenizer"] = tokenizer
trainer = DebugTrainer(
model=model,
args=training_args,
train_dataset=train_stream,
eval_dataset=eval_dataset,
#tokenizer=tokenizer,
#processing_class=tokenizer,
data_collator=data_collator,
**trainer_kwargs,
)
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
# ==== 断点恢复判定(非共享盘安全写法)====
def last_step(path: str) -> int:
ck = get_last_checkpoint(path)
if ck is None:
return -1
base = os.path.basename(ck)
try:
return int(base.split("-")[-1])
except Exception:
return -1
local_last = last_step(args.output_dir) # -1 表示本机没有任何 checkpoint
device = torch.device(f"cuda:{local_rank}" if (torch.cuda.is_available() and local_rank >= 0) else "cpu")
resume_flag = None
if dist.is_available() and dist.is_initialized():
# 只要有任意一个 rank 没有 ckpt -> 不恢复
has_local = torch.tensor(1 if local_last >= 0 else 0, device=device)
dist.all_reduce(has_local, op=dist.ReduceOp.MIN)
if has_local.item() == 1:
# 全员都有:收集每个 rank 的 last step,取公共最小步 k(每台机器都一定存在)
ts = torch.tensor(local_last, device=device)
world = dist.get_world_size()
buf = [torch.zeros_like(ts) for _ in range(world)]
dist.all_gather(buf, ts)
steps = [b.item() for b in buf]
k = min(steps)
if k >= 0:
resume_flag = os.path.join(args.output_dir, f"checkpoint-{k}")
if is_main_process():
print(f"[resume] steps={steps}, resume={resume_flag}", flush=True)
else:
# 单机或未初始化分布式:本地有就按本机最后一步恢复
if local_last >= 0:
resume_flag = os.path.join(args.output_dir, f"checkpoint-{local_last}")
print_once(f"[host={socket.gethostname()}] Resume = {resume_flag is not None}")
# —— 全局一致性检测:如果有任意 rank 缺这个 ckpt,就禁用恢复 ——
if dist.is_available() and dist.is_initialized():
device = torch.device(f"cuda:{local_rank}" if (torch.cuda.is_available() and local_rank >= 0) else "cpu")
present = torch.tensor(1 if (resume_flag is not None and os.path.isdir(resume_flag)) else 0, device=device)
dist.all_reduce(present, op=dist.ReduceOp.MIN)
if present.item() == 0:
if is_main_process():
print(f"[resume] {resume_flag} missing on some ranks -> disable resume.", flush=True)
resume_flag = None
dist.barrier()
else:
if resume_flag is not None and not os.path.isdir(resume_flag):
# 单机:缺就直接禁用恢复
print(f"[resume] {resume_flag} not found locally -> disable resume.", flush=True)
resume_flag = None
print_once(f"[resume] final = {resume_flag if resume_flag else 'None (fresh start)'}")
print_once("***** Starting training *****")
dbg(f"allocated={torch.cuda.memory_allocated()/1024**2:.1f} MB, "
f"reserved={torch.cuda.memory_reserved()/1024**2:.1f} MB")
train_result = trainer.train(resume_from_checkpoint=resume_flag)
trainer.save_model() # DeepSpeed stage3_gather_16bit_weights_on_model_save=true 时,在 rank0 聚合整模型
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
if eval_dataset is not None:
print_once("***** Running eval *****")
eval_metrics = trainer.evaluate()
trainer.log_metrics("eval", eval_metrics)
trainer.save_metrics("eval", eval_metrics)
print_once("Done.")
if __name__ == "__main__":
main()