diff --git a/mm-zero3.sh b/mm-zero3.sh
index a68ebc5..b58c261 100755
--- a/mm-zero3.sh
+++ b/mm-zero3.sh
@@ -31,7 +31,7 @@ deepspeed --hostfile hostfile \
--model_name_or_path /home/test/Qwen3-32B \
--data_glob "/home/test/datasets/my_corpus/train.jsonl" \
--output_dir /home/test/checkpoints/q3-32b-ds4 \
- --seq_len 512 \
+ --seq_len 1024 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--learning_rate 2e-5 --weight_decay 0.1 --warmup_ratio 0.02 \
diff --git a/train_sft_ds.py b/train_sft_ds.py
index fd1ece4..a4c07df 100644
--- a/train_sft_ds.py
+++ b/train_sft_ds.py
@@ -633,10 +633,30 @@ def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
# ----------------- 工具:提取所有 … 的字符区间(包含标签本身) -----------------
+# def _think_char_spans(rendered: str) -> List[Tuple[int, int]]:
+# """
+# 纯 str.find 实现,不用正则。
+# 返回全局的 … 区间列表,坐标为 rendered 上的绝对位置。
+# """
+# spans: List[Tuple[int, int]] = []
+# open_tag = ""
+# close_tag = ""
+# pos = 0
+# while True:
+# a = rendered.find(open_tag, pos)
+# if a == -1:
+# break
+# b = rendered.find(close_tag, a + len(open_tag))
+# if b == -1:
+# break
+# spans.append((a, b + len(close_tag)))
+# pos = b + len(close_tag)
+# return spans
+
def _think_char_spans(rendered: str) -> List[Tuple[int, int]]:
"""
- 纯 str.find 实现,不用正则。
- 返回全局的 … 区间列表,坐标为 rendered 上的绝对位置。
+ 返回需要忽略监督的区间(仅 ... 的“内部”),
+ 标签本身 与 仍参与监督,以便模型学会闭合。
"""
spans: List[Tuple[int, int]] = []
open_tag = ""
@@ -649,11 +669,13 @@ def _think_char_spans(rendered: str) -> List[Tuple[int, int]]:
b = rendered.find(close_tag, a + len(open_tag))
if b == -1:
break
- spans.append((a, b + len(close_tag)))
+ # 只忽略内部,不忽略两侧标签
+ spans.append((a + len(open_tag), b))
pos = b + len(close_tag)
return spans
+
# ----------------- 仅监督 assistant 的数据集(忽略 …) -----------------
class QwenChatSFTDataset(IterableDataset):
"""
@@ -710,15 +732,17 @@ class QwenChatSFTDataset(IterableDataset):
# —— 样本级终止符:确保训练时每条样本以 eos 结束 ——
# if self.tok.eos_token and not rendered.endswith(self.tok.eos_token):
# rendered += self.tok.eos_token
+
+
# —— 样本级终止符:把 eos 插入到最后一个 assistant 的 <|im_end|>\n 之前 ——
- if self.tok.eos_token:
- open_tag = "<|im_start|>assistant\n"
- close_tag = "<|im_end|>\n"
- head, sep, tail = rendered.rpartition(close_tag) # 按最后一个 close_tag 切
- if sep: # 找到了收尾标记
- # 仅当 assistant 文本末尾还没有 eos 才插入,避免重复
- if not head.endswith(self.tok.eos_token):
- rendered = head + self.tok.eos_token + sep + tail
+ # if self.tok.eos_token:
+ # open_tag = "<|im_start|>assistant\n"
+ # close_tag = "<|im_end|>\n"
+ # head, sep, tail = rendered.rpartition(close_tag) # 按最后一个 close_tag 切
+ # if sep: # 找到了收尾标记
+ # # 仅当 assistant 文本末尾还没有 eos 才插入,避免重复
+ # if not head.endswith(self.tok.eos_token):
+ # rendered = head + self.tok.eos_token + sep + tail
# 1) 找到所有 assistant 区间 & 全局 think 区间
@@ -753,9 +777,49 @@ class QwenChatSFTDataset(IterableDataset):
labels[i] = input_ids[i]
# —— 固定长度策略:先截尾(保留尾部),再左侧补齐 ——
+ # if len(input_ids) > self.seq_len:
+ # input_ids = input_ids[-self.seq_len:]
+ # labels = labels[-self.seq_len:]
+
+ # ======== 助手感知的截断策略:尽量保证“最后一个 assistant 片段”完整 ========
if len(input_ids) > self.seq_len:
- input_ids = input_ids[-self.seq_len:]
- labels = labels[-self.seq_len:]
+ # 取最后一个 assistant 的字符区间([s_last, e_last))
+ s_last, e_last = asst_spans[-1]
+
+ # 用 offsets 把字符区间映射到 token 索引区间 [j, k_excl)
+ j = 0
+ while j < len(offsets) and offsets[j][1] <= s_last:
+ j += 1
+ k_excl = j
+ while k_excl < len(offsets) and offsets[k_excl][0] < e_last:
+ k_excl += 1
+
+ A = max(0, k_excl - j) # 最后一个 assistant 覆盖的 token 数
+
+ if A >= self.seq_len:
+ # 单个 assistant 本身超过窗口 —— 保“结尾”,避免切尾
+ start = max(0, k_excl - self.seq_len)
+ end = start + self.seq_len
+ else:
+ # 有空间容纳整个 assistant:让窗口覆盖完整 [j, k_excl)
+ start = max(0, min(j, len(input_ids) - self.seq_len))
+ end = start + self.seq_len
+ if end < k_excl:
+ end = k_excl
+ start = max(0, end - self.seq_len)
+
+ # 可选:尝试“居中”一点(给最后一个 assistant 左右留些上下文)
+ leftover = self.seq_len - A
+ left_wish = leftover // 2
+ start = max(0, min(j - left_wish, start))
+ end = start + self.seq_len
+ if end < k_excl:
+ end = k_excl
+ start = max(0, end - self.seq_len)
+
+ # 真正切片
+ input_ids = input_ids[start:end]
+ labels = labels[start:end]
pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id
L = len(input_ids)