This commit is contained in:
parent
c3dfd25132
commit
b343ac5529
|
|
@ -329,6 +329,11 @@ class SFTDataCollator:
|
||||||
batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long))
|
batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long))
|
||||||
batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long))
|
batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long))
|
||||||
|
|
||||||
|
# ensure this batch has supervised tokens
|
||||||
|
has_sup = any((lab != -100).any().item() for lab in batch_lab)
|
||||||
|
if not has_sup:
|
||||||
|
raise RuntimeError("batch has zero supervised tokens; check masking or dataset.")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": torch.stack(batch_inp, dim=0),
|
"input_ids": torch.stack(batch_inp, dim=0),
|
||||||
"attention_mask": torch.stack(batch_attn, dim=0),
|
"attention_mask": torch.stack(batch_attn, dim=0),
|
||||||
|
|
@ -604,10 +609,24 @@ def main():
|
||||||
)
|
)
|
||||||
model = get_peft_model(model, lora_cfg)
|
model = get_peft_model(model, lora_cfg)
|
||||||
|
|
||||||
|
try:
|
||||||
|
model.print_trainable_parameters()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# 3) 再次配置梯度检查点(注入后调用更稳)
|
# 3) 再次配置梯度检查点(注入后调用更稳)
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||||
|
|
||||||
|
# 关键:让输入参与梯度,从而兼容 checkpoint
|
||||||
|
try:
|
||||||
|
model.enable_input_require_grads()
|
||||||
|
except AttributeError:
|
||||||
|
# 旧版 transformers 兜底:给 embedding 输出打 requires_grad
|
||||||
|
emb = model.get_input_embeddings()
|
||||||
|
if hasattr(emb, "register_forward_hook"):
|
||||||
|
emb.register_forward_hook(lambda m, inp, out: out.requires_grad_(True))
|
||||||
|
|
||||||
# 4) 打印可训练参数占比
|
# 4) 打印可训练参数占比
|
||||||
try:
|
try:
|
||||||
from peft import get_peft_model_state_dict
|
from peft import get_peft_model_state_dict
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue