This commit is contained in:
hailin 2025-09-01 19:18:24 +08:00
parent 603b23dab0
commit d1c6564255
2 changed files with 28 additions and 7 deletions

View File

@ -14,12 +14,8 @@
"offload_optimizer": { "device": "none" },
"offload_param": { "device": "none" },
"stage3_gather_16bit_weights_on_model_save": false
},
"optimizer": {
"type": "AdamW",
"params": { "lr": 2e-5, "betas": [0.9, 0.999], "eps": 1e-8, "weight_decay": 0.1 }
"stage3_gather_16bit_weights_on_model_save": false,
"zero_allow_untested_optimizer": true
},
"bf16": { "enabled": true },

View File

@ -24,7 +24,7 @@ from transformers import (
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import get_last_checkpoint
from torch.optim import AdamW as TorchAdamW
# ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ====
import os, sys, site, shutil
@ -883,6 +883,30 @@ def main():
else:
trainer_kwargs["tokenizer"] = tokenizer
decay_params, no_decay_params = [], []
for n, p in model.named_parameters():
if not p.requires_grad:
continue
if any(nd in n for nd in ["bias", "LayerNorm.weight", "layer_norm.weight", "norm.weight", "ln_f.weight"]):
no_decay_params.append(p)
else:
decay_params.append(p)
optimizer_grouped_parameters = [
{"params": decay_params, "weight_decay": args.weight_decay},
{"params": no_decay_params, "weight_decay": 0.0},
]
optimizer = TorchAdamW(
optimizer_grouped_parameters,
lr=args.learning_rate,
betas=(0.9, 0.999),
eps=1e-8,
)
trainer = DebugTrainer(
model=model,
args=training_args,
@ -891,6 +915,7 @@ def main():
#tokenizer=tokenizer,
#processing_class=tokenizer,
data_collator=data_collator,
optimizers=(optimizer, None),
**trainer_kwargs,
)