This commit is contained in:
parent
7f4932689f
commit
a820de7aaa
|
|
@ -577,29 +577,61 @@ from torch.utils.data import IterableDataset
|
|||
from transformers import AutoTokenizer # 仅作类型提示/引用,不强依赖
|
||||
|
||||
|
||||
# # ----------------- 工具:提取 assistant 字符区间 -----------------
|
||||
# def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
|
||||
# """
|
||||
# 在 apply_chat_template 渲染后的文本中,返回所有 assistant 内容的字符区间 [start, end)。
|
||||
# (覆盖了 assistant 的全部内容,包括其中可能出现的 <think>…</think>)
|
||||
# """
|
||||
# spans: List[Tuple[int, int]] = []
|
||||
# open_tag = "<|im_start|>assistant\n"
|
||||
# close_tag = "<|im_end|>\n"
|
||||
# pos = 0
|
||||
# while True:
|
||||
# a = rendered.find(open_tag, pos)
|
||||
# if a == -1:
|
||||
# break
|
||||
# start = a + len(open_tag)
|
||||
# b = rendered.find(close_tag, start)
|
||||
# if b == -1:
|
||||
# break
|
||||
|
||||
# end = b + len("<|im_end|>")
|
||||
# spans.append((start, end))
|
||||
# # spans.append((start, b))
|
||||
# pos = b + len(close_tag)
|
||||
# return spans
|
||||
|
||||
# ----------------- 工具:提取 assistant 字符区间 -----------------
|
||||
def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
在 apply_chat_template 渲染后的文本中,返回所有 assistant 内容的字符区间 [start, end)。
|
||||
(覆盖了 assistant 的全部内容,包括其中可能出现的 <think>…</think>)
|
||||
"""
|
||||
spans: List[Tuple[int, int]] = []
|
||||
open_tag = "<|im_start|>assistant\n"
|
||||
close_tag = "<|im_end|>\n"
|
||||
close_token = "<|im_end|>"
|
||||
close_tag = close_token + "\n" # 常见模板带换行
|
||||
|
||||
pos = 0
|
||||
while True:
|
||||
a = rendered.find(open_tag, pos)
|
||||
if a == -1:
|
||||
break
|
||||
start = a + len(open_tag)
|
||||
|
||||
# 先找含换行版本,找不到再退化找不带换行的
|
||||
b = rendered.find(close_tag, start)
|
||||
if b == -1:
|
||||
break
|
||||
spans.append((start, b))
|
||||
pos = b + len(close_tag)
|
||||
b = rendered.find(close_token, start)
|
||||
if b == -1:
|
||||
break
|
||||
|
||||
end = b + len(close_token) # 把 <|im_end|> 本体纳入监督
|
||||
spans.append((start, end))
|
||||
|
||||
# pos 跳过这一轮结束标记(带换行就多跳一格)
|
||||
pos = b + (len(close_tag) if rendered.startswith(close_tag, b) else len(close_token))
|
||||
return spans
|
||||
|
||||
|
||||
|
||||
# ----------------- 工具:提取所有 <think>…</think> 的字符区间(包含标签本身) -----------------
|
||||
def _think_char_spans(rendered: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
|
|
@ -675,6 +707,20 @@ class QwenChatSFTDataset(IterableDataset):
|
|||
if not isinstance(rendered, str) or not rendered.strip():
|
||||
continue
|
||||
|
||||
# —— 样本级终止符:确保训练时每条样本以 eos 结束 ——
|
||||
# if self.tok.eos_token and not rendered.endswith(self.tok.eos_token):
|
||||
# rendered += self.tok.eos_token
|
||||
# —— 样本级终止符:把 eos 插入到最后一个 assistant 的 <|im_end|>\n 之前 ——
|
||||
if self.tok.eos_token:
|
||||
open_tag = "<|im_start|>assistant\n"
|
||||
close_tag = "<|im_end|>\n"
|
||||
head, sep, tail = rendered.rpartition(close_tag) # 按最后一个 close_tag 切
|
||||
if sep: # 找到了收尾标记
|
||||
# 仅当 assistant 文本末尾还没有 eos 才插入,避免重复
|
||||
if not head.endswith(self.tok.eos_token):
|
||||
rendered = head + self.tok.eos_token + sep + tail
|
||||
|
||||
|
||||
# 1) 找到所有 assistant 区间 & 全局 think 区间
|
||||
asst_spans = _assistant_char_spans(rendered)
|
||||
if not asst_spans:
|
||||
|
|
@ -1047,7 +1093,7 @@ def main():
|
|||
# train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
|
||||
|
||||
# ====== 正式训练流(不做任何手动分片,交给 Accelerate/Trainer)======
|
||||
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||||
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True).shuffle(buffer_size=50000, seed=args.seed)
|
||||
train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
|
||||
|
||||
# # ====== 一致性探针(与上面保持同逻辑)=====
|
||||
|
|
|
|||
Loading…
Reference in New Issue