This commit is contained in:
hailin 2025-08-26 08:43:15 +08:00
parent a1934ce954
commit e9e7626ae7
1 changed files with 26 additions and 15 deletions

View File

@ -95,7 +95,7 @@ class QwenChatSFTDataset(IterableDataset):
if not msgs or not isinstance(msgs, list):
continue
# 可选过滤掉带有非空 <think>…</think> 的样本(避免训练真实 COT
# 可选过滤 think
bad = False
for m in msgs:
if m.get("role") == "assistant" and isinstance(m.get("content"), str):
@ -109,22 +109,16 @@ class QwenChatSFTDataset(IterableDataset):
tools = ex.get("tools", None)
# 1) 按模型自带模板渲染
rendered: str = self.tok.apply_chat_template(
msgs,
tools=tools,
add_generation_prompt=False,
tokenize=False
msgs, tools=tools, add_generation_prompt=False, tokenize=False
)
if not isinstance(rendered, str) or not rendered.strip():
continue
# 2) 找出 assistant 片段的字符区间
spans = _assistant_char_spans(rendered)
if not spans:
continue
# 3) 分词 + 字符/Token 对齐
enc = self.tok(
rendered,
add_special_tokens=False,
@ -133,11 +127,9 @@ class QwenChatSFTDataset(IterableDataset):
input_ids: List[int] = enc["input_ids"]
offsets: List[Tuple[int, int]] = enc["offset_mapping"]
# 空样本防御:分词后长度为 0
if not input_ids:
continue
# 4) 仅 assistant 计损失
labels = [-100] * len(input_ids)
def in_any_span(lo: int, hi: int) -> bool:
@ -150,22 +142,40 @@ class QwenChatSFTDataset(IterableDataset):
if in_any_span(lo, hi):
labels[i] = input_ids[i]
# 5) 超长裁剪(保留尾部)
# —— 固定长度策略:先截尾,再在 Dataset 层补到固定 seq_len ——
# 1) 截断到 seq_len保留尾部
if len(input_ids) > self.seq_len:
input_ids = input_ids[-self.seq_len:]
labels = labels[-self.seq_len:]
# 若没有任何可训练 tokenlabels 全 -100也跳过
# 2) 左侧补齐到 seq_len保证所有样本长度一致
pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id
L = len(input_ids)
if L < self.seq_len:
pad = self.seq_len - L
input_ids = ([pad_id] * pad) + input_ids
labels = ([-100] * pad) + labels
attn_mask = [0] * pad + [1] * L
else:
# 恰好等于 seq_len
attn_mask = [1] * self.seq_len
# 若没有任何可训练 tokenlabels 全 -100跳过
if all(v == -100 for v in labels):
continue
assert len(input_ids) == self.seq_len
assert len(labels) == self.seq_len
assert len(attn_mask) == self.seq_len
yield {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.ones(len(input_ids), dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long)
"attention_mask": torch.tensor(attn_mask, dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long),
}
# ----------------- 专用 Collatorpad inputs, pad labels=-100 -----------------
class SFTDataCollator:
def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None):
@ -387,7 +397,8 @@ def main():
eval_dataset = ListDataset(eval_samples)
# 更稳:联调阶段不强行 pad 到 4096
data_collator = SFTDataCollator(tokenizer, pad_to_length=None)
# data_collator = SFTDataCollator(tokenizer, pad_to_length=None)
data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len)
os.makedirs(args.output_dir, exist_ok=True)
logging_dir = os.path.join(args.output_dir, "logs")