This commit is contained in:
parent
c1eef90d1c
commit
8b0d8e0c5e
|
|
@ -31,7 +31,7 @@ deepspeed --hostfile hostfile \
|
|||
--model_name_or_path /home/test/Qwen3-32B \
|
||||
--data_glob "/home/test/datasets/my_corpus/train.jsonl" \
|
||||
--output_dir /home/test/checkpoints/q3-32b-ds4 \
|
||||
--seq_len 512 \
|
||||
--seq_len 1024 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--learning_rate 2e-5 --weight_decay 0.1 --warmup_ratio 0.02 \
|
||||
|
|
|
|||
|
|
@ -633,10 +633,30 @@ def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
|
|||
|
||||
|
||||
# ----------------- 工具:提取所有 <think>…</think> 的字符区间(包含标签本身) -----------------
|
||||
# def _think_char_spans(rendered: str) -> List[Tuple[int, int]]:
|
||||
# """
|
||||
# 纯 str.find 实现,不用正则。
|
||||
# 返回全局的 <think>…</think> 区间列表,坐标为 rendered 上的绝对位置。
|
||||
# """
|
||||
# spans: List[Tuple[int, int]] = []
|
||||
# open_tag = "<think>"
|
||||
# close_tag = "</think>"
|
||||
# pos = 0
|
||||
# while True:
|
||||
# a = rendered.find(open_tag, pos)
|
||||
# if a == -1:
|
||||
# break
|
||||
# b = rendered.find(close_tag, a + len(open_tag))
|
||||
# if b == -1:
|
||||
# break
|
||||
# spans.append((a, b + len(close_tag)))
|
||||
# pos = b + len(close_tag)
|
||||
# return spans
|
||||
|
||||
def _think_char_spans(rendered: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
纯 str.find 实现,不用正则。
|
||||
返回全局的 <think>…</think> 区间列表,坐标为 rendered 上的绝对位置。
|
||||
返回需要忽略监督的区间(仅 <think>...</think> 的“内部”),
|
||||
标签本身 <think> 与 </think> 仍参与监督,以便模型学会闭合。
|
||||
"""
|
||||
spans: List[Tuple[int, int]] = []
|
||||
open_tag = "<think>"
|
||||
|
|
@ -649,11 +669,13 @@ def _think_char_spans(rendered: str) -> List[Tuple[int, int]]:
|
|||
b = rendered.find(close_tag, a + len(open_tag))
|
||||
if b == -1:
|
||||
break
|
||||
spans.append((a, b + len(close_tag)))
|
||||
# 只忽略内部,不忽略两侧标签
|
||||
spans.append((a + len(open_tag), b))
|
||||
pos = b + len(close_tag)
|
||||
return spans
|
||||
|
||||
|
||||
|
||||
# ----------------- 仅监督 assistant 的数据集(忽略 <think>…</think>) -----------------
|
||||
class QwenChatSFTDataset(IterableDataset):
|
||||
"""
|
||||
|
|
@ -710,15 +732,17 @@ class QwenChatSFTDataset(IterableDataset):
|
|||
# —— 样本级终止符:确保训练时每条样本以 eos 结束 ——
|
||||
# if self.tok.eos_token and not rendered.endswith(self.tok.eos_token):
|
||||
# rendered += self.tok.eos_token
|
||||
|
||||
|
||||
# —— 样本级终止符:把 eos 插入到最后一个 assistant 的 <|im_end|>\n 之前 ——
|
||||
if self.tok.eos_token:
|
||||
open_tag = "<|im_start|>assistant\n"
|
||||
close_tag = "<|im_end|>\n"
|
||||
head, sep, tail = rendered.rpartition(close_tag) # 按最后一个 close_tag 切
|
||||
if sep: # 找到了收尾标记
|
||||
# 仅当 assistant 文本末尾还没有 eos 才插入,避免重复
|
||||
if not head.endswith(self.tok.eos_token):
|
||||
rendered = head + self.tok.eos_token + sep + tail
|
||||
# if self.tok.eos_token:
|
||||
# open_tag = "<|im_start|>assistant\n"
|
||||
# close_tag = "<|im_end|>\n"
|
||||
# head, sep, tail = rendered.rpartition(close_tag) # 按最后一个 close_tag 切
|
||||
# if sep: # 找到了收尾标记
|
||||
# # 仅当 assistant 文本末尾还没有 eos 才插入,避免重复
|
||||
# if not head.endswith(self.tok.eos_token):
|
||||
# rendered = head + self.tok.eos_token + sep + tail
|
||||
|
||||
|
||||
# 1) 找到所有 assistant 区间 & 全局 think 区间
|
||||
|
|
@ -753,9 +777,49 @@ class QwenChatSFTDataset(IterableDataset):
|
|||
labels[i] = input_ids[i]
|
||||
|
||||
# —— 固定长度策略:先截尾(保留尾部),再左侧补齐 ——
|
||||
# if len(input_ids) > self.seq_len:
|
||||
# input_ids = input_ids[-self.seq_len:]
|
||||
# labels = labels[-self.seq_len:]
|
||||
|
||||
# ======== 助手感知的截断策略:尽量保证“最后一个 assistant 片段”完整 ========
|
||||
if len(input_ids) > self.seq_len:
|
||||
input_ids = input_ids[-self.seq_len:]
|
||||
labels = labels[-self.seq_len:]
|
||||
# 取最后一个 assistant 的字符区间([s_last, e_last))
|
||||
s_last, e_last = asst_spans[-1]
|
||||
|
||||
# 用 offsets 把字符区间映射到 token 索引区间 [j, k_excl)
|
||||
j = 0
|
||||
while j < len(offsets) and offsets[j][1] <= s_last:
|
||||
j += 1
|
||||
k_excl = j
|
||||
while k_excl < len(offsets) and offsets[k_excl][0] < e_last:
|
||||
k_excl += 1
|
||||
|
||||
A = max(0, k_excl - j) # 最后一个 assistant 覆盖的 token 数
|
||||
|
||||
if A >= self.seq_len:
|
||||
# 单个 assistant 本身超过窗口 —— 保“结尾”,避免切尾
|
||||
start = max(0, k_excl - self.seq_len)
|
||||
end = start + self.seq_len
|
||||
else:
|
||||
# 有空间容纳整个 assistant:让窗口覆盖完整 [j, k_excl)
|
||||
start = max(0, min(j, len(input_ids) - self.seq_len))
|
||||
end = start + self.seq_len
|
||||
if end < k_excl:
|
||||
end = k_excl
|
||||
start = max(0, end - self.seq_len)
|
||||
|
||||
# 可选:尝试“居中”一点(给最后一个 assistant 左右留些上下文)
|
||||
leftover = self.seq_len - A
|
||||
left_wish = leftover // 2
|
||||
start = max(0, min(j - left_wish, start))
|
||||
end = start + self.seq_len
|
||||
if end < k_excl:
|
||||
end = k_excl
|
||||
start = max(0, end - self.seq_len)
|
||||
|
||||
# 真正切片
|
||||
input_ids = input_ids[start:end]
|
||||
labels = labels[start:end]
|
||||
|
||||
pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id
|
||||
L = len(input_ids)
|
||||
|
|
|
|||
Loading…
Reference in New Issue