diff --git a/train_sft_ds.py b/train_sft_ds.py index e3abcd7..1d7d1ee 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -1,20 +1,11 @@ -#!/usr/bin/env python3 import os -# 让 user-site 生效(deepspeed/torchrun 常把 PYTHONNOUSERSITE=1 带进来) os.environ.pop("PYTHONNOUSERSITE", None) - os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") - - os.environ.setdefault("WANDB_START_METHOD", "thread") os.environ.setdefault("WANDB_DIR", f"/tmp/{os.environ.get('USER','user')}/wandb") - -# ★ 新增:自建服务的 base_url(避免走默认的 cloud) os.environ.setdefault("WANDB_BASE_URL", "https://wandb.szaiai.com") -# (可选)某些版本支持这个 env;真正生效仍以下面的 Settings(init_timeout=...) 为准 os.environ.setdefault("WANDB_INIT_TIMEOUT", "300") - import glob import socket import argparse @@ -39,7 +30,7 @@ from transformers.trainer_utils import get_last_checkpoint from torch.optim import AdamW as TorchAdamW # ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ==== -import os, sys, site, shutil +import site, shutil home = os.path.expanduser("~") want = [f"{home}/.local/bin", "/usr/local/cuda-11.8/bin"] @@ -59,14 +50,9 @@ if cuda_lib not in ld.split(":"): os.environ["LD_LIBRARY_PATH"] = f"{cuda_lib}:{ld}" if ld else cuda_lib # 可视化确认 -import torch print(f"[env] torch.version.cuda={torch.version.cuda} CUDA_HOME={os.environ['CUDA_HOME']}", flush=True) -# ==== ensure python can see user site & set torch extensions dir ==== -import os, sys, site - # 1) 确保不会屏蔽用户站点包(ninja 安在 ~/.local 里) -# os.environ.pop("PYTHONNOUSERSITE", None) os.environ.pop("DS_BUILD_OPS", None) os.environ.pop("DS_SKIP_CUDA_BUILD", None) @@ -82,7 +68,6 @@ except Exception: os.environ.setdefault("TORCH_EXTENSIONS_DIR", f"/tmp/{os.environ.get('USER','user')}/torch_ext") os.environ.setdefault("MAX_JOBS", "12") -import shutil if shutil.which("ninja") is None: os.environ["USE_NINJA"] = "0" print("[env] no CLI ninja on PATH -> USE_NINJA=0 fallback", flush=True) @@ -134,7 +119,13 @@ class DebugTrainer(Trainer): flush=True ) self._dbg_printed = True - return super().training_step(model, inputs, num_items_in_batch) + + try: + return super().training_step(model, inputs, num_items_in_batch=num_items_in_batch) + except TypeError: + return super().training_step(model, inputs) + + # return super().training_step(model, inputs, num_items_in_batch) # ----------------- 日志回调 ----------------- class CsvLossLogger(TrainerCallback): @@ -145,13 +136,6 @@ class CsvLossLogger(TrainerCallback): with open(self.csv_path, "w", encoding="utf-8") as f: f.write("step,loss,lr,total_flos\n") - # def on_log(self, args, state, control, logs=None, **kwargs): - # if not is_main_process() or logs is None: - # return - # with open(self.csv_path, "a", encoding="utf-8") as f: - # f.write(f"{state.global_step},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n") - - def on_train_begin(self, args, state, control, **kwargs): tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0) tot = tmp if isinstance(tmp, int) and tmp > 0 else 0 @@ -168,10 +152,6 @@ class CsvLossLogger(TrainerCallback): # ---- 控制台打印:所有 rank 都打当前步/总步 ---- cur = int(getattr(state, "global_step", 0) or 0) - - # if getattr(args, "logging_steps", None) and cur % args.logging_steps != 0: - # return - tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0) tot = tmp if isinstance(tmp, int) and tmp > 0 else 0 pct = (f"{(cur / tot * 100):.1f}%" if tot else "n/a") @@ -181,9 +161,6 @@ class CsvLossLogger(TrainerCallback): print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] total_steps={tot}", flush=True) self._tot_announced = True - # if not is_main_process(): - # return - rank = os.environ.get("RANK", "?") host = socket.gethostname() print( @@ -200,449 +177,40 @@ class CsvLossLogger(TrainerCallback): f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n" ) -# # ----------------- 仅监督 assistant 的数据集 ----------------- -# def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]: -# """ -# 在 apply_chat_template 渲染后的文本中,返回所有 assistant 内容的字符区间 [start, end)。 -# """ -# spans: List[Tuple[int, int]] = [] -# open_tag = "<|im_start|>assistant\n" -# close_tag = "<|im_end|>\n" -# pos = 0 -# while True: -# a = rendered.find(open_tag, pos) -# if a == -1: -# break -# start = a + len(open_tag) -# b = rendered.find(close_tag, start) -# if b == -1: -# break -# spans.append((start, b)) -# pos = b + len(close_tag) -# return spans - -# class QwenChatSFTDataset(IterableDataset): -# """ -# 期望 jsonl 每行形如: -# {"messages":[{"role":"system","content":"..."},{"role":"user","content":"..."},{"role":"assistant","content":"..."}]} -# 可选包含工具: -# {"messages":[...], "tools":[{...}]} - -# 工作流: -# - 使用 tokenizer.apply_chat_template 渲染 -# - 仅对 assistant 片段计损失(其他 token 的 label = -100) -# - 超长序列保留尾部(通常包含回答) -# """ -# def __init__(self, -# ex_iter: Iterable[dict], -# tokenizer: AutoTokenizer, -# seq_len: int = 4096): -# self.ex_iter = ex_iter -# self.tok = tokenizer -# self.seq_len = seq_len - -# def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: - -# # >>> DEBUG BEGIN -# dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1" -# if not hasattr(self, "_dbg_seen"): self._dbg_seen = 0 -# dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3")) -# rank = int(os.environ.get("RANK", "0")) -# lrank = int(os.environ.get("LOCAL_RANK", "-1")) -# host = socket.gethostname() -# # >>> DEBUG END - -# for ex in self.ex_iter: -# msgs = ex.get("messages", None) -# if not msgs or not isinstance(msgs, list): -# continue - -# # 可选过滤 think -# bad = False -# for m in msgs: -# if m.get("role") == "assistant" and isinstance(m.get("content"), str): -# c = m["content"] -# if "" in c and "" in c: -# inner = c.split("")[-1].split("")[0].strip() -# if inner: -# bad = True; break -# # 注销这里就可以确保参与计算监督微调,打开就表示跳过 -# if bad: -# continue - -# tools = ex.get("tools", None) - -# # 兼容老版本 tokenizer.apply_chat_template 不支持 tools 参数的情况 -# try: -# rendered: str = self.tok.apply_chat_template( -# msgs, tools=tools, add_generation_prompt=False, tokenize=False -# ) -# except TypeError: -# rendered: str = self.tok.apply_chat_template( -# msgs, add_generation_prompt=False, tokenize=False -# ) - - -# if not isinstance(rendered, str) or not rendered.strip(): -# continue - -# spans = _assistant_char_spans(rendered) -# if not spans: -# continue - -# enc = self.tok( -# rendered, -# add_special_tokens=False, -# return_offsets_mapping=True -# ) -# input_ids: List[int] = enc["input_ids"] -# offsets: List[Tuple[int, int]] = enc["offset_mapping"] - -# if not input_ids: -# continue - -# labels = [-100] * len(input_ids) - -# def in_any_span(lo: int, hi: int) -> bool: -# for s, e in spans: -# if not (hi <= s or lo >= e): -# return True -# return False - -# for i, (lo, hi) in enumerate(offsets): -# if in_any_span(lo, hi): -# labels[i] = input_ids[i] - -# # —— 固定长度策略:先截尾,再在 Dataset 层补到固定 seq_len —— -# # 1) 截断到 seq_len(保留尾部) -# if len(input_ids) > self.seq_len: -# input_ids = input_ids[-self.seq_len:] -# labels = labels[-self.seq_len:] - -# # 2) 左侧补齐到 seq_len(保证所有样本长度一致) -# pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id -# L = len(input_ids) -# if L < self.seq_len: -# pad = self.seq_len - L -# input_ids = ([pad_id] * pad) + input_ids -# labels = ([-100] * pad) + labels -# attn_mask = [0] * pad + [1] * L -# else: -# # 恰好等于 seq_len -# attn_mask = [1] * self.seq_len - -# # 若没有任何可训练 token(labels 全 -100),跳过 -# if all(v == -100 for v in labels): -# continue - -# assert len(input_ids) == self.seq_len -# assert len(labels) == self.seq_len -# assert len(attn_mask) == self.seq_len - -# # >>> DEBUG PRINT(此时变量已定义) -# if dbg_on and self._dbg_seen < dbg_limit: -# sup_tok = sum(1 for v in labels if v != -100) -# print( -# f"[sample][host={host} RANK={rank} LRank={lrank}] " -# f"rendered_len={len(rendered)} toks={len(input_ids)} sup_toks={sup_tok} " -# f"seq_len={self.seq_len} pad_id={pad_id}", -# flush=True -# ) -# if sup_tok == 0: -# print( -# f"[WARN][host={host} RANK={rank}] sample has 0 supervised tokens -> would be skipped", -# flush=True -# ) -# self._dbg_seen += 1 -# # <<< DEBUG PRINT - -# yield { -# "input_ids": torch.tensor(input_ids, dtype=torch.long), -# "attention_mask": torch.tensor(attn_mask, dtype=torch.long), -# "labels": torch.tensor(labels, dtype=torch.long), -# } - -# # ================================= 监督 ============================================ -# # ----------------- 工具:提取 assistant 字符区间 ----------------- -# def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]: -# """ -# 在 apply_chat_template 渲染后的纯文本中,返回所有 assistant 段的字符区间 [start, end) -# 这些区间覆盖了 assistant 的全部内容(包括 ... 标签与正文)。 -# """ -# spans: List[Tuple[int, int]] = [] -# open_tag = "<|im_start|>assistant\n" -# close_tag = "<|im_end|>\n" -# pos = 0 -# while True: -# a = rendered.find(open_tag, pos) -# if a == -1: -# break -# s = a + len(open_tag) -# b = rendered.find(close_tag, s) -# if b == -1: -# break -# spans.append((s, b)) -# pos = b + len(close_tag) -# return spans - - -# # ----------------- 数据集:SFT(监督 assistant 全段,含 标签与内容) ----------------- -# class QwenChatSFTDataset(IterableDataset): -# """ -# 期望 jsonl 每行形如: -# {"messages":[{"role":"system","content":"..."},{"role":"user","content":"..."},{"role":"assistant","content":"..."}]} -# 可选包含工具: -# {"messages":[...], "tools":[{...}]} - -# 工作流: -# - 使用 tokenizer.apply_chat_template 渲染 -# - 仅对 assistant 片段计损失(其他 token 的 label = -100) -# - 截断时“优先确保最后一个 assistant 不被截断”;若其长度 > seq_len,则保留其“结尾”以避免切尾 -# """ -# def __init__(self, -# ex_iter: Iterable[dict], -# tokenizer: AutoTokenizer, -# seq_len: int = 4096): -# self.ex_iter = ex_iter -# self.tok = tokenizer -# self.seq_len = seq_len - -# def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: - -# # >>> DEBUG BEGIN -# dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1" -# if not hasattr(self, "_dbg_seen"): self._dbg_seen = 0 -# dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3")) -# rank = int(os.environ.get("RANK", "0")) -# lrank = int(os.environ.get("LOCAL_RANK", "-1")) -# host = socket.gethostname() -# # >>> DEBUG END - -# for ex in self.ex_iter: -# msgs = ex.get("messages", None) -# if not msgs or not isinstance(msgs, list): -# continue - -# # —— 不再过滤 :显式允许其参与监督(包括标签与正文) -# tools = ex.get("tools", None) - -# # 兼容老版本 tokenizer.apply_chat_template 不支持 tools 参数的情况 -# try: -# rendered: str = self.tok.apply_chat_template( -# msgs, tools=tools, add_generation_prompt=False, tokenize=False -# ) -# except TypeError: -# rendered: str = self.tok.apply_chat_template( -# msgs, add_generation_prompt=False, tokenize=False -# ) - -# if not isinstance(rendered, str) or not rendered.strip(): -# continue - -# spans = _assistant_char_spans(rendered) -# if not spans: -# continue - -# # 编码并拿到字符偏移,确保与 rendered 对齐 -# enc = self.tok( -# rendered, -# add_special_tokens=False, -# return_offsets_mapping=True -# ) -# input_ids: List[int] = enc["input_ids"] -# offsets: List[Tuple[int, int]] = enc["offset_mapping"] - -# if not input_ids: -# continue - -# # 先对“所有 assistant 片段”打标签;包含 标签与内容、以及回答正文 -# labels = [-100] * len(input_ids) - -# def in_any_span(lo: int, hi: int) -> bool: -# for s, e in spans: -# # 与任一 [s, e) 有交集即监督 -# if not (hi <= s or lo >= e): -# return True -# return False - -# for i, (lo, hi) in enumerate(offsets): -# if in_any_span(lo, hi): -# labels[i] = input_ids[i] - -# # 若没有任何可训练 token(labels 全 -100),跳过 -# if all(v == -100 for v in labels): -# continue - -# # ======== Assistant 感知的截断策略(保证“最后一个 assistant 不被截掉”)======== -# if len(input_ids) > self.seq_len: -# # 取“最后一个 assistant”的字符区间 -# s_last, e_last = spans[-1] - -# # 将字符区间映射到 token 索引区间 [j, k_excl) -# # j: 第一个 token,其右端 hi > s_last -# j = 0 -# while j < len(offsets) and offsets[j][1] <= s_last: -# j += 1 -# # k_excl: 第一个 token,其左端 lo >= e_last(即不再与 [s_last, e_last) 相交) -# k_excl = j -# while k_excl < len(offsets) and offsets[k_excl][0] < e_last: -# k_excl += 1 - -# A = max(0, k_excl - j) # 最后一个 assistant 覆盖的 token 数 - -# if A >= self.seq_len: -# # 单个 assistant 本身超过窗口 —— 保“结尾”,避免被切尾 -# start = max(0, k_excl - self.seq_len) -# end = start + self.seq_len -# else: -# # 有空间容纳整个 assistant:尽量把窗口对齐到包括完整 assistant -# # 先试图把窗口从 j 开始,但要保证 k_excl 也在窗口内 -# start = max(0, min(j, len(input_ids) - self.seq_len)) -# end = start + self.seq_len -# if end < k_excl: -# # 还没覆盖到 assistant 末尾,则右移窗口到恰好覆盖末尾 -# end = k_excl -# start = end - self.seq_len -# if start < 0: -# start = 0 -# end = self.seq_len - -# # 可选:尝试“居中”一点(留部分历史上下文),但仍需包含完整 [j, k_excl) -# leftover = self.seq_len - A -# # 把剩余的一半尽量分配给左侧上下文(不越界) -# left_wish = leftover // 2 -# start = max(0, min(j - left_wish, start)) -# end = start + self.seq_len -# if end < k_excl: -# # 若居中导致末尾又被排除,再纠正一次 -# end = k_excl -# start = end - self.seq_len -# if start < 0: -# start = 0 -# end = self.seq_len - -# # 真正切片 -# input_ids = input_ids[start:end] -# labels = labels[start:end] -# # 注意:offsets 后续不再使用(只为确定切片窗口),无需同步裁剪 - -# # 训练注意:这里的策略保证: -# # - 若最后一个 assistant <= seq_len:完整保留; -# # - 若 > seq_len:至少保证 assistant 的“结尾”在窗口内,不会“切尾”。 - -# # ======== 统一长度:左侧补齐到 seq_len ======== -# pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id -# L = len(input_ids) -# if L < self.seq_len: -# pad = self.seq_len - L -# input_ids = ([pad_id] * pad) + input_ids -# labels = ([-100] * pad) + labels -# attn_mask = [0] * pad + [1] * L -# else: -# attn_mask = [1] * self.seq_len - -# # Sanity -# assert len(input_ids) == self.seq_len -# assert len(labels) == self.seq_len -# assert len(attn_mask) == self.seq_len - -# # >>> DEBUG PRINT -# if dbg_on and self._dbg_seen < dbg_limit: -# sup_tok = sum(1 for v in labels if v != -100) -# print( -# f"[sample][host={host} RANK={rank} LRank={lrank}] " -# f"toks={len(input_ids)} sup_toks={sup_tok} seq_len={self.seq_len} pad_id={pad_id}", -# flush=True -# ) -# if sup_tok == 0: -# print( -# f"[WARN][host={host} RANK={rank}] sample has 0 supervised tokens -> would be skipped", -# flush=True -# ) -# self._dbg_seen += 1 -# # <<< DEBUG PRINT - -# yield { -# "input_ids": torch.tensor(input_ids, dtype=torch.long), -# "attention_mask": torch.tensor(attn_mask, dtype=torch.long), -# "labels": torch.tensor(labels, dtype=torch.long), -# } - -# # ================================= end ============================================ - - -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -import os -import socket from typing import List, Tuple, Iterable, Iterator, Dict -import torch -from torch.utils.data import IterableDataset -from transformers import AutoTokenizer # 仅作类型提示/引用,不强依赖 - - -# # ----------------- 工具:提取 assistant 字符区间 ----------------- +# ----------------- 工具:提取 assistant 字符区间 ----------------- # def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]: -# """ -# 在 apply_chat_template 渲染后的文本中,返回所有 assistant 内容的字符区间 [start, end)。 -# (覆盖了 assistant 的全部内容,包括其中可能出现的 ) -# """ # spans: List[Tuple[int, int]] = [] # open_tag = "<|im_start|>assistant\n" -# close_tag = "<|im_end|>\n" +# close_token = "<|im_end|>" +# close_tag = close_token + "\n" # 常见模板带换行 + # pos = 0 # while True: # a = rendered.find(open_tag, pos) # if a == -1: # break # start = a + len(open_tag) + +# # 先找含换行版本,找不到再退化找不带换行的 # b = rendered.find(close_tag, start) # if b == -1: -# break +# b = rendered.find(close_token, start) +# if b == -1: +# break -# end = b + len("<|im_end|>") +# end = b + len(close_token) # 把 <|im_end|> 本体纳入监督 # spans.append((start, end)) -# # spans.append((start, b)) -# pos = b + len(close_tag) + +# # pos 跳过这一轮结束标记(带换行就多跳一格) +# pos = b + (len(close_tag) if rendered.startswith(close_tag, b) else len(close_token)) # return spans -# ----------------- 工具:提取 assistant 字符区间 ----------------- -def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]: - spans: List[Tuple[int, int]] = [] - open_tag = "<|im_start|>assistant\n" - close_token = "<|im_end|>" - close_tag = close_token + "\n" # 常见模板带换行 - - pos = 0 - while True: - a = rendered.find(open_tag, pos) - if a == -1: - break - start = a + len(open_tag) - - # 先找含换行版本,找不到再退化找不带换行的 - b = rendered.find(close_tag, start) - if b == -1: - b = rendered.find(close_token, start) - if b == -1: - break - - end = b + len(close_token) # 把 <|im_end|> 本体纳入监督 - spans.append((start, end)) - - # pos 跳过这一轮结束标记(带换行就多跳一格) - pos = b + (len(close_tag) if rendered.startswith(close_tag, b) else len(close_token)) - return spans - - - -# ----------------- 工具:提取所有 的字符区间(包含标签本身) ----------------- # def _think_char_spans(rendered: str) -> List[Tuple[int, int]]: # """ -# 纯 str.find 实现,不用正则。 -# 返回全局的 区间列表,坐标为 rendered 上的绝对位置。 +# 返回需要忽略监督的区间(仅 ... 的“内部”), +# 标签本身 仍参与监督,以便模型学会闭合。 # """ # spans: List[Tuple[int, int]] = [] # open_tag = "" @@ -655,236 +223,214 @@ def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]: # b = rendered.find(close_tag, a + len(open_tag)) # if b == -1: # break -# spans.append((a, b + len(close_tag))) +# # 只忽略内部,不忽略两侧标签 +# spans.append((a + len(open_tag), b)) # pos = b + len(close_tag) # return spans -def _think_char_spans(rendered: str) -> List[Tuple[int, int]]: - """ - 返回需要忽略监督的区间(仅 ... 的“内部”), - 标签本身 仍参与监督,以便模型学会闭合。 - """ - spans: List[Tuple[int, int]] = [] - open_tag = "" - close_tag = "" - pos = 0 - while True: - a = rendered.find(open_tag, pos) - if a == -1: - break - b = rendered.find(close_tag, a + len(open_tag)) - if b == -1: - break - # 只忽略内部,不忽略两侧标签 - spans.append((a + len(open_tag), b)) - pos = b + len(close_tag) - return spans - - -# ----------------- 仅监督 assistant 的数据集(忽略 ) ----------------- +# ----------------- 仅监督 assistant 内容(token-id 级,不用 offsets) ----------------- class QwenChatSFTDataset(IterableDataset): """ - 期望 jsonl 每行形如: - {"messages":[{"role":"system","content":"..."},{"role":"user","content":"..."},{"role":"assistant","content":"..."}]} - 可选包含工具: - {"messages":[...], "tools":[{...}]} - - 工作流: - - 使用 tokenizer.apply_chat_template 渲染(可带 tools) - - 仅对 assistant 片段计损失;凡落在 内的 token,labels 置 -100(不监督) - - 超长序列保留尾部(通常包含回答),再左侧补齐到固定长度 + - 通过 chat_template 得到 token ids + - 以 special token id 定位 assistant 片段(<|im_start|>assistant\n ... <|im_end|>) + - 只监督 assistant 内容本体;默认把 (含标签)整体屏蔽 + - 超长时保最后一个 assistant 片段完整,左侧补齐到 seq_len """ def __init__(self, ex_iter: Iterable[dict], tokenizer: AutoTokenizer, - seq_len: int = 4096): + seq_len: int = 4096, + mask_think_and_tags: bool = True): self.ex_iter = ex_iter self.tok = tokenizer self.seq_len = seq_len + self.mask_think_and_tags = mask_think_and_tags + + # 关键标记的 token 序列 + self.id_START = self.tok.convert_tokens_to_ids("<|im_start|>") + self.id_END = self.tok.convert_tokens_to_ids("<|im_end|>") + # self.ids_ASSISTANT_NL = self.tok.encode("assistant\n", add_special_tokens=False) + # 支持两种常见写法:'assistant\\n' 或 'assistant' + self.ids_ASSISTANT_CANDIDATES = [ + self.tok.encode("assistant\n", add_special_tokens=False), + self.tok.encode("assistant", add_special_tokens=False), + ] + # 过滤空候选(极端 tokenizer 配置) + self.ids_ASSISTANT_CANDIDATES = [c for c in self.ids_ASSISTANT_CANDIDATES if len(c) > 0] + + if not self.ids_ASSISTANT_CANDIDATES: + raise RuntimeError("[fatal] no valid 'assistant' role token sequence found; check chat template/tokenizer.") + + + self.ids_THINK_OPEN = self.tok.encode("", add_special_tokens=False) + self.ids_THINK_CLOSE = self.tok.encode("", add_special_tokens=False) + + # 兜底:有些模型未注册这些特殊 id 时,直接 fail-fast + for name, val in { + "id_START": self.id_START, "id_END": self.id_END + }.items(): + if val is None or val == self.tok.unk_token_id: + raise RuntimeError(f"[fatal] tokenizer missing special token id for {name}") + + @staticmethod + def _find_subseq(hay: list, needle: list, start: int) -> int: + n = len(needle) + if n == 0: return start + for i in range(start, len(hay) - n + 1): + if hay[i:i+n] == needle: + return i + return -1 + + def _find_role_after_start(self, ids, j_start: int) -> Optional[Tuple[int, int]]: + """ + 从 j_start 开始,尝试匹配任一 'assistant' 角色 token 序列。 + 返回 (pos, length);匹配失败返回 None。 + """ + for cand in self.ids_ASSISTANT_CANDIDATES: + pos = self._find_subseq(ids, cand, j_start) + if pos == j_start: + return (pos, len(cand)) + return None def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: - - # >>> DEBUG 开关 + # 调试开关 dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1" - if not hasattr(self, "_dbg_seen"): - self._dbg_seen = 0 dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3")) + seen = 0 + host = socket.gethostname() rank = int(os.environ.get("RANK", "0")) lrank = int(os.environ.get("LOCAL_RANK", "-1")) - host = socket.gethostname() - # <<< DEBUG for ex in self.ex_iter: - msgs = ex.get("messages", None) + msgs = ex.get("messages") if not msgs or not isinstance(msgs, list): continue - tools = ex.get("tools", None) - # 兼容老版本 tokenizer.apply_chat_template 不支持 tools 参数的情况 + # 直接让模板 tokenization -> ids(避免 offset 落坑) try: - rendered: str = self.tok.apply_chat_template( - msgs, tools=tools, add_generation_prompt=False, tokenize=False + ids = self.tok.apply_chat_template( + msgs, tools=tools, add_generation_prompt=False, + tokenize=True, return_tensors=None ) + # 兼容老版本返回 dict 的情况 + if isinstance(ids, dict): + ids = ids["input_ids"] except TypeError: + # 极端回退:先渲染字符串再手动分词 rendered: str = self.tok.apply_chat_template( msgs, add_generation_prompt=False, tokenize=False ) + ids = self.tok(rendered, add_special_tokens=False)["input_ids"] - if not isinstance(rendered, str) or not rendered.strip(): + if not ids: continue - # —— 样本级终止符:确保训练时每条样本以 eos 结束 —— - # if self.tok.eos_token and not rendered.endswith(self.tok.eos_token): - # rendered += self.tok.eos_token + # 构建监督掩码(0/1) + mask = [0] * len(ids) + i = 0 + while i < len(ids): + # 找到一个 <|im_start|> + try: + a = ids.index(self.id_START, i) + except ValueError: + break + # 必须是 assistant 角色(兼容 'assistant\\n' 或 'assistant') + j = a + 1 + role_match = self._find_role_after_start(ids, j) + if role_match is None: + i = a + 1 + continue + _, role_len = role_match + content_lo = j + role_len # 跳过角色 token 序列 - # —— 样本级终止符:把 eos 插入到最后一个 assistant 的 <|im_end|>\n 之前 —— - # if self.tok.eos_token: - # open_tag = "<|im_start|>assistant\n" - # close_tag = "<|im_end|>\n" - # head, sep, tail = rendered.rpartition(close_tag) # 按最后一个 close_tag 切 - # if sep: # 找到了收尾标记 - # # 仅当 assistant 文本末尾还没有 eos 才插入,避免重复 - # if not head.endswith(self.tok.eos_token): - # rendered = head + self.tok.eos_token + sep + tail + # 找匹配的 <|im_end|> + try: + b = ids.index(self.id_END, content_lo) + except ValueError: + # 不闭合就放弃这个片段 + i = a + 1 + continue + content_hi = b # 不含 END + # 先把整个内容区间标 1(监督) + for t in range(content_lo, content_hi): + mask[t] = 1 - # 1) 找到所有 assistant 区间 & 全局 think 区间 - asst_spans = _assistant_char_spans(rendered) - if not asst_spans: - continue - think_spans = _think_char_spans(rendered) + # 可选:把 (含标签)整体屏蔽 + if self.mask_think_and_tags: + p = content_lo + while True: + o = self._find_subseq(ids, self.ids_THINK_OPEN, p) + if o == -1 or o >= content_hi: + break + c = self._find_subseq(ids, self.ids_THINK_CLOSE, o + len(self.ids_THINK_OPEN)) + if c == -1 or c > content_hi: + break + x_lo = o # 含 + x_hi = c + len(self.ids_THINK_CLOSE) # 含 + for t in range(x_lo, min(x_hi, content_hi)): + mask[t] = 0 + p = x_hi - # 2) 编码 & offset 对齐 - enc = self.tok( - rendered, - add_special_tokens=False, - return_offsets_mapping=True - ) - input_ids = enc["input_ids"] - offsets = enc["offset_mapping"] - if not input_ids: + # 继续找下一个片段 + i = b + 1 + + # 如果没有任何可监督 token,跳过 + if not any(mask): continue - # 3) 仅监督 assistant 片段,且排除落在 think 区间内的 token - labels = [-100] * len(input_ids) - - def in_any_span(lo: int, hi: int, intervals: List[Tuple[int, int]]) -> bool: - for s, e in intervals: - # 有交集即为 True - if not (hi <= s or lo >= e): - return True - return False - - for i, (lo, hi) in enumerate(offsets): - if in_any_span(lo, hi, asst_spans) and not in_any_span(lo, hi, think_spans): - labels[i] = input_ids[i] - - # —— 固定长度策略:先截尾(保留尾部),再左侧补齐 —— - # if len(input_ids) > self.seq_len: - # input_ids = input_ids[-self.seq_len:] - # labels = labels[-self.seq_len:] - - # ======== 助手感知的截断策略:尽量保证“最后一个 assistant 片段”完整 ======== - if len(input_ids) > self.seq_len: - # 取最后一个 assistant 的字符区间([s_last, e_last)) - s_last, e_last = asst_spans[-1] - - # 用 offsets 把字符区间映射到 token 索引区间 [j, k_excl) - j = 0 - while j < len(offsets) and offsets[j][1] <= s_last: - j += 1 - k_excl = j - while k_excl < len(offsets) and offsets[k_excl][0] < e_last: - k_excl += 1 - - A = max(0, k_excl - j) # 最后一个 assistant 覆盖的 token 数 - - if A >= self.seq_len: - # 单个 assistant 本身超过窗口 —— 保“结尾”,避免切尾 - start = max(0, k_excl - self.seq_len) - end = start + self.seq_len - else: - # 有空间容纳整个 assistant:让窗口覆盖完整 [j, k_excl) - start = max(0, min(j, len(input_ids) - self.seq_len)) - end = start + self.seq_len - if end < k_excl: - end = k_excl - start = max(0, end - self.seq_len) - - # 可选:尝试“居中”一点(给最后一个 assistant 左右留些上下文) - leftover = self.seq_len - A - left_wish = leftover // 2 - start = max(0, min(j - left_wish, start)) - end = start + self.seq_len - if end < k_excl: - end = k_excl - start = max(0, end - self.seq_len) - - # 真正切片 - input_ids = input_ids[start:end] - labels = labels[start:end] + # ======== 截断策略:优先保留“最后一个被监督 token”为终点 ======== + if len(ids) > self.seq_len: + last_on = max(idx for idx, v in enumerate(mask) if v == 1) + end = min(len(ids), last_on + 1) + start = max(0, end - self.seq_len) + ids = ids[start:end] + mask = mask[start:end] + # ======== 左侧 pad ======== pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id - L = len(input_ids) + L = len(ids) if L < self.seq_len: pad = self.seq_len - L - input_ids = ([pad_id] * pad) + input_ids - labels = ([-100] * pad) + labels - attn_mask = [0] * pad + [1] * L + input_ids = [pad_id] * pad + ids + attention_mask = [0] * pad + [1] * L + labels = [-100] * pad + [tok if m == 1 else -100 for tok, m in zip(ids, mask)] else: - attn_mask = [1] * self.seq_len + input_ids = ids + attention_mask = [1] * self.seq_len + labels = [tok if m == 1 else -100 for tok, m in zip(ids, mask)] - # 若没有任何可训练 token(labels 全 -100),跳过 - if all(v == -100 for v in labels): - continue - - # Sanity - assert len(input_ids) == self.seq_len - assert len(labels) == self.seq_len - assert len(attn_mask) == self.seq_len - - # >>> DEBUG - if dbg_on and self._dbg_seen < dbg_limit: + # >>> 调试打印(可选) + if dbg_on and seen < dbg_limit: sup_tok = sum(1 for v in labels if v != -100) print( f"[sample][host={host} RANK={rank} LRank={lrank}] " - f"rendered_len={len(rendered)} toks={len(input_ids)} sup_toks={sup_tok} " + f"toks={len(input_ids)} sup_toks={sup_tok} " f"seq_len={self.seq_len} pad_id={pad_id}", flush=True ) - if sup_tok == 0: - print( - f"[WARN][host={host} RANK={rank}] sample has 0 supervised tokens -> skipped", - flush=True - ) - self._dbg_seen += 1 - # <<< DEBUG + seen += 1 yield { "input_ids": torch.tensor(input_ids, dtype=torch.long), - "attention_mask": torch.tensor(attn_mask, dtype=torch.long), + "attention_mask": torch.tensor(attention_mask, dtype=torch.long), "labels": torch.tensor(labels, dtype=torch.long), } - - -# ----------------- 专用 Collator:pad inputs, pad labels=-100 ----------------- +# ----------------- Collator(保持与上游一致:pad->label=-100, attn=0) ----------------- class SFTDataCollator: def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None): self.tok = tokenizer self.pad_to_length = pad_to_length - assert self.tok.pad_token_id is not None + assert self.tok.pad_token_id is not None, "tokenizer.pad_token_id must be set" - def __call__(self, features): + def __call__(self, features): if not features: - raise RuntimeError( - f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. " - f"Check dataset sharding/streaming." - ) + raise RuntimeError("Empty batch passed to collator") def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x) input_ids = [_to_list(f["input_ids"]) for f in features] @@ -905,24 +451,13 @@ class SFTDataCollator: batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long)) batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long)) - # >>> DEBUG BEGIN - dbg_on = os.environ.get("DBG_COLLATE", "0") == "1" - if dbg_on: - rank = int(os.environ.get("RANK", "0")) - host = socket.gethostname() - bs = len(features) - first_len = len(input_ids[0]) if bs > 0 else None - print( - f"[collate][host={host} RANK={rank}] features={bs} " - f"target_len={target_len} first_len={first_len}", - flush=True - ) return { "input_ids": torch.stack(batch_inp, dim=0), "attention_mask": torch.stack(batch_attn, dim=0), "labels": torch.stack(batch_lab, dim=0), } + # ----------------- 参数 ----------------- def parse_args(): ap = argparse.ArgumentParser() @@ -942,7 +477,6 @@ def parse_args(): ap.add_argument("--save_steps", type=int, default=500) ap.add_argument("--eval_ratio", type=float, default=0.0) ap.add_argument("--seed", type=int, default=1337) - #ap.add_argument("--deepspeed", type=str, default="ds_config_zero3.json") ap.add_argument("--gradient_checkpointing", action="store_true") ap.add_argument("--bf16", action="store_true", help="3090/A100/H100 等可开 bf16;同时在 DS 配置里也要开") @@ -1101,9 +635,14 @@ def main(): pass # 强制要求 fast tokenizer(offset_mapping 依赖 fast) - from transformers import PreTrainedTokenizerFast - if not isinstance(tokenizer, PreTrainedTokenizerFast) or not getattr(tokenizer, "is_fast", False): - raise RuntimeError("需要 *Fast* tokenizer 以获取 offset_mapping;请安装 tokenizers>=0.14 并使用对应 Fast 版分词器。") + # from transformers import PreTrainedTokenizerFast + # if not isinstance(tokenizer, PreTrainedTokenizerFast) or not getattr(tokenizer, "is_fast", False): + # raise RuntimeError("需要 *Fast* tokenizer 以获取 offset_mapping;请安装 tokenizers>=0.14 并使用对应 Fast 版分词器。") + + # 建议使用 fast 分词器(更快);不再依赖 offset_mapping + if not getattr(tokenizer, "is_fast", False): + print("[warn] using a slow tokenizer; masks are token-id based and still correct, just slower.", flush=True) + tokenizer.model_max_length = args.seq_len @@ -1124,6 +663,15 @@ def main(): dtype = (torch.bfloat16 if use_bf16 else (torch.float16 if torch.cuda.is_available() else torch.float32)) + + try: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + except Exception: + pass + + # 交给插件做 ZeRO-Init/分片加载 model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, @@ -1140,14 +688,14 @@ def main(): # 3) pad/alibi 等配置 model.config.pad_token_id = tokenizer.pad_token_id + + if getattr(model, "generation_config", None) is not None: + model.generation_config.pad_token_id = tokenizer.pad_token_id + model.config.use_cache = False if args.gradient_checkpointing: model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) try: - # torch.backends.cuda.enable_flash_sdp(False) - # torch.backends.cuda.enable_mem_efficient_sdp(False) - # torch.backends.cuda.enable_math_sdp(True) - # 让 PyTorch 自己选,或显式打开高效实现(任选其一): torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) @@ -1172,8 +720,9 @@ def main(): for ex in ds_stream_probe: yield ex train_stream_probe = QwenChatSFTDataset(ex_iter_probe(), tokenizer, seq_len=args.seq_len) + try: - _ = next(iter(train_stream_probe)) + sample = next(iter(train_stream_probe)) except StopIteration: raise RuntimeError( f"[host={host} rank={rank}] 数据文件匹配到了,但没有产生任何可训练样本。\n" @@ -1182,36 +731,17 @@ def main(): "另外检查 seq_len 是否过小导致全部被裁。" ) - # # ====== 正式训练流 ====== - # ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) - # if world_size > 1 and len(files) >= world_size: - # # 多文件,按文件连续分片 - # ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True) - # train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len) - # else: - # # 单文件或文件数不足,按样本取模轮转 - # def ex_iter2(): - # for i, ex in enumerate(ds_stream2): - # if i % max(world_size, 1) == rank: - # yield ex - # train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len) + # 更靠谱的自检(替换你现在的两行 assert) + ids, attn, labs = sample["input_ids"], sample["attention_mask"], sample["labels"] + assert (labs != -100).any(), "[fatal] no supervised tokens in first valid sample" + # pad 区必须被忽略监督 + assert bool((labs[attn == 0] == -100).all()), "[fatal] padded tokens must have label -100" + # ====== 正式训练流(不做任何手动分片,交给 Accelerate/Trainer)====== ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True).shuffle(buffer_size=50000, seed=args.seed) train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len) - # # ====== 一致性探针(与上面保持同逻辑)===== - # ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) - # if world_size > 1 and len(files) >= world_size: - # ds_stream_probe2 = ds_stream_probe2.shard(num_shards=world_size, index=rank, contiguous=True) - # probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len) - # else: - # def ex_iter2_probe(): - # for i, ex in enumerate(ds_stream_probe2): - # if i % max(world_size, 1) == rank: - # yield ex - # probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len) - # ====== 一致性探针(不分片)====== ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len) @@ -1318,7 +848,7 @@ def main(): f"eval size {len(eval_dataset)} still not divisible by global_bs {global_bs}" # 更稳:联调阶段不强行 pad 到 4096 - data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len) + data_collator = SFTDataCollator(tokenizer, pad_to_length=None) os.makedirs(args.output_dir, exist_ok=True) logging_dir = os.path.join(args.output_dir, "logs") @@ -1342,7 +872,6 @@ def main(): ta_kwargs2 = dict( output_dir=args.output_dir, logging_dir=logging_dir, - # ★ 新增:自定义 run_name,避免等于 output_dir 的 warning run_name=f"sft-{os.path.basename(args.output_dir)}-{socket.gethostname()}", do_train=True, do_eval=(eval_dataset is not None), @@ -1360,22 +889,21 @@ def main(): logging_steps=args.log_interval, save_steps=args.save_steps, save_total_limit=2, - # deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None), deepspeed=(args.deepspeed if use_ds else None), dataloader_drop_last=False, dataloader_num_workers=0, + label_smoothing_factor=0.0, per_device_eval_batch_size=args.per_device_eval_batch_size, report_to=([] if args.report_to == "none" else [args.report_to]), - bf16=args.bf16, - fp16=(not args.bf16), + #bf16=args.bf16, + #fp16=(not args.bf16), gradient_checkpointing=args.gradient_checkpointing, remove_unused_columns=False, save_on_each_node=True, logging_first_step=True, **ta_kwargs, # 你之前构造的 eval_strategy 兼容项 ) - # if "dataloader_prefetch_factor" in ta_sig: - # ta_kwargs2["dataloader_prefetch_factor"] = None + if "dataloader_pin_memory" in ta_sig: ta_kwargs2["dataloader_pin_memory"] = False if "torch_compile" in ta_sig: @@ -1401,8 +929,6 @@ def main(): args=training_args, train_dataset=train_stream, eval_dataset=eval_dataset, - #tokenizer=tokenizer, - #processing_class=tokenizer, data_collator=data_collator, **trainer_kwargs, ) @@ -1467,7 +993,6 @@ def main(): print(f"[resume] {resume_flag} not found locally -> disable resume.", flush=True) resume_flag = None - print_once(f"[resume] final = {resume_flag if resume_flag else 'None (fresh start)'}") print_once("***** Starting training *****")