From 8b0d8e0c5e049d8c5ff8e533df976a3fb6710980 Mon Sep 17 00:00:00 2001 From: hailin Date: Mon, 8 Sep 2025 19:43:48 +0800 Subject: [PATCH] . --- mm-zero3.sh | 2 +- train_sft_ds.py | 90 ++++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 78 insertions(+), 14 deletions(-) diff --git a/mm-zero3.sh b/mm-zero3.sh index a68ebc5..b58c261 100755 --- a/mm-zero3.sh +++ b/mm-zero3.sh @@ -31,7 +31,7 @@ deepspeed --hostfile hostfile \ --model_name_or_path /home/test/Qwen3-32B \ --data_glob "/home/test/datasets/my_corpus/train.jsonl" \ --output_dir /home/test/checkpoints/q3-32b-ds4 \ - --seq_len 512 \ + --seq_len 1024 \ --per_device_train_batch_size 1 \ --gradient_accumulation_steps 1 \ --learning_rate 2e-5 --weight_decay 0.1 --warmup_ratio 0.02 \ diff --git a/train_sft_ds.py b/train_sft_ds.py index fd1ece4..a4c07df 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -633,10 +633,30 @@ def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]: # ----------------- 工具:提取所有 的字符区间(包含标签本身) ----------------- +# def _think_char_spans(rendered: str) -> List[Tuple[int, int]]: +# """ +# 纯 str.find 实现,不用正则。 +# 返回全局的 区间列表,坐标为 rendered 上的绝对位置。 +# """ +# spans: List[Tuple[int, int]] = [] +# open_tag = "" +# close_tag = "" +# pos = 0 +# while True: +# a = rendered.find(open_tag, pos) +# if a == -1: +# break +# b = rendered.find(close_tag, a + len(open_tag)) +# if b == -1: +# break +# spans.append((a, b + len(close_tag))) +# pos = b + len(close_tag) +# return spans + def _think_char_spans(rendered: str) -> List[Tuple[int, int]]: """ - 纯 str.find 实现,不用正则。 - 返回全局的 区间列表,坐标为 rendered 上的绝对位置。 + 返回需要忽略监督的区间(仅 ... 的“内部”), + 标签本身 仍参与监督,以便模型学会闭合。 """ spans: List[Tuple[int, int]] = [] open_tag = "" @@ -649,11 +669,13 @@ def _think_char_spans(rendered: str) -> List[Tuple[int, int]]: b = rendered.find(close_tag, a + len(open_tag)) if b == -1: break - spans.append((a, b + len(close_tag))) + # 只忽略内部,不忽略两侧标签 + spans.append((a + len(open_tag), b)) pos = b + len(close_tag) return spans + # ----------------- 仅监督 assistant 的数据集(忽略 ) ----------------- class QwenChatSFTDataset(IterableDataset): """ @@ -710,15 +732,17 @@ class QwenChatSFTDataset(IterableDataset): # —— 样本级终止符:确保训练时每条样本以 eos 结束 —— # if self.tok.eos_token and not rendered.endswith(self.tok.eos_token): # rendered += self.tok.eos_token + + # —— 样本级终止符:把 eos 插入到最后一个 assistant 的 <|im_end|>\n 之前 —— - if self.tok.eos_token: - open_tag = "<|im_start|>assistant\n" - close_tag = "<|im_end|>\n" - head, sep, tail = rendered.rpartition(close_tag) # 按最后一个 close_tag 切 - if sep: # 找到了收尾标记 - # 仅当 assistant 文本末尾还没有 eos 才插入,避免重复 - if not head.endswith(self.tok.eos_token): - rendered = head + self.tok.eos_token + sep + tail + # if self.tok.eos_token: + # open_tag = "<|im_start|>assistant\n" + # close_tag = "<|im_end|>\n" + # head, sep, tail = rendered.rpartition(close_tag) # 按最后一个 close_tag 切 + # if sep: # 找到了收尾标记 + # # 仅当 assistant 文本末尾还没有 eos 才插入,避免重复 + # if not head.endswith(self.tok.eos_token): + # rendered = head + self.tok.eos_token + sep + tail # 1) 找到所有 assistant 区间 & 全局 think 区间 @@ -753,9 +777,49 @@ class QwenChatSFTDataset(IterableDataset): labels[i] = input_ids[i] # —— 固定长度策略:先截尾(保留尾部),再左侧补齐 —— + # if len(input_ids) > self.seq_len: + # input_ids = input_ids[-self.seq_len:] + # labels = labels[-self.seq_len:] + + # ======== 助手感知的截断策略:尽量保证“最后一个 assistant 片段”完整 ======== if len(input_ids) > self.seq_len: - input_ids = input_ids[-self.seq_len:] - labels = labels[-self.seq_len:] + # 取最后一个 assistant 的字符区间([s_last, e_last)) + s_last, e_last = asst_spans[-1] + + # 用 offsets 把字符区间映射到 token 索引区间 [j, k_excl) + j = 0 + while j < len(offsets) and offsets[j][1] <= s_last: + j += 1 + k_excl = j + while k_excl < len(offsets) and offsets[k_excl][0] < e_last: + k_excl += 1 + + A = max(0, k_excl - j) # 最后一个 assistant 覆盖的 token 数 + + if A >= self.seq_len: + # 单个 assistant 本身超过窗口 —— 保“结尾”,避免切尾 + start = max(0, k_excl - self.seq_len) + end = start + self.seq_len + else: + # 有空间容纳整个 assistant:让窗口覆盖完整 [j, k_excl) + start = max(0, min(j, len(input_ids) - self.seq_len)) + end = start + self.seq_len + if end < k_excl: + end = k_excl + start = max(0, end - self.seq_len) + + # 可选:尝试“居中”一点(给最后一个 assistant 左右留些上下文) + leftover = self.seq_len - A + left_wish = leftover // 2 + start = max(0, min(j - left_wish, start)) + end = start + self.seq_len + if end < k_excl: + end = k_excl + start = max(0, end - self.seq_len) + + # 真正切片 + input_ids = input_ids[start:end] + labels = labels[start:end] pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id L = len(input_ids)