diff --git a/train_sft_ds.py b/train_sft_ds.py
index 20672f9..34777d5 100644
--- a/train_sft_ds.py
+++ b/train_sft_ds.py
@@ -577,29 +577,61 @@ from torch.utils.data import IterableDataset
from transformers import AutoTokenizer # 仅作类型提示/引用,不强依赖
+# # ----------------- 工具:提取 assistant 字符区间 -----------------
+# def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
+# """
+# 在 apply_chat_template 渲染后的文本中,返回所有 assistant 内容的字符区间 [start, end)。
+# (覆盖了 assistant 的全部内容,包括其中可能出现的 …)
+# """
+# spans: List[Tuple[int, int]] = []
+# open_tag = "<|im_start|>assistant\n"
+# close_tag = "<|im_end|>\n"
+# pos = 0
+# while True:
+# a = rendered.find(open_tag, pos)
+# if a == -1:
+# break
+# start = a + len(open_tag)
+# b = rendered.find(close_tag, start)
+# if b == -1:
+# break
+
+# end = b + len("<|im_end|>")
+# spans.append((start, end))
+# # spans.append((start, b))
+# pos = b + len(close_tag)
+# return spans
+
# ----------------- 工具:提取 assistant 字符区间 -----------------
def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
- """
- 在 apply_chat_template 渲染后的文本中,返回所有 assistant 内容的字符区间 [start, end)。
- (覆盖了 assistant 的全部内容,包括其中可能出现的 …)
- """
spans: List[Tuple[int, int]] = []
open_tag = "<|im_start|>assistant\n"
- close_tag = "<|im_end|>\n"
+ close_token = "<|im_end|>"
+ close_tag = close_token + "\n" # 常见模板带换行
+
pos = 0
while True:
a = rendered.find(open_tag, pos)
if a == -1:
break
start = a + len(open_tag)
+
+ # 先找含换行版本,找不到再退化找不带换行的
b = rendered.find(close_tag, start)
if b == -1:
- break
- spans.append((start, b))
- pos = b + len(close_tag)
+ b = rendered.find(close_token, start)
+ if b == -1:
+ break
+
+ end = b + len(close_token) # 把 <|im_end|> 本体纳入监督
+ spans.append((start, end))
+
+ # pos 跳过这一轮结束标记(带换行就多跳一格)
+ pos = b + (len(close_tag) if rendered.startswith(close_tag, b) else len(close_token))
return spans
+
# ----------------- 工具:提取所有 … 的字符区间(包含标签本身) -----------------
def _think_char_spans(rendered: str) -> List[Tuple[int, int]]:
"""
@@ -675,6 +707,20 @@ class QwenChatSFTDataset(IterableDataset):
if not isinstance(rendered, str) or not rendered.strip():
continue
+ # —— 样本级终止符:确保训练时每条样本以 eos 结束 ——
+ # if self.tok.eos_token and not rendered.endswith(self.tok.eos_token):
+ # rendered += self.tok.eos_token
+ # —— 样本级终止符:把 eos 插入到最后一个 assistant 的 <|im_end|>\n 之前 ——
+ if self.tok.eos_token:
+ open_tag = "<|im_start|>assistant\n"
+ close_tag = "<|im_end|>\n"
+ head, sep, tail = rendered.rpartition(close_tag) # 按最后一个 close_tag 切
+ if sep: # 找到了收尾标记
+ # 仅当 assistant 文本末尾还没有 eos 才插入,避免重复
+ if not head.endswith(self.tok.eos_token):
+ rendered = head + self.tok.eos_token + sep + tail
+
+
# 1) 找到所有 assistant 区间 & 全局 think 区间
asst_spans = _assistant_char_spans(rendered)
if not asst_spans:
@@ -1047,7 +1093,7 @@ def main():
# train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
# ====== 正式训练流(不做任何手动分片,交给 Accelerate/Trainer)======
- ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
+ ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True).shuffle(buffer_size=50000, seed=args.seed)
train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
# # ====== 一致性探针(与上面保持同逻辑)=====