From e9e7626ae7ff7c2265e5ca588da492a308b633a5 Mon Sep 17 00:00:00 2001 From: hailin Date: Tue, 26 Aug 2025 08:43:15 +0800 Subject: [PATCH] . --- train_sft_ds.py | 41 ++++++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/train_sft_ds.py b/train_sft_ds.py index b395c07..3753cb9 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -95,7 +95,7 @@ class QwenChatSFTDataset(IterableDataset): if not msgs or not isinstance(msgs, list): continue - # 可选:过滤掉带有非空 的样本(避免训练真实 COT) + # 可选过滤 think bad = False for m in msgs: if m.get("role") == "assistant" and isinstance(m.get("content"), str): @@ -109,22 +109,16 @@ class QwenChatSFTDataset(IterableDataset): tools = ex.get("tools", None) - # 1) 按模型自带模板渲染 rendered: str = self.tok.apply_chat_template( - msgs, - tools=tools, - add_generation_prompt=False, - tokenize=False + msgs, tools=tools, add_generation_prompt=False, tokenize=False ) if not isinstance(rendered, str) or not rendered.strip(): continue - # 2) 找出 assistant 片段的字符区间 spans = _assistant_char_spans(rendered) if not spans: continue - # 3) 分词 + 字符/Token 对齐 enc = self.tok( rendered, add_special_tokens=False, @@ -133,11 +127,9 @@ class QwenChatSFTDataset(IterableDataset): input_ids: List[int] = enc["input_ids"] offsets: List[Tuple[int, int]] = enc["offset_mapping"] - # 空样本防御:分词后长度为 0 if not input_ids: continue - # 4) 仅 assistant 计损失 labels = [-100] * len(input_ids) def in_any_span(lo: int, hi: int) -> bool: @@ -150,22 +142,40 @@ class QwenChatSFTDataset(IterableDataset): if in_any_span(lo, hi): labels[i] = input_ids[i] - # 5) 超长裁剪(保留尾部) + # —— 固定长度策略:先截尾,再在 Dataset 层补到固定 seq_len —— + # 1) 截断到 seq_len(保留尾部) if len(input_ids) > self.seq_len: input_ids = input_ids[-self.seq_len:] labels = labels[-self.seq_len:] - # 若没有任何可训练 token(labels 全 -100),也跳过 + # 2) 左侧补齐到 seq_len(保证所有样本长度一致) + pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id + L = len(input_ids) + if L < self.seq_len: + pad = self.seq_len - L + input_ids = ([pad_id] * pad) + input_ids + labels = ([-100] * pad) + labels + attn_mask = [0] * pad + [1] * L + else: + # 恰好等于 seq_len + attn_mask = [1] * self.seq_len + + # 若没有任何可训练 token(labels 全 -100),跳过 if all(v == -100 for v in labels): continue + assert len(input_ids) == self.seq_len + assert len(labels) == self.seq_len + assert len(attn_mask) == self.seq_len + yield { "input_ids": torch.tensor(input_ids, dtype=torch.long), - "attention_mask": torch.ones(len(input_ids), dtype=torch.long), - "labels": torch.tensor(labels, dtype=torch.long) + "attention_mask": torch.tensor(attn_mask, dtype=torch.long), + "labels": torch.tensor(labels, dtype=torch.long), } + # ----------------- 专用 Collator:pad inputs, pad labels=-100 ----------------- class SFTDataCollator: def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None): @@ -387,7 +397,8 @@ def main(): eval_dataset = ListDataset(eval_samples) # 更稳:联调阶段不强行 pad 到 4096 - data_collator = SFTDataCollator(tokenizer, pad_to_length=None) + # data_collator = SFTDataCollator(tokenizer, pad_to_length=None) + data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len) os.makedirs(args.output_dir, exist_ok=True) logging_dir = os.path.join(args.output_dir, "logs")