This commit is contained in:
parent
603b23dab0
commit
d1c6564255
|
|
@ -14,12 +14,8 @@
|
|||
"offload_optimizer": { "device": "none" },
|
||||
"offload_param": { "device": "none" },
|
||||
|
||||
"stage3_gather_16bit_weights_on_model_save": false
|
||||
},
|
||||
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": { "lr": 2e-5, "betas": [0.9, 0.999], "eps": 1e-8, "weight_decay": 0.1 }
|
||||
"stage3_gather_16bit_weights_on_model_save": false,
|
||||
"zero_allow_untested_optimizer": true
|
||||
},
|
||||
|
||||
"bf16": { "enabled": true },
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from transformers import (
|
|||
)
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
from torch.optim import AdamW as TorchAdamW
|
||||
|
||||
# ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ====
|
||||
import os, sys, site, shutil
|
||||
|
|
@ -883,6 +883,30 @@ def main():
|
|||
else:
|
||||
trainer_kwargs["tokenizer"] = tokenizer
|
||||
|
||||
|
||||
|
||||
decay_params, no_decay_params = [], []
|
||||
for n, p in model.named_parameters():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
if any(nd in n for nd in ["bias", "LayerNorm.weight", "layer_norm.weight", "norm.weight", "ln_f.weight"]):
|
||||
no_decay_params.append(p)
|
||||
else:
|
||||
decay_params.append(p)
|
||||
|
||||
optimizer_grouped_parameters = [
|
||||
{"params": decay_params, "weight_decay": args.weight_decay},
|
||||
{"params": no_decay_params, "weight_decay": 0.0},
|
||||
]
|
||||
|
||||
optimizer = TorchAdamW(
|
||||
optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
)
|
||||
|
||||
|
||||
trainer = DebugTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
|
|
@ -891,6 +915,7 @@ def main():
|
|||
#tokenizer=tokenizer,
|
||||
#processing_class=tokenizer,
|
||||
data_collator=data_collator,
|
||||
optimizers=(optimizer, None),
|
||||
**trainer_kwargs,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue