This commit is contained in:
hailin 2025-09-09 23:37:15 +08:00
parent c3dfd25132
commit b343ac5529
1 changed files with 19 additions and 0 deletions

View File

@ -329,6 +329,11 @@ class SFTDataCollator:
batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long))
batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long))
# ensure this batch has supervised tokens
has_sup = any((lab != -100).any().item() for lab in batch_lab)
if not has_sup:
raise RuntimeError("batch has zero supervised tokens; check masking or dataset.")
return {
"input_ids": torch.stack(batch_inp, dim=0),
"attention_mask": torch.stack(batch_attn, dim=0),
@ -604,10 +609,24 @@ def main():
)
model = get_peft_model(model, lora_cfg)
try:
model.print_trainable_parameters()
except Exception:
pass
# 3) 再次配置梯度检查点(注入后调用更稳)
if args.gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
# 关键:让输入参与梯度,从而兼容 checkpoint
try:
model.enable_input_require_grads()
except AttributeError:
# 旧版 transformers 兜底:给 embedding 输出打 requires_grad
emb = model.get_input_embeddings()
if hasattr(emb, "register_forward_hook"):
emb.register_forward_hook(lambda m, inp, out: out.requires_grad_(True))
# 4) 打印可训练参数占比
try:
from peft import get_peft_model_state_dict