This commit is contained in:
parent
c1eef90d1c
commit
8b0d8e0c5e
|
|
@ -31,7 +31,7 @@ deepspeed --hostfile hostfile \
|
||||||
--model_name_or_path /home/test/Qwen3-32B \
|
--model_name_or_path /home/test/Qwen3-32B \
|
||||||
--data_glob "/home/test/datasets/my_corpus/train.jsonl" \
|
--data_glob "/home/test/datasets/my_corpus/train.jsonl" \
|
||||||
--output_dir /home/test/checkpoints/q3-32b-ds4 \
|
--output_dir /home/test/checkpoints/q3-32b-ds4 \
|
||||||
--seq_len 512 \
|
--seq_len 1024 \
|
||||||
--per_device_train_batch_size 1 \
|
--per_device_train_batch_size 1 \
|
||||||
--gradient_accumulation_steps 1 \
|
--gradient_accumulation_steps 1 \
|
||||||
--learning_rate 2e-5 --weight_decay 0.1 --warmup_ratio 0.02 \
|
--learning_rate 2e-5 --weight_decay 0.1 --warmup_ratio 0.02 \
|
||||||
|
|
|
||||||
|
|
@ -633,10 +633,30 @@ def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
|
||||||
|
|
||||||
|
|
||||||
# ----------------- 工具:提取所有 <think>…</think> 的字符区间(包含标签本身) -----------------
|
# ----------------- 工具:提取所有 <think>…</think> 的字符区间(包含标签本身) -----------------
|
||||||
|
# def _think_char_spans(rendered: str) -> List[Tuple[int, int]]:
|
||||||
|
# """
|
||||||
|
# 纯 str.find 实现,不用正则。
|
||||||
|
# 返回全局的 <think>…</think> 区间列表,坐标为 rendered 上的绝对位置。
|
||||||
|
# """
|
||||||
|
# spans: List[Tuple[int, int]] = []
|
||||||
|
# open_tag = "<think>"
|
||||||
|
# close_tag = "</think>"
|
||||||
|
# pos = 0
|
||||||
|
# while True:
|
||||||
|
# a = rendered.find(open_tag, pos)
|
||||||
|
# if a == -1:
|
||||||
|
# break
|
||||||
|
# b = rendered.find(close_tag, a + len(open_tag))
|
||||||
|
# if b == -1:
|
||||||
|
# break
|
||||||
|
# spans.append((a, b + len(close_tag)))
|
||||||
|
# pos = b + len(close_tag)
|
||||||
|
# return spans
|
||||||
|
|
||||||
def _think_char_spans(rendered: str) -> List[Tuple[int, int]]:
|
def _think_char_spans(rendered: str) -> List[Tuple[int, int]]:
|
||||||
"""
|
"""
|
||||||
纯 str.find 实现,不用正则。
|
返回需要忽略监督的区间(仅 <think>...</think> 的“内部”),
|
||||||
返回全局的 <think>…</think> 区间列表,坐标为 rendered 上的绝对位置。
|
标签本身 <think> 与 </think> 仍参与监督,以便模型学会闭合。
|
||||||
"""
|
"""
|
||||||
spans: List[Tuple[int, int]] = []
|
spans: List[Tuple[int, int]] = []
|
||||||
open_tag = "<think>"
|
open_tag = "<think>"
|
||||||
|
|
@ -649,11 +669,13 @@ def _think_char_spans(rendered: str) -> List[Tuple[int, int]]:
|
||||||
b = rendered.find(close_tag, a + len(open_tag))
|
b = rendered.find(close_tag, a + len(open_tag))
|
||||||
if b == -1:
|
if b == -1:
|
||||||
break
|
break
|
||||||
spans.append((a, b + len(close_tag)))
|
# 只忽略内部,不忽略两侧标签
|
||||||
|
spans.append((a + len(open_tag), b))
|
||||||
pos = b + len(close_tag)
|
pos = b + len(close_tag)
|
||||||
return spans
|
return spans
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------- 仅监督 assistant 的数据集(忽略 <think>…</think>) -----------------
|
# ----------------- 仅监督 assistant 的数据集(忽略 <think>…</think>) -----------------
|
||||||
class QwenChatSFTDataset(IterableDataset):
|
class QwenChatSFTDataset(IterableDataset):
|
||||||
"""
|
"""
|
||||||
|
|
@ -710,15 +732,17 @@ class QwenChatSFTDataset(IterableDataset):
|
||||||
# —— 样本级终止符:确保训练时每条样本以 eos 结束 ——
|
# —— 样本级终止符:确保训练时每条样本以 eos 结束 ——
|
||||||
# if self.tok.eos_token and not rendered.endswith(self.tok.eos_token):
|
# if self.tok.eos_token and not rendered.endswith(self.tok.eos_token):
|
||||||
# rendered += self.tok.eos_token
|
# rendered += self.tok.eos_token
|
||||||
|
|
||||||
|
|
||||||
# —— 样本级终止符:把 eos 插入到最后一个 assistant 的 <|im_end|>\n 之前 ——
|
# —— 样本级终止符:把 eos 插入到最后一个 assistant 的 <|im_end|>\n 之前 ——
|
||||||
if self.tok.eos_token:
|
# if self.tok.eos_token:
|
||||||
open_tag = "<|im_start|>assistant\n"
|
# open_tag = "<|im_start|>assistant\n"
|
||||||
close_tag = "<|im_end|>\n"
|
# close_tag = "<|im_end|>\n"
|
||||||
head, sep, tail = rendered.rpartition(close_tag) # 按最后一个 close_tag 切
|
# head, sep, tail = rendered.rpartition(close_tag) # 按最后一个 close_tag 切
|
||||||
if sep: # 找到了收尾标记
|
# if sep: # 找到了收尾标记
|
||||||
# 仅当 assistant 文本末尾还没有 eos 才插入,避免重复
|
# # 仅当 assistant 文本末尾还没有 eos 才插入,避免重复
|
||||||
if not head.endswith(self.tok.eos_token):
|
# if not head.endswith(self.tok.eos_token):
|
||||||
rendered = head + self.tok.eos_token + sep + tail
|
# rendered = head + self.tok.eos_token + sep + tail
|
||||||
|
|
||||||
|
|
||||||
# 1) 找到所有 assistant 区间 & 全局 think 区间
|
# 1) 找到所有 assistant 区间 & 全局 think 区间
|
||||||
|
|
@ -753,9 +777,49 @@ class QwenChatSFTDataset(IterableDataset):
|
||||||
labels[i] = input_ids[i]
|
labels[i] = input_ids[i]
|
||||||
|
|
||||||
# —— 固定长度策略:先截尾(保留尾部),再左侧补齐 ——
|
# —— 固定长度策略:先截尾(保留尾部),再左侧补齐 ——
|
||||||
|
# if len(input_ids) > self.seq_len:
|
||||||
|
# input_ids = input_ids[-self.seq_len:]
|
||||||
|
# labels = labels[-self.seq_len:]
|
||||||
|
|
||||||
|
# ======== 助手感知的截断策略:尽量保证“最后一个 assistant 片段”完整 ========
|
||||||
if len(input_ids) > self.seq_len:
|
if len(input_ids) > self.seq_len:
|
||||||
input_ids = input_ids[-self.seq_len:]
|
# 取最后一个 assistant 的字符区间([s_last, e_last))
|
||||||
labels = labels[-self.seq_len:]
|
s_last, e_last = asst_spans[-1]
|
||||||
|
|
||||||
|
# 用 offsets 把字符区间映射到 token 索引区间 [j, k_excl)
|
||||||
|
j = 0
|
||||||
|
while j < len(offsets) and offsets[j][1] <= s_last:
|
||||||
|
j += 1
|
||||||
|
k_excl = j
|
||||||
|
while k_excl < len(offsets) and offsets[k_excl][0] < e_last:
|
||||||
|
k_excl += 1
|
||||||
|
|
||||||
|
A = max(0, k_excl - j) # 最后一个 assistant 覆盖的 token 数
|
||||||
|
|
||||||
|
if A >= self.seq_len:
|
||||||
|
# 单个 assistant 本身超过窗口 —— 保“结尾”,避免切尾
|
||||||
|
start = max(0, k_excl - self.seq_len)
|
||||||
|
end = start + self.seq_len
|
||||||
|
else:
|
||||||
|
# 有空间容纳整个 assistant:让窗口覆盖完整 [j, k_excl)
|
||||||
|
start = max(0, min(j, len(input_ids) - self.seq_len))
|
||||||
|
end = start + self.seq_len
|
||||||
|
if end < k_excl:
|
||||||
|
end = k_excl
|
||||||
|
start = max(0, end - self.seq_len)
|
||||||
|
|
||||||
|
# 可选:尝试“居中”一点(给最后一个 assistant 左右留些上下文)
|
||||||
|
leftover = self.seq_len - A
|
||||||
|
left_wish = leftover // 2
|
||||||
|
start = max(0, min(j - left_wish, start))
|
||||||
|
end = start + self.seq_len
|
||||||
|
if end < k_excl:
|
||||||
|
end = k_excl
|
||||||
|
start = max(0, end - self.seq_len)
|
||||||
|
|
||||||
|
# 真正切片
|
||||||
|
input_ids = input_ids[start:end]
|
||||||
|
labels = labels[start:end]
|
||||||
|
|
||||||
pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id
|
pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id
|
||||||
L = len(input_ids)
|
L = len(input_ids)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue