This commit is contained in:
parent
a1934ce954
commit
e9e7626ae7
|
|
@ -95,7 +95,7 @@ class QwenChatSFTDataset(IterableDataset):
|
|||
if not msgs or not isinstance(msgs, list):
|
||||
continue
|
||||
|
||||
# 可选:过滤掉带有非空 <think>…</think> 的样本(避免训练真实 COT)
|
||||
# 可选过滤 think
|
||||
bad = False
|
||||
for m in msgs:
|
||||
if m.get("role") == "assistant" and isinstance(m.get("content"), str):
|
||||
|
|
@ -109,22 +109,16 @@ class QwenChatSFTDataset(IterableDataset):
|
|||
|
||||
tools = ex.get("tools", None)
|
||||
|
||||
# 1) 按模型自带模板渲染
|
||||
rendered: str = self.tok.apply_chat_template(
|
||||
msgs,
|
||||
tools=tools,
|
||||
add_generation_prompt=False,
|
||||
tokenize=False
|
||||
msgs, tools=tools, add_generation_prompt=False, tokenize=False
|
||||
)
|
||||
if not isinstance(rendered, str) or not rendered.strip():
|
||||
continue
|
||||
|
||||
# 2) 找出 assistant 片段的字符区间
|
||||
spans = _assistant_char_spans(rendered)
|
||||
if not spans:
|
||||
continue
|
||||
|
||||
# 3) 分词 + 字符/Token 对齐
|
||||
enc = self.tok(
|
||||
rendered,
|
||||
add_special_tokens=False,
|
||||
|
|
@ -133,11 +127,9 @@ class QwenChatSFTDataset(IterableDataset):
|
|||
input_ids: List[int] = enc["input_ids"]
|
||||
offsets: List[Tuple[int, int]] = enc["offset_mapping"]
|
||||
|
||||
# 空样本防御:分词后长度为 0
|
||||
if not input_ids:
|
||||
continue
|
||||
|
||||
# 4) 仅 assistant 计损失
|
||||
labels = [-100] * len(input_ids)
|
||||
|
||||
def in_any_span(lo: int, hi: int) -> bool:
|
||||
|
|
@ -150,22 +142,40 @@ class QwenChatSFTDataset(IterableDataset):
|
|||
if in_any_span(lo, hi):
|
||||
labels[i] = input_ids[i]
|
||||
|
||||
# 5) 超长裁剪(保留尾部)
|
||||
# —— 固定长度策略:先截尾,再在 Dataset 层补到固定 seq_len ——
|
||||
# 1) 截断到 seq_len(保留尾部)
|
||||
if len(input_ids) > self.seq_len:
|
||||
input_ids = input_ids[-self.seq_len:]
|
||||
labels = labels[-self.seq_len:]
|
||||
|
||||
# 若没有任何可训练 token(labels 全 -100),也跳过
|
||||
# 2) 左侧补齐到 seq_len(保证所有样本长度一致)
|
||||
pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id
|
||||
L = len(input_ids)
|
||||
if L < self.seq_len:
|
||||
pad = self.seq_len - L
|
||||
input_ids = ([pad_id] * pad) + input_ids
|
||||
labels = ([-100] * pad) + labels
|
||||
attn_mask = [0] * pad + [1] * L
|
||||
else:
|
||||
# 恰好等于 seq_len
|
||||
attn_mask = [1] * self.seq_len
|
||||
|
||||
# 若没有任何可训练 token(labels 全 -100),跳过
|
||||
if all(v == -100 for v in labels):
|
||||
continue
|
||||
|
||||
assert len(input_ids) == self.seq_len
|
||||
assert len(labels) == self.seq_len
|
||||
assert len(attn_mask) == self.seq_len
|
||||
|
||||
yield {
|
||||
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
||||
"attention_mask": torch.ones(len(input_ids), dtype=torch.long),
|
||||
"labels": torch.tensor(labels, dtype=torch.long)
|
||||
"attention_mask": torch.tensor(attn_mask, dtype=torch.long),
|
||||
"labels": torch.tensor(labels, dtype=torch.long),
|
||||
}
|
||||
|
||||
|
||||
|
||||
# ----------------- 专用 Collator:pad inputs, pad labels=-100 -----------------
|
||||
class SFTDataCollator:
|
||||
def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None):
|
||||
|
|
@ -387,7 +397,8 @@ def main():
|
|||
eval_dataset = ListDataset(eval_samples)
|
||||
|
||||
# 更稳:联调阶段不强行 pad 到 4096
|
||||
data_collator = SFTDataCollator(tokenizer, pad_to_length=None)
|
||||
# data_collator = SFTDataCollator(tokenizer, pad_to_length=None)
|
||||
data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len)
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
logging_dir = os.path.join(args.output_dir, "logs")
|
||||
|
|
|
|||
Loading…
Reference in New Issue