diff --git a/train_sft_ds.py b/train_sft_ds.py index 20672f9..34777d5 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -577,29 +577,61 @@ from torch.utils.data import IterableDataset from transformers import AutoTokenizer # 仅作类型提示/引用,不强依赖 +# # ----------------- 工具:提取 assistant 字符区间 ----------------- +# def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]: +# """ +# 在 apply_chat_template 渲染后的文本中,返回所有 assistant 内容的字符区间 [start, end)。 +# (覆盖了 assistant 的全部内容,包括其中可能出现的 ) +# """ +# spans: List[Tuple[int, int]] = [] +# open_tag = "<|im_start|>assistant\n" +# close_tag = "<|im_end|>\n" +# pos = 0 +# while True: +# a = rendered.find(open_tag, pos) +# if a == -1: +# break +# start = a + len(open_tag) +# b = rendered.find(close_tag, start) +# if b == -1: +# break + +# end = b + len("<|im_end|>") +# spans.append((start, end)) +# # spans.append((start, b)) +# pos = b + len(close_tag) +# return spans + # ----------------- 工具:提取 assistant 字符区间 ----------------- def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]: - """ - 在 apply_chat_template 渲染后的文本中,返回所有 assistant 内容的字符区间 [start, end)。 - (覆盖了 assistant 的全部内容,包括其中可能出现的 ) - """ spans: List[Tuple[int, int]] = [] open_tag = "<|im_start|>assistant\n" - close_tag = "<|im_end|>\n" + close_token = "<|im_end|>" + close_tag = close_token + "\n" # 常见模板带换行 + pos = 0 while True: a = rendered.find(open_tag, pos) if a == -1: break start = a + len(open_tag) + + # 先找含换行版本,找不到再退化找不带换行的 b = rendered.find(close_tag, start) if b == -1: - break - spans.append((start, b)) - pos = b + len(close_tag) + b = rendered.find(close_token, start) + if b == -1: + break + + end = b + len(close_token) # 把 <|im_end|> 本体纳入监督 + spans.append((start, end)) + + # pos 跳过这一轮结束标记(带换行就多跳一格) + pos = b + (len(close_tag) if rendered.startswith(close_tag, b) else len(close_token)) return spans + # ----------------- 工具:提取所有 的字符区间(包含标签本身) ----------------- def _think_char_spans(rendered: str) -> List[Tuple[int, int]]: """ @@ -675,6 +707,20 @@ class QwenChatSFTDataset(IterableDataset): if not isinstance(rendered, str) or not rendered.strip(): continue + # —— 样本级终止符:确保训练时每条样本以 eos 结束 —— + # if self.tok.eos_token and not rendered.endswith(self.tok.eos_token): + # rendered += self.tok.eos_token + # —— 样本级终止符:把 eos 插入到最后一个 assistant 的 <|im_end|>\n 之前 —— + if self.tok.eos_token: + open_tag = "<|im_start|>assistant\n" + close_tag = "<|im_end|>\n" + head, sep, tail = rendered.rpartition(close_tag) # 按最后一个 close_tag 切 + if sep: # 找到了收尾标记 + # 仅当 assistant 文本末尾还没有 eos 才插入,避免重复 + if not head.endswith(self.tok.eos_token): + rendered = head + self.tok.eos_token + sep + tail + + # 1) 找到所有 assistant 区间 & 全局 think 区间 asst_spans = _assistant_char_spans(rendered) if not asst_spans: @@ -1047,7 +1093,7 @@ def main(): # train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len) # ====== 正式训练流(不做任何手动分片,交给 Accelerate/Trainer)====== - ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) + ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True).shuffle(buffer_size=50000, seed=args.seed) train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len) # # ====== 一致性探针(与上面保持同逻辑)=====