diff --git a/train_sft_ds.py b/train_sft_ds.py
index e3abcd7..1d7d1ee 100644
--- a/train_sft_ds.py
+++ b/train_sft_ds.py
@@ -1,20 +1,11 @@
-#!/usr/bin/env python3
import os
-# 让 user-site 生效(deepspeed/torchrun 常把 PYTHONNOUSERSITE=1 带进来)
os.environ.pop("PYTHONNOUSERSITE", None)
-
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
-
-
os.environ.setdefault("WANDB_START_METHOD", "thread")
os.environ.setdefault("WANDB_DIR", f"/tmp/{os.environ.get('USER','user')}/wandb")
-
-# ★ 新增:自建服务的 base_url(避免走默认的 cloud)
os.environ.setdefault("WANDB_BASE_URL", "https://wandb.szaiai.com")
-# (可选)某些版本支持这个 env;真正生效仍以下面的 Settings(init_timeout=...) 为准
os.environ.setdefault("WANDB_INIT_TIMEOUT", "300")
-
import glob
import socket
import argparse
@@ -39,7 +30,7 @@ from transformers.trainer_utils import get_last_checkpoint
from torch.optim import AdamW as TorchAdamW
# ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ====
-import os, sys, site, shutil
+import site, shutil
home = os.path.expanduser("~")
want = [f"{home}/.local/bin", "/usr/local/cuda-11.8/bin"]
@@ -59,14 +50,9 @@ if cuda_lib not in ld.split(":"):
os.environ["LD_LIBRARY_PATH"] = f"{cuda_lib}:{ld}" if ld else cuda_lib
# 可视化确认
-import torch
print(f"[env] torch.version.cuda={torch.version.cuda} CUDA_HOME={os.environ['CUDA_HOME']}", flush=True)
-# ==== ensure python can see user site & set torch extensions dir ====
-import os, sys, site
-
# 1) 确保不会屏蔽用户站点包(ninja 安在 ~/.local 里)
-# os.environ.pop("PYTHONNOUSERSITE", None)
os.environ.pop("DS_BUILD_OPS", None)
os.environ.pop("DS_SKIP_CUDA_BUILD", None)
@@ -82,7 +68,6 @@ except Exception:
os.environ.setdefault("TORCH_EXTENSIONS_DIR", f"/tmp/{os.environ.get('USER','user')}/torch_ext")
os.environ.setdefault("MAX_JOBS", "12")
-import shutil
if shutil.which("ninja") is None:
os.environ["USE_NINJA"] = "0"
print("[env] no CLI ninja on PATH -> USE_NINJA=0 fallback", flush=True)
@@ -134,7 +119,13 @@ class DebugTrainer(Trainer):
flush=True
)
self._dbg_printed = True
- return super().training_step(model, inputs, num_items_in_batch)
+
+ try:
+ return super().training_step(model, inputs, num_items_in_batch=num_items_in_batch)
+ except TypeError:
+ return super().training_step(model, inputs)
+
+ # return super().training_step(model, inputs, num_items_in_batch)
# ----------------- 日志回调 -----------------
class CsvLossLogger(TrainerCallback):
@@ -145,13 +136,6 @@ class CsvLossLogger(TrainerCallback):
with open(self.csv_path, "w", encoding="utf-8") as f:
f.write("step,loss,lr,total_flos\n")
- # def on_log(self, args, state, control, logs=None, **kwargs):
- # if not is_main_process() or logs is None:
- # return
- # with open(self.csv_path, "a", encoding="utf-8") as f:
- # f.write(f"{state.global_step},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n")
-
-
def on_train_begin(self, args, state, control, **kwargs):
tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0)
tot = tmp if isinstance(tmp, int) and tmp > 0 else 0
@@ -168,10 +152,6 @@ class CsvLossLogger(TrainerCallback):
# ---- 控制台打印:所有 rank 都打当前步/总步 ----
cur = int(getattr(state, "global_step", 0) or 0)
-
- # if getattr(args, "logging_steps", None) and cur % args.logging_steps != 0:
- # return
-
tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0)
tot = tmp if isinstance(tmp, int) and tmp > 0 else 0
pct = (f"{(cur / tot * 100):.1f}%" if tot else "n/a")
@@ -181,9 +161,6 @@ class CsvLossLogger(TrainerCallback):
print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] total_steps={tot}", flush=True)
self._tot_announced = True
- # if not is_main_process():
- # return
-
rank = os.environ.get("RANK", "?")
host = socket.gethostname()
print(
@@ -200,449 +177,40 @@ class CsvLossLogger(TrainerCallback):
f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n"
)
-# # ----------------- 仅监督 assistant 的数据集 -----------------
-# def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
-# """
-# 在 apply_chat_template 渲染后的文本中,返回所有 assistant 内容的字符区间 [start, end)。
-# """
-# spans: List[Tuple[int, int]] = []
-# open_tag = "<|im_start|>assistant\n"
-# close_tag = "<|im_end|>\n"
-# pos = 0
-# while True:
-# a = rendered.find(open_tag, pos)
-# if a == -1:
-# break
-# start = a + len(open_tag)
-# b = rendered.find(close_tag, start)
-# if b == -1:
-# break
-# spans.append((start, b))
-# pos = b + len(close_tag)
-# return spans
-
-# class QwenChatSFTDataset(IterableDataset):
-# """
-# 期望 jsonl 每行形如:
-# {"messages":[{"role":"system","content":"..."},{"role":"user","content":"..."},{"role":"assistant","content":"..."}]}
-# 可选包含工具:
-# {"messages":[...], "tools":[{...}]}
-
-# 工作流:
-# - 使用 tokenizer.apply_chat_template 渲染
-# - 仅对 assistant 片段计损失(其他 token 的 label = -100)
-# - 超长序列保留尾部(通常包含回答)
-# """
-# def __init__(self,
-# ex_iter: Iterable[dict],
-# tokenizer: AutoTokenizer,
-# seq_len: int = 4096):
-# self.ex_iter = ex_iter
-# self.tok = tokenizer
-# self.seq_len = seq_len
-
-# def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
-
-# # >>> DEBUG BEGIN
-# dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1"
-# if not hasattr(self, "_dbg_seen"): self._dbg_seen = 0
-# dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3"))
-# rank = int(os.environ.get("RANK", "0"))
-# lrank = int(os.environ.get("LOCAL_RANK", "-1"))
-# host = socket.gethostname()
-# # >>> DEBUG END
-
-# for ex in self.ex_iter:
-# msgs = ex.get("messages", None)
-# if not msgs or not isinstance(msgs, list):
-# continue
-
-# # 可选过滤 think
-# bad = False
-# for m in msgs:
-# if m.get("role") == "assistant" and isinstance(m.get("content"), str):
-# c = m["content"]
-# if "" in c and "" in c:
-# inner = c.split("")[-1].split("")[0].strip()
-# if inner:
-# bad = True; break
-# # 注销这里就可以确保参与计算监督微调,打开就表示跳过
-# if bad:
-# continue
-
-# tools = ex.get("tools", None)
-
-# # 兼容老版本 tokenizer.apply_chat_template 不支持 tools 参数的情况
-# try:
-# rendered: str = self.tok.apply_chat_template(
-# msgs, tools=tools, add_generation_prompt=False, tokenize=False
-# )
-# except TypeError:
-# rendered: str = self.tok.apply_chat_template(
-# msgs, add_generation_prompt=False, tokenize=False
-# )
-
-
-# if not isinstance(rendered, str) or not rendered.strip():
-# continue
-
-# spans = _assistant_char_spans(rendered)
-# if not spans:
-# continue
-
-# enc = self.tok(
-# rendered,
-# add_special_tokens=False,
-# return_offsets_mapping=True
-# )
-# input_ids: List[int] = enc["input_ids"]
-# offsets: List[Tuple[int, int]] = enc["offset_mapping"]
-
-# if not input_ids:
-# continue
-
-# labels = [-100] * len(input_ids)
-
-# def in_any_span(lo: int, hi: int) -> bool:
-# for s, e in spans:
-# if not (hi <= s or lo >= e):
-# return True
-# return False
-
-# for i, (lo, hi) in enumerate(offsets):
-# if in_any_span(lo, hi):
-# labels[i] = input_ids[i]
-
-# # —— 固定长度策略:先截尾,再在 Dataset 层补到固定 seq_len ——
-# # 1) 截断到 seq_len(保留尾部)
-# if len(input_ids) > self.seq_len:
-# input_ids = input_ids[-self.seq_len:]
-# labels = labels[-self.seq_len:]
-
-# # 2) 左侧补齐到 seq_len(保证所有样本长度一致)
-# pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id
-# L = len(input_ids)
-# if L < self.seq_len:
-# pad = self.seq_len - L
-# input_ids = ([pad_id] * pad) + input_ids
-# labels = ([-100] * pad) + labels
-# attn_mask = [0] * pad + [1] * L
-# else:
-# # 恰好等于 seq_len
-# attn_mask = [1] * self.seq_len
-
-# # 若没有任何可训练 token(labels 全 -100),跳过
-# if all(v == -100 for v in labels):
-# continue
-
-# assert len(input_ids) == self.seq_len
-# assert len(labels) == self.seq_len
-# assert len(attn_mask) == self.seq_len
-
-# # >>> DEBUG PRINT(此时变量已定义)
-# if dbg_on and self._dbg_seen < dbg_limit:
-# sup_tok = sum(1 for v in labels if v != -100)
-# print(
-# f"[sample][host={host} RANK={rank} LRank={lrank}] "
-# f"rendered_len={len(rendered)} toks={len(input_ids)} sup_toks={sup_tok} "
-# f"seq_len={self.seq_len} pad_id={pad_id}",
-# flush=True
-# )
-# if sup_tok == 0:
-# print(
-# f"[WARN][host={host} RANK={rank}] sample has 0 supervised tokens -> would be skipped",
-# flush=True
-# )
-# self._dbg_seen += 1
-# # <<< DEBUG PRINT
-
-# yield {
-# "input_ids": torch.tensor(input_ids, dtype=torch.long),
-# "attention_mask": torch.tensor(attn_mask, dtype=torch.long),
-# "labels": torch.tensor(labels, dtype=torch.long),
-# }
-
-# # ================================= 监督 ============================================
-# # ----------------- 工具:提取 assistant 字符区间 -----------------
-# def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
-# """
-# 在 apply_chat_template 渲染后的纯文本中,返回所有 assistant 段的字符区间 [start, end)
-# 这些区间覆盖了 assistant 的全部内容(包括 ... 标签与正文)。
-# """
-# spans: List[Tuple[int, int]] = []
-# open_tag = "<|im_start|>assistant\n"
-# close_tag = "<|im_end|>\n"
-# pos = 0
-# while True:
-# a = rendered.find(open_tag, pos)
-# if a == -1:
-# break
-# s = a + len(open_tag)
-# b = rendered.find(close_tag, s)
-# if b == -1:
-# break
-# spans.append((s, b))
-# pos = b + len(close_tag)
-# return spans
-
-
-# # ----------------- 数据集:SFT(监督 assistant 全段,含 标签与内容) -----------------
-# class QwenChatSFTDataset(IterableDataset):
-# """
-# 期望 jsonl 每行形如:
-# {"messages":[{"role":"system","content":"..."},{"role":"user","content":"..."},{"role":"assistant","content":"..."}]}
-# 可选包含工具:
-# {"messages":[...], "tools":[{...}]}
-
-# 工作流:
-# - 使用 tokenizer.apply_chat_template 渲染
-# - 仅对 assistant 片段计损失(其他 token 的 label = -100)
-# - 截断时“优先确保最后一个 assistant 不被截断”;若其长度 > seq_len,则保留其“结尾”以避免切尾
-# """
-# def __init__(self,
-# ex_iter: Iterable[dict],
-# tokenizer: AutoTokenizer,
-# seq_len: int = 4096):
-# self.ex_iter = ex_iter
-# self.tok = tokenizer
-# self.seq_len = seq_len
-
-# def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
-
-# # >>> DEBUG BEGIN
-# dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1"
-# if not hasattr(self, "_dbg_seen"): self._dbg_seen = 0
-# dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3"))
-# rank = int(os.environ.get("RANK", "0"))
-# lrank = int(os.environ.get("LOCAL_RANK", "-1"))
-# host = socket.gethostname()
-# # >>> DEBUG END
-
-# for ex in self.ex_iter:
-# msgs = ex.get("messages", None)
-# if not msgs or not isinstance(msgs, list):
-# continue
-
-# # —— 不再过滤 :显式允许其参与监督(包括标签与正文)
-# tools = ex.get("tools", None)
-
-# # 兼容老版本 tokenizer.apply_chat_template 不支持 tools 参数的情况
-# try:
-# rendered: str = self.tok.apply_chat_template(
-# msgs, tools=tools, add_generation_prompt=False, tokenize=False
-# )
-# except TypeError:
-# rendered: str = self.tok.apply_chat_template(
-# msgs, add_generation_prompt=False, tokenize=False
-# )
-
-# if not isinstance(rendered, str) or not rendered.strip():
-# continue
-
-# spans = _assistant_char_spans(rendered)
-# if not spans:
-# continue
-
-# # 编码并拿到字符偏移,确保与 rendered 对齐
-# enc = self.tok(
-# rendered,
-# add_special_tokens=False,
-# return_offsets_mapping=True
-# )
-# input_ids: List[int] = enc["input_ids"]
-# offsets: List[Tuple[int, int]] = enc["offset_mapping"]
-
-# if not input_ids:
-# continue
-
-# # 先对“所有 assistant 片段”打标签;包含 标签与内容、以及回答正文
-# labels = [-100] * len(input_ids)
-
-# def in_any_span(lo: int, hi: int) -> bool:
-# for s, e in spans:
-# # 与任一 [s, e) 有交集即监督
-# if not (hi <= s or lo >= e):
-# return True
-# return False
-
-# for i, (lo, hi) in enumerate(offsets):
-# if in_any_span(lo, hi):
-# labels[i] = input_ids[i]
-
-# # 若没有任何可训练 token(labels 全 -100),跳过
-# if all(v == -100 for v in labels):
-# continue
-
-# # ======== Assistant 感知的截断策略(保证“最后一个 assistant 不被截掉”)========
-# if len(input_ids) > self.seq_len:
-# # 取“最后一个 assistant”的字符区间
-# s_last, e_last = spans[-1]
-
-# # 将字符区间映射到 token 索引区间 [j, k_excl)
-# # j: 第一个 token,其右端 hi > s_last
-# j = 0
-# while j < len(offsets) and offsets[j][1] <= s_last:
-# j += 1
-# # k_excl: 第一个 token,其左端 lo >= e_last(即不再与 [s_last, e_last) 相交)
-# k_excl = j
-# while k_excl < len(offsets) and offsets[k_excl][0] < e_last:
-# k_excl += 1
-
-# A = max(0, k_excl - j) # 最后一个 assistant 覆盖的 token 数
-
-# if A >= self.seq_len:
-# # 单个 assistant 本身超过窗口 —— 保“结尾”,避免被切尾
-# start = max(0, k_excl - self.seq_len)
-# end = start + self.seq_len
-# else:
-# # 有空间容纳整个 assistant:尽量把窗口对齐到包括完整 assistant
-# # 先试图把窗口从 j 开始,但要保证 k_excl 也在窗口内
-# start = max(0, min(j, len(input_ids) - self.seq_len))
-# end = start + self.seq_len
-# if end < k_excl:
-# # 还没覆盖到 assistant 末尾,则右移窗口到恰好覆盖末尾
-# end = k_excl
-# start = end - self.seq_len
-# if start < 0:
-# start = 0
-# end = self.seq_len
-
-# # 可选:尝试“居中”一点(留部分历史上下文),但仍需包含完整 [j, k_excl)
-# leftover = self.seq_len - A
-# # 把剩余的一半尽量分配给左侧上下文(不越界)
-# left_wish = leftover // 2
-# start = max(0, min(j - left_wish, start))
-# end = start + self.seq_len
-# if end < k_excl:
-# # 若居中导致末尾又被排除,再纠正一次
-# end = k_excl
-# start = end - self.seq_len
-# if start < 0:
-# start = 0
-# end = self.seq_len
-
-# # 真正切片
-# input_ids = input_ids[start:end]
-# labels = labels[start:end]
-# # 注意:offsets 后续不再使用(只为确定切片窗口),无需同步裁剪
-
-# # 训练注意:这里的策略保证:
-# # - 若最后一个 assistant <= seq_len:完整保留;
-# # - 若 > seq_len:至少保证 assistant 的“结尾”在窗口内,不会“切尾”。
-
-# # ======== 统一长度:左侧补齐到 seq_len ========
-# pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id
-# L = len(input_ids)
-# if L < self.seq_len:
-# pad = self.seq_len - L
-# input_ids = ([pad_id] * pad) + input_ids
-# labels = ([-100] * pad) + labels
-# attn_mask = [0] * pad + [1] * L
-# else:
-# attn_mask = [1] * self.seq_len
-
-# # Sanity
-# assert len(input_ids) == self.seq_len
-# assert len(labels) == self.seq_len
-# assert len(attn_mask) == self.seq_len
-
-# # >>> DEBUG PRINT
-# if dbg_on and self._dbg_seen < dbg_limit:
-# sup_tok = sum(1 for v in labels if v != -100)
-# print(
-# f"[sample][host={host} RANK={rank} LRank={lrank}] "
-# f"toks={len(input_ids)} sup_toks={sup_tok} seq_len={self.seq_len} pad_id={pad_id}",
-# flush=True
-# )
-# if sup_tok == 0:
-# print(
-# f"[WARN][host={host} RANK={rank}] sample has 0 supervised tokens -> would be skipped",
-# flush=True
-# )
-# self._dbg_seen += 1
-# # <<< DEBUG PRINT
-
-# yield {
-# "input_ids": torch.tensor(input_ids, dtype=torch.long),
-# "attention_mask": torch.tensor(attn_mask, dtype=torch.long),
-# "labels": torch.tensor(labels, dtype=torch.long),
-# }
-
-# # ================================= end ============================================
-
-
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-import os
-import socket
from typing import List, Tuple, Iterable, Iterator, Dict
-import torch
-from torch.utils.data import IterableDataset
-from transformers import AutoTokenizer # 仅作类型提示/引用,不强依赖
-
-
-# # ----------------- 工具:提取 assistant 字符区间 -----------------
+# ----------------- 工具:提取 assistant 字符区间 -----------------
# def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
-# """
-# 在 apply_chat_template 渲染后的文本中,返回所有 assistant 内容的字符区间 [start, end)。
-# (覆盖了 assistant 的全部内容,包括其中可能出现的 …)
-# """
# spans: List[Tuple[int, int]] = []
# open_tag = "<|im_start|>assistant\n"
-# close_tag = "<|im_end|>\n"
+# close_token = "<|im_end|>"
+# close_tag = close_token + "\n" # 常见模板带换行
+
# pos = 0
# while True:
# a = rendered.find(open_tag, pos)
# if a == -1:
# break
# start = a + len(open_tag)
+
+# # 先找含换行版本,找不到再退化找不带换行的
# b = rendered.find(close_tag, start)
# if b == -1:
-# break
+# b = rendered.find(close_token, start)
+# if b == -1:
+# break
-# end = b + len("<|im_end|>")
+# end = b + len(close_token) # 把 <|im_end|> 本体纳入监督
# spans.append((start, end))
-# # spans.append((start, b))
-# pos = b + len(close_tag)
+
+# # pos 跳过这一轮结束标记(带换行就多跳一格)
+# pos = b + (len(close_tag) if rendered.startswith(close_tag, b) else len(close_token))
# return spans
-# ----------------- 工具:提取 assistant 字符区间 -----------------
-def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
- spans: List[Tuple[int, int]] = []
- open_tag = "<|im_start|>assistant\n"
- close_token = "<|im_end|>"
- close_tag = close_token + "\n" # 常见模板带换行
-
- pos = 0
- while True:
- a = rendered.find(open_tag, pos)
- if a == -1:
- break
- start = a + len(open_tag)
-
- # 先找含换行版本,找不到再退化找不带换行的
- b = rendered.find(close_tag, start)
- if b == -1:
- b = rendered.find(close_token, start)
- if b == -1:
- break
-
- end = b + len(close_token) # 把 <|im_end|> 本体纳入监督
- spans.append((start, end))
-
- # pos 跳过这一轮结束标记(带换行就多跳一格)
- pos = b + (len(close_tag) if rendered.startswith(close_tag, b) else len(close_token))
- return spans
-
-
-
-# ----------------- 工具:提取所有 … 的字符区间(包含标签本身) -----------------
# def _think_char_spans(rendered: str) -> List[Tuple[int, int]]:
# """
-# 纯 str.find 实现,不用正则。
-# 返回全局的 … 区间列表,坐标为 rendered 上的绝对位置。
+# 返回需要忽略监督的区间(仅 ... 的“内部”),
+# 标签本身 与 仍参与监督,以便模型学会闭合。
# """
# spans: List[Tuple[int, int]] = []
# open_tag = ""
@@ -655,236 +223,214 @@ def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
# b = rendered.find(close_tag, a + len(open_tag))
# if b == -1:
# break
-# spans.append((a, b + len(close_tag)))
+# # 只忽略内部,不忽略两侧标签
+# spans.append((a + len(open_tag), b))
# pos = b + len(close_tag)
# return spans
-def _think_char_spans(rendered: str) -> List[Tuple[int, int]]:
- """
- 返回需要忽略监督的区间(仅 ... 的“内部”),
- 标签本身 与 仍参与监督,以便模型学会闭合。
- """
- spans: List[Tuple[int, int]] = []
- open_tag = ""
- close_tag = ""
- pos = 0
- while True:
- a = rendered.find(open_tag, pos)
- if a == -1:
- break
- b = rendered.find(close_tag, a + len(open_tag))
- if b == -1:
- break
- # 只忽略内部,不忽略两侧标签
- spans.append((a + len(open_tag), b))
- pos = b + len(close_tag)
- return spans
-
-
-# ----------------- 仅监督 assistant 的数据集(忽略 …) -----------------
+# ----------------- 仅监督 assistant 内容(token-id 级,不用 offsets) -----------------
class QwenChatSFTDataset(IterableDataset):
"""
- 期望 jsonl 每行形如:
- {"messages":[{"role":"system","content":"..."},{"role":"user","content":"..."},{"role":"assistant","content":"..."}]}
- 可选包含工具:
- {"messages":[...], "tools":[{...}]}
-
- 工作流:
- - 使用 tokenizer.apply_chat_template 渲染(可带 tools)
- - 仅对 assistant 片段计损失;凡落在 … 内的 token,labels 置 -100(不监督)
- - 超长序列保留尾部(通常包含回答),再左侧补齐到固定长度
+ - 通过 chat_template 得到 token ids
+ - 以 special token id 定位 assistant 片段(<|im_start|>assistant\n ... <|im_end|>)
+ - 只监督 assistant 内容本体;默认把 …(含标签)整体屏蔽
+ - 超长时保最后一个 assistant 片段完整,左侧补齐到 seq_len
"""
def __init__(self,
ex_iter: Iterable[dict],
tokenizer: AutoTokenizer,
- seq_len: int = 4096):
+ seq_len: int = 4096,
+ mask_think_and_tags: bool = True):
self.ex_iter = ex_iter
self.tok = tokenizer
self.seq_len = seq_len
+ self.mask_think_and_tags = mask_think_and_tags
+
+ # 关键标记的 token 序列
+ self.id_START = self.tok.convert_tokens_to_ids("<|im_start|>")
+ self.id_END = self.tok.convert_tokens_to_ids("<|im_end|>")
+ # self.ids_ASSISTANT_NL = self.tok.encode("assistant\n", add_special_tokens=False)
+ # 支持两种常见写法:'assistant\\n' 或 'assistant'
+ self.ids_ASSISTANT_CANDIDATES = [
+ self.tok.encode("assistant\n", add_special_tokens=False),
+ self.tok.encode("assistant", add_special_tokens=False),
+ ]
+ # 过滤空候选(极端 tokenizer 配置)
+ self.ids_ASSISTANT_CANDIDATES = [c for c in self.ids_ASSISTANT_CANDIDATES if len(c) > 0]
+
+ if not self.ids_ASSISTANT_CANDIDATES:
+ raise RuntimeError("[fatal] no valid 'assistant' role token sequence found; check chat template/tokenizer.")
+
+
+ self.ids_THINK_OPEN = self.tok.encode("", add_special_tokens=False)
+ self.ids_THINK_CLOSE = self.tok.encode("", add_special_tokens=False)
+
+ # 兜底:有些模型未注册这些特殊 id 时,直接 fail-fast
+ for name, val in {
+ "id_START": self.id_START, "id_END": self.id_END
+ }.items():
+ if val is None or val == self.tok.unk_token_id:
+ raise RuntimeError(f"[fatal] tokenizer missing special token id for {name}")
+
+ @staticmethod
+ def _find_subseq(hay: list, needle: list, start: int) -> int:
+ n = len(needle)
+ if n == 0: return start
+ for i in range(start, len(hay) - n + 1):
+ if hay[i:i+n] == needle:
+ return i
+ return -1
+
+ def _find_role_after_start(self, ids, j_start: int) -> Optional[Tuple[int, int]]:
+ """
+ 从 j_start 开始,尝试匹配任一 'assistant' 角色 token 序列。
+ 返回 (pos, length);匹配失败返回 None。
+ """
+ for cand in self.ids_ASSISTANT_CANDIDATES:
+ pos = self._find_subseq(ids, cand, j_start)
+ if pos == j_start:
+ return (pos, len(cand))
+ return None
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
-
- # >>> DEBUG 开关
+ # 调试开关
dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1"
- if not hasattr(self, "_dbg_seen"):
- self._dbg_seen = 0
dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3"))
+ seen = 0
+ host = socket.gethostname()
rank = int(os.environ.get("RANK", "0"))
lrank = int(os.environ.get("LOCAL_RANK", "-1"))
- host = socket.gethostname()
- # <<< DEBUG
for ex in self.ex_iter:
- msgs = ex.get("messages", None)
+ msgs = ex.get("messages")
if not msgs or not isinstance(msgs, list):
continue
-
tools = ex.get("tools", None)
- # 兼容老版本 tokenizer.apply_chat_template 不支持 tools 参数的情况
+ # 直接让模板 tokenization -> ids(避免 offset 落坑)
try:
- rendered: str = self.tok.apply_chat_template(
- msgs, tools=tools, add_generation_prompt=False, tokenize=False
+ ids = self.tok.apply_chat_template(
+ msgs, tools=tools, add_generation_prompt=False,
+ tokenize=True, return_tensors=None
)
+ # 兼容老版本返回 dict 的情况
+ if isinstance(ids, dict):
+ ids = ids["input_ids"]
except TypeError:
+ # 极端回退:先渲染字符串再手动分词
rendered: str = self.tok.apply_chat_template(
msgs, add_generation_prompt=False, tokenize=False
)
+ ids = self.tok(rendered, add_special_tokens=False)["input_ids"]
- if not isinstance(rendered, str) or not rendered.strip():
+ if not ids:
continue
- # —— 样本级终止符:确保训练时每条样本以 eos 结束 ——
- # if self.tok.eos_token and not rendered.endswith(self.tok.eos_token):
- # rendered += self.tok.eos_token
+ # 构建监督掩码(0/1)
+ mask = [0] * len(ids)
+ i = 0
+ while i < len(ids):
+ # 找到一个 <|im_start|>
+ try:
+ a = ids.index(self.id_START, i)
+ except ValueError:
+ break
+ # 必须是 assistant 角色(兼容 'assistant\\n' 或 'assistant')
+ j = a + 1
+ role_match = self._find_role_after_start(ids, j)
+ if role_match is None:
+ i = a + 1
+ continue
+ _, role_len = role_match
+ content_lo = j + role_len # 跳过角色 token 序列
- # —— 样本级终止符:把 eos 插入到最后一个 assistant 的 <|im_end|>\n 之前 ——
- # if self.tok.eos_token:
- # open_tag = "<|im_start|>assistant\n"
- # close_tag = "<|im_end|>\n"
- # head, sep, tail = rendered.rpartition(close_tag) # 按最后一个 close_tag 切
- # if sep: # 找到了收尾标记
- # # 仅当 assistant 文本末尾还没有 eos 才插入,避免重复
- # if not head.endswith(self.tok.eos_token):
- # rendered = head + self.tok.eos_token + sep + tail
+ # 找匹配的 <|im_end|>
+ try:
+ b = ids.index(self.id_END, content_lo)
+ except ValueError:
+ # 不闭合就放弃这个片段
+ i = a + 1
+ continue
+ content_hi = b # 不含 END
+ # 先把整个内容区间标 1(监督)
+ for t in range(content_lo, content_hi):
+ mask[t] = 1
- # 1) 找到所有 assistant 区间 & 全局 think 区间
- asst_spans = _assistant_char_spans(rendered)
- if not asst_spans:
- continue
- think_spans = _think_char_spans(rendered)
+ # 可选:把 …(含标签)整体屏蔽
+ if self.mask_think_and_tags:
+ p = content_lo
+ while True:
+ o = self._find_subseq(ids, self.ids_THINK_OPEN, p)
+ if o == -1 or o >= content_hi:
+ break
+ c = self._find_subseq(ids, self.ids_THINK_CLOSE, o + len(self.ids_THINK_OPEN))
+ if c == -1 or c > content_hi:
+ break
+ x_lo = o # 含
+ x_hi = c + len(self.ids_THINK_CLOSE) # 含
+ for t in range(x_lo, min(x_hi, content_hi)):
+ mask[t] = 0
+ p = x_hi
- # 2) 编码 & offset 对齐
- enc = self.tok(
- rendered,
- add_special_tokens=False,
- return_offsets_mapping=True
- )
- input_ids = enc["input_ids"]
- offsets = enc["offset_mapping"]
- if not input_ids:
+ # 继续找下一个片段
+ i = b + 1
+
+ # 如果没有任何可监督 token,跳过
+ if not any(mask):
continue
- # 3) 仅监督 assistant 片段,且排除落在 think 区间内的 token
- labels = [-100] * len(input_ids)
-
- def in_any_span(lo: int, hi: int, intervals: List[Tuple[int, int]]) -> bool:
- for s, e in intervals:
- # 有交集即为 True
- if not (hi <= s or lo >= e):
- return True
- return False
-
- for i, (lo, hi) in enumerate(offsets):
- if in_any_span(lo, hi, asst_spans) and not in_any_span(lo, hi, think_spans):
- labels[i] = input_ids[i]
-
- # —— 固定长度策略:先截尾(保留尾部),再左侧补齐 ——
- # if len(input_ids) > self.seq_len:
- # input_ids = input_ids[-self.seq_len:]
- # labels = labels[-self.seq_len:]
-
- # ======== 助手感知的截断策略:尽量保证“最后一个 assistant 片段”完整 ========
- if len(input_ids) > self.seq_len:
- # 取最后一个 assistant 的字符区间([s_last, e_last))
- s_last, e_last = asst_spans[-1]
-
- # 用 offsets 把字符区间映射到 token 索引区间 [j, k_excl)
- j = 0
- while j < len(offsets) and offsets[j][1] <= s_last:
- j += 1
- k_excl = j
- while k_excl < len(offsets) and offsets[k_excl][0] < e_last:
- k_excl += 1
-
- A = max(0, k_excl - j) # 最后一个 assistant 覆盖的 token 数
-
- if A >= self.seq_len:
- # 单个 assistant 本身超过窗口 —— 保“结尾”,避免切尾
- start = max(0, k_excl - self.seq_len)
- end = start + self.seq_len
- else:
- # 有空间容纳整个 assistant:让窗口覆盖完整 [j, k_excl)
- start = max(0, min(j, len(input_ids) - self.seq_len))
- end = start + self.seq_len
- if end < k_excl:
- end = k_excl
- start = max(0, end - self.seq_len)
-
- # 可选:尝试“居中”一点(给最后一个 assistant 左右留些上下文)
- leftover = self.seq_len - A
- left_wish = leftover // 2
- start = max(0, min(j - left_wish, start))
- end = start + self.seq_len
- if end < k_excl:
- end = k_excl
- start = max(0, end - self.seq_len)
-
- # 真正切片
- input_ids = input_ids[start:end]
- labels = labels[start:end]
+ # ======== 截断策略:优先保留“最后一个被监督 token”为终点 ========
+ if len(ids) > self.seq_len:
+ last_on = max(idx for idx, v in enumerate(mask) if v == 1)
+ end = min(len(ids), last_on + 1)
+ start = max(0, end - self.seq_len)
+ ids = ids[start:end]
+ mask = mask[start:end]
+ # ======== 左侧 pad ========
pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id
- L = len(input_ids)
+ L = len(ids)
if L < self.seq_len:
pad = self.seq_len - L
- input_ids = ([pad_id] * pad) + input_ids
- labels = ([-100] * pad) + labels
- attn_mask = [0] * pad + [1] * L
+ input_ids = [pad_id] * pad + ids
+ attention_mask = [0] * pad + [1] * L
+ labels = [-100] * pad + [tok if m == 1 else -100 for tok, m in zip(ids, mask)]
else:
- attn_mask = [1] * self.seq_len
+ input_ids = ids
+ attention_mask = [1] * self.seq_len
+ labels = [tok if m == 1 else -100 for tok, m in zip(ids, mask)]
- # 若没有任何可训练 token(labels 全 -100),跳过
- if all(v == -100 for v in labels):
- continue
-
- # Sanity
- assert len(input_ids) == self.seq_len
- assert len(labels) == self.seq_len
- assert len(attn_mask) == self.seq_len
-
- # >>> DEBUG
- if dbg_on and self._dbg_seen < dbg_limit:
+ # >>> 调试打印(可选)
+ if dbg_on and seen < dbg_limit:
sup_tok = sum(1 for v in labels if v != -100)
print(
f"[sample][host={host} RANK={rank} LRank={lrank}] "
- f"rendered_len={len(rendered)} toks={len(input_ids)} sup_toks={sup_tok} "
+ f"toks={len(input_ids)} sup_toks={sup_tok} "
f"seq_len={self.seq_len} pad_id={pad_id}",
flush=True
)
- if sup_tok == 0:
- print(
- f"[WARN][host={host} RANK={rank}] sample has 0 supervised tokens -> skipped",
- flush=True
- )
- self._dbg_seen += 1
- # <<< DEBUG
+ seen += 1
yield {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
- "attention_mask": torch.tensor(attn_mask, dtype=torch.long),
+ "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long),
}
-
-
-# ----------------- 专用 Collator:pad inputs, pad labels=-100 -----------------
+# ----------------- Collator(保持与上游一致:pad->label=-100, attn=0) -----------------
class SFTDataCollator:
def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None):
self.tok = tokenizer
self.pad_to_length = pad_to_length
- assert self.tok.pad_token_id is not None
+ assert self.tok.pad_token_id is not None, "tokenizer.pad_token_id must be set"
- def __call__(self, features):
+ def __call__(self, features):
if not features:
- raise RuntimeError(
- f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. "
- f"Check dataset sharding/streaming."
- )
+ raise RuntimeError("Empty batch passed to collator")
def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x)
input_ids = [_to_list(f["input_ids"]) for f in features]
@@ -905,24 +451,13 @@ class SFTDataCollator:
batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long))
batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long))
- # >>> DEBUG BEGIN
- dbg_on = os.environ.get("DBG_COLLATE", "0") == "1"
- if dbg_on:
- rank = int(os.environ.get("RANK", "0"))
- host = socket.gethostname()
- bs = len(features)
- first_len = len(input_ids[0]) if bs > 0 else None
- print(
- f"[collate][host={host} RANK={rank}] features={bs} "
- f"target_len={target_len} first_len={first_len}",
- flush=True
- )
return {
"input_ids": torch.stack(batch_inp, dim=0),
"attention_mask": torch.stack(batch_attn, dim=0),
"labels": torch.stack(batch_lab, dim=0),
}
+
# ----------------- 参数 -----------------
def parse_args():
ap = argparse.ArgumentParser()
@@ -942,7 +477,6 @@ def parse_args():
ap.add_argument("--save_steps", type=int, default=500)
ap.add_argument("--eval_ratio", type=float, default=0.0)
ap.add_argument("--seed", type=int, default=1337)
- #ap.add_argument("--deepspeed", type=str, default="ds_config_zero3.json")
ap.add_argument("--gradient_checkpointing", action="store_true")
ap.add_argument("--bf16", action="store_true",
help="3090/A100/H100 等可开 bf16;同时在 DS 配置里也要开")
@@ -1101,9 +635,14 @@ def main():
pass
# 强制要求 fast tokenizer(offset_mapping 依赖 fast)
- from transformers import PreTrainedTokenizerFast
- if not isinstance(tokenizer, PreTrainedTokenizerFast) or not getattr(tokenizer, "is_fast", False):
- raise RuntimeError("需要 *Fast* tokenizer 以获取 offset_mapping;请安装 tokenizers>=0.14 并使用对应 Fast 版分词器。")
+ # from transformers import PreTrainedTokenizerFast
+ # if not isinstance(tokenizer, PreTrainedTokenizerFast) or not getattr(tokenizer, "is_fast", False):
+ # raise RuntimeError("需要 *Fast* tokenizer 以获取 offset_mapping;请安装 tokenizers>=0.14 并使用对应 Fast 版分词器。")
+
+ # 建议使用 fast 分词器(更快);不再依赖 offset_mapping
+ if not getattr(tokenizer, "is_fast", False):
+ print("[warn] using a slow tokenizer; masks are token-id based and still correct, just slower.", flush=True)
+
tokenizer.model_max_length = args.seq_len
@@ -1124,6 +663,15 @@ def main():
dtype = (torch.bfloat16 if use_bf16 else
(torch.float16 if torch.cuda.is_available() else torch.float32))
+
+ try:
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+ torch.set_float32_matmul_precision("high")
+ except Exception:
+ pass
+
+
# 交给插件做 ZeRO-Init/分片加载
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
@@ -1140,14 +688,14 @@ def main():
# 3) pad/alibi 等配置
model.config.pad_token_id = tokenizer.pad_token_id
+
+ if getattr(model, "generation_config", None) is not None:
+ model.generation_config.pad_token_id = tokenizer.pad_token_id
+
model.config.use_cache = False
if args.gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
try:
- # torch.backends.cuda.enable_flash_sdp(False)
- # torch.backends.cuda.enable_mem_efficient_sdp(False)
- # torch.backends.cuda.enable_math_sdp(True)
-
# 让 PyTorch 自己选,或显式打开高效实现(任选其一):
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
@@ -1172,8 +720,9 @@ def main():
for ex in ds_stream_probe:
yield ex
train_stream_probe = QwenChatSFTDataset(ex_iter_probe(), tokenizer, seq_len=args.seq_len)
+
try:
- _ = next(iter(train_stream_probe))
+ sample = next(iter(train_stream_probe))
except StopIteration:
raise RuntimeError(
f"[host={host} rank={rank}] 数据文件匹配到了,但没有产生任何可训练样本。\n"
@@ -1182,36 +731,17 @@ def main():
"另外检查 seq_len 是否过小导致全部被裁。"
)
- # # ====== 正式训练流 ======
- # ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
- # if world_size > 1 and len(files) >= world_size:
- # # 多文件,按文件连续分片
- # ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True)
- # train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
- # else:
- # # 单文件或文件数不足,按样本取模轮转
- # def ex_iter2():
- # for i, ex in enumerate(ds_stream2):
- # if i % max(world_size, 1) == rank:
- # yield ex
- # train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
+ # 更靠谱的自检(替换你现在的两行 assert)
+ ids, attn, labs = sample["input_ids"], sample["attention_mask"], sample["labels"]
+ assert (labs != -100).any(), "[fatal] no supervised tokens in first valid sample"
+ # pad 区必须被忽略监督
+ assert bool((labs[attn == 0] == -100).all()), "[fatal] padded tokens must have label -100"
+
# ====== 正式训练流(不做任何手动分片,交给 Accelerate/Trainer)======
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True).shuffle(buffer_size=50000, seed=args.seed)
train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
- # # ====== 一致性探针(与上面保持同逻辑)=====
- # ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
- # if world_size > 1 and len(files) >= world_size:
- # ds_stream_probe2 = ds_stream_probe2.shard(num_shards=world_size, index=rank, contiguous=True)
- # probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len)
- # else:
- # def ex_iter2_probe():
- # for i, ex in enumerate(ds_stream_probe2):
- # if i % max(world_size, 1) == rank:
- # yield ex
- # probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len)
-
# ====== 一致性探针(不分片)======
ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len)
@@ -1318,7 +848,7 @@ def main():
f"eval size {len(eval_dataset)} still not divisible by global_bs {global_bs}"
# 更稳:联调阶段不强行 pad 到 4096
- data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len)
+ data_collator = SFTDataCollator(tokenizer, pad_to_length=None)
os.makedirs(args.output_dir, exist_ok=True)
logging_dir = os.path.join(args.output_dir, "logs")
@@ -1342,7 +872,6 @@ def main():
ta_kwargs2 = dict(
output_dir=args.output_dir,
logging_dir=logging_dir,
- # ★ 新增:自定义 run_name,避免等于 output_dir 的 warning
run_name=f"sft-{os.path.basename(args.output_dir)}-{socket.gethostname()}",
do_train=True,
do_eval=(eval_dataset is not None),
@@ -1360,22 +889,21 @@ def main():
logging_steps=args.log_interval,
save_steps=args.save_steps,
save_total_limit=2,
- # deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None),
deepspeed=(args.deepspeed if use_ds else None),
dataloader_drop_last=False,
dataloader_num_workers=0,
+ label_smoothing_factor=0.0,
per_device_eval_batch_size=args.per_device_eval_batch_size,
report_to=([] if args.report_to == "none" else [args.report_to]),
- bf16=args.bf16,
- fp16=(not args.bf16),
+ #bf16=args.bf16,
+ #fp16=(not args.bf16),
gradient_checkpointing=args.gradient_checkpointing,
remove_unused_columns=False,
save_on_each_node=True,
logging_first_step=True,
**ta_kwargs, # 你之前构造的 eval_strategy 兼容项
)
- # if "dataloader_prefetch_factor" in ta_sig:
- # ta_kwargs2["dataloader_prefetch_factor"] = None
+
if "dataloader_pin_memory" in ta_sig:
ta_kwargs2["dataloader_pin_memory"] = False
if "torch_compile" in ta_sig:
@@ -1401,8 +929,6 @@ def main():
args=training_args,
train_dataset=train_stream,
eval_dataset=eval_dataset,
- #tokenizer=tokenizer,
- #processing_class=tokenizer,
data_collator=data_collator,
**trainer_kwargs,
)
@@ -1467,7 +993,6 @@ def main():
print(f"[resume] {resume_flag} not found locally -> disable resume.", flush=True)
resume_flag = None
-
print_once(f"[resume] final = {resume_flag if resume_flag else 'None (fresh start)'}")
print_once("***** Starting training *****")