This commit is contained in:
hailin 2025-09-08 19:43:48 +08:00
parent c1eef90d1c
commit 8b0d8e0c5e
2 changed files with 78 additions and 14 deletions

View File

@ -31,7 +31,7 @@ deepspeed --hostfile hostfile \
--model_name_or_path /home/test/Qwen3-32B \ --model_name_or_path /home/test/Qwen3-32B \
--data_glob "/home/test/datasets/my_corpus/train.jsonl" \ --data_glob "/home/test/datasets/my_corpus/train.jsonl" \
--output_dir /home/test/checkpoints/q3-32b-ds4 \ --output_dir /home/test/checkpoints/q3-32b-ds4 \
--seq_len 512 \ --seq_len 1024 \
--per_device_train_batch_size 1 \ --per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \ --gradient_accumulation_steps 1 \
--learning_rate 2e-5 --weight_decay 0.1 --warmup_ratio 0.02 \ --learning_rate 2e-5 --weight_decay 0.1 --warmup_ratio 0.02 \

View File

@ -633,10 +633,30 @@ def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
# ----------------- 工具:提取所有 <think>…</think> 的字符区间(包含标签本身) ----------------- # ----------------- 工具:提取所有 <think>…</think> 的字符区间(包含标签本身) -----------------
# def _think_char_spans(rendered: str) -> List[Tuple[int, int]]:
# """
# 纯 str.find 实现,不用正则。
# 返回全局的 <think>…</think> 区间列表,坐标为 rendered 上的绝对位置。
# """
# spans: List[Tuple[int, int]] = []
# open_tag = "<think>"
# close_tag = "</think>"
# pos = 0
# while True:
# a = rendered.find(open_tag, pos)
# if a == -1:
# break
# b = rendered.find(close_tag, a + len(open_tag))
# if b == -1:
# break
# spans.append((a, b + len(close_tag)))
# pos = b + len(close_tag)
# return spans
def _think_char_spans(rendered: str) -> List[Tuple[int, int]]: def _think_char_spans(rendered: str) -> List[Tuple[int, int]]:
""" """
str.find 实现不用正则 返回需要忽略监督的区间 <think>...</think> 内部
返回全局的 <think></think> 区间列表坐标为 rendered 上的绝对位置 标签本身 <think> </think> 仍参与监督以便模型学会闭合
""" """
spans: List[Tuple[int, int]] = [] spans: List[Tuple[int, int]] = []
open_tag = "<think>" open_tag = "<think>"
@ -649,11 +669,13 @@ def _think_char_spans(rendered: str) -> List[Tuple[int, int]]:
b = rendered.find(close_tag, a + len(open_tag)) b = rendered.find(close_tag, a + len(open_tag))
if b == -1: if b == -1:
break break
spans.append((a, b + len(close_tag))) # 只忽略内部,不忽略两侧标签
spans.append((a + len(open_tag), b))
pos = b + len(close_tag) pos = b + len(close_tag)
return spans return spans
# ----------------- 仅监督 assistant 的数据集(忽略 <think>…</think> ----------------- # ----------------- 仅监督 assistant 的数据集(忽略 <think>…</think> -----------------
class QwenChatSFTDataset(IterableDataset): class QwenChatSFTDataset(IterableDataset):
""" """
@ -710,15 +732,17 @@ class QwenChatSFTDataset(IterableDataset):
# —— 样本级终止符:确保训练时每条样本以 eos 结束 —— # —— 样本级终止符:确保训练时每条样本以 eos 结束 ——
# if self.tok.eos_token and not rendered.endswith(self.tok.eos_token): # if self.tok.eos_token and not rendered.endswith(self.tok.eos_token):
# rendered += self.tok.eos_token # rendered += self.tok.eos_token
# —— 样本级终止符:把 eos 插入到最后一个 assistant 的 <|im_end|>\n 之前 —— # —— 样本级终止符:把 eos 插入到最后一个 assistant 的 <|im_end|>\n 之前 ——
if self.tok.eos_token: # if self.tok.eos_token:
open_tag = "<|im_start|>assistant\n" # open_tag = "<|im_start|>assistant\n"
close_tag = "<|im_end|>\n" # close_tag = "<|im_end|>\n"
head, sep, tail = rendered.rpartition(close_tag) # 按最后一个 close_tag 切 # head, sep, tail = rendered.rpartition(close_tag) # 按最后一个 close_tag 切
if sep: # 找到了收尾标记 # if sep: # 找到了收尾标记
# 仅当 assistant 文本末尾还没有 eos 才插入,避免重复 # # 仅当 assistant 文本末尾还没有 eos 才插入,避免重复
if not head.endswith(self.tok.eos_token): # if not head.endswith(self.tok.eos_token):
rendered = head + self.tok.eos_token + sep + tail # rendered = head + self.tok.eos_token + sep + tail
# 1) 找到所有 assistant 区间 & 全局 think 区间 # 1) 找到所有 assistant 区间 & 全局 think 区间
@ -753,9 +777,49 @@ class QwenChatSFTDataset(IterableDataset):
labels[i] = input_ids[i] labels[i] = input_ids[i]
# —— 固定长度策略:先截尾(保留尾部),再左侧补齐 —— # —— 固定长度策略:先截尾(保留尾部),再左侧补齐 ——
# if len(input_ids) > self.seq_len:
# input_ids = input_ids[-self.seq_len:]
# labels = labels[-self.seq_len:]
# ======== 助手感知的截断策略:尽量保证“最后一个 assistant 片段”完整 ========
if len(input_ids) > self.seq_len: if len(input_ids) > self.seq_len:
input_ids = input_ids[-self.seq_len:] # 取最后一个 assistant 的字符区间([s_last, e_last)
labels = labels[-self.seq_len:] s_last, e_last = asst_spans[-1]
# 用 offsets 把字符区间映射到 token 索引区间 [j, k_excl)
j = 0
while j < len(offsets) and offsets[j][1] <= s_last:
j += 1
k_excl = j
while k_excl < len(offsets) and offsets[k_excl][0] < e_last:
k_excl += 1
A = max(0, k_excl - j) # 最后一个 assistant 覆盖的 token 数
if A >= self.seq_len:
# 单个 assistant 本身超过窗口 —— 保“结尾”,避免切尾
start = max(0, k_excl - self.seq_len)
end = start + self.seq_len
else:
# 有空间容纳整个 assistant让窗口覆盖完整 [j, k_excl)
start = max(0, min(j, len(input_ids) - self.seq_len))
end = start + self.seq_len
if end < k_excl:
end = k_excl
start = max(0, end - self.seq_len)
# 可选:尝试“居中”一点(给最后一个 assistant 左右留些上下文)
leftover = self.seq_len - A
left_wish = leftover // 2
start = max(0, min(j - left_wish, start))
end = start + self.seq_len
if end < k_excl:
end = k_excl
start = max(0, end - self.seq_len)
# 真正切片
input_ids = input_ids[start:end]
labels = labels[start:end]
pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id
L = len(input_ids) L = len(input_ids)