This commit is contained in:
hailin 2025-09-08 14:45:12 +08:00
parent 7f4932689f
commit a820de7aaa
1 changed files with 55 additions and 9 deletions

View File

@ -577,29 +577,61 @@ from torch.utils.data import IterableDataset
from transformers import AutoTokenizer # 仅作类型提示/引用,不强依赖
# # ----------------- 工具:提取 assistant 字符区间 -----------------
# def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
# """
# 在 apply_chat_template 渲染后的文本中,返回所有 assistant 内容的字符区间 [start, end)。
# (覆盖了 assistant 的全部内容,包括其中可能出现的 <think>…</think>
# """
# spans: List[Tuple[int, int]] = []
# open_tag = "<|im_start|>assistant\n"
# close_tag = "<|im_end|>\n"
# pos = 0
# while True:
# a = rendered.find(open_tag, pos)
# if a == -1:
# break
# start = a + len(open_tag)
# b = rendered.find(close_tag, start)
# if b == -1:
# break
# end = b + len("<|im_end|>")
# spans.append((start, end))
# # spans.append((start, b))
# pos = b + len(close_tag)
# return spans
# ----------------- 工具:提取 assistant 字符区间 -----------------
def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
"""
apply_chat_template 渲染后的文本中返回所有 assistant 内容的字符区间 [start, end)
覆盖了 assistant 的全部内容包括其中可能出现的 <think></think>
"""
spans: List[Tuple[int, int]] = []
open_tag = "<|im_start|>assistant\n"
close_tag = "<|im_end|>\n"
close_token = "<|im_end|>"
close_tag = close_token + "\n" # 常见模板带换行
pos = 0
while True:
a = rendered.find(open_tag, pos)
if a == -1:
break
start = a + len(open_tag)
# 先找含换行版本,找不到再退化找不带换行的
b = rendered.find(close_tag, start)
if b == -1:
break
spans.append((start, b))
pos = b + len(close_tag)
b = rendered.find(close_token, start)
if b == -1:
break
end = b + len(close_token) # 把 <|im_end|> 本体纳入监督
spans.append((start, end))
# pos 跳过这一轮结束标记(带换行就多跳一格)
pos = b + (len(close_tag) if rendered.startswith(close_tag, b) else len(close_token))
return spans
# ----------------- 工具:提取所有 <think>…</think> 的字符区间(包含标签本身) -----------------
def _think_char_spans(rendered: str) -> List[Tuple[int, int]]:
"""
@ -675,6 +707,20 @@ class QwenChatSFTDataset(IterableDataset):
if not isinstance(rendered, str) or not rendered.strip():
continue
# —— 样本级终止符:确保训练时每条样本以 eos 结束 ——
# if self.tok.eos_token and not rendered.endswith(self.tok.eos_token):
# rendered += self.tok.eos_token
# —— 样本级终止符:把 eos 插入到最后一个 assistant 的 <|im_end|>\n 之前 ——
if self.tok.eos_token:
open_tag = "<|im_start|>assistant\n"
close_tag = "<|im_end|>\n"
head, sep, tail = rendered.rpartition(close_tag) # 按最后一个 close_tag 切
if sep: # 找到了收尾标记
# 仅当 assistant 文本末尾还没有 eos 才插入,避免重复
if not head.endswith(self.tok.eos_token):
rendered = head + self.tok.eos_token + sep + tail
# 1) 找到所有 assistant 区间 & 全局 think 区间
asst_spans = _assistant_char_spans(rendered)
if not asst_spans:
@ -1047,7 +1093,7 @@ def main():
# train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
# ====== 正式训练流(不做任何手动分片,交给 Accelerate/Trainer======
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True).shuffle(buffer_size=50000, seed=args.seed)
train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
# # ====== 一致性探针(与上面保持同逻辑)=====