This commit is contained in:
hailin 2025-09-01 19:18:24 +08:00
parent 603b23dab0
commit d1c6564255
2 changed files with 28 additions and 7 deletions

View File

@ -14,12 +14,8 @@
"offload_optimizer": { "device": "none" }, "offload_optimizer": { "device": "none" },
"offload_param": { "device": "none" }, "offload_param": { "device": "none" },
"stage3_gather_16bit_weights_on_model_save": false "stage3_gather_16bit_weights_on_model_save": false,
}, "zero_allow_untested_optimizer": true
"optimizer": {
"type": "AdamW",
"params": { "lr": 2e-5, "betas": [0.9, 0.999], "eps": 1e-8, "weight_decay": 0.1 }
}, },
"bf16": { "enabled": true }, "bf16": { "enabled": true },

View File

@ -24,7 +24,7 @@ from transformers import (
) )
from transformers.trainer_callback import TrainerCallback from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from torch.optim import AdamW as TorchAdamW
# ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ==== # ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ====
import os, sys, site, shutil import os, sys, site, shutil
@ -883,6 +883,30 @@ def main():
else: else:
trainer_kwargs["tokenizer"] = tokenizer trainer_kwargs["tokenizer"] = tokenizer
decay_params, no_decay_params = [], []
for n, p in model.named_parameters():
if not p.requires_grad:
continue
if any(nd in n for nd in ["bias", "LayerNorm.weight", "layer_norm.weight", "norm.weight", "ln_f.weight"]):
no_decay_params.append(p)
else:
decay_params.append(p)
optimizer_grouped_parameters = [
{"params": decay_params, "weight_decay": args.weight_decay},
{"params": no_decay_params, "weight_decay": 0.0},
]
optimizer = TorchAdamW(
optimizer_grouped_parameters,
lr=args.learning_rate,
betas=(0.9, 0.999),
eps=1e-8,
)
trainer = DebugTrainer( trainer = DebugTrainer(
model=model, model=model,
args=training_args, args=training_args,
@ -891,6 +915,7 @@ def main():
#tokenizer=tokenizer, #tokenizer=tokenizer,
#processing_class=tokenizer, #processing_class=tokenizer,
data_collator=data_collator, data_collator=data_collator,
optimizers=(optimizer, None),
**trainer_kwargs, **trainer_kwargs,
) )