This commit is contained in:
parent
603b23dab0
commit
d1c6564255
|
|
@ -14,12 +14,8 @@
|
||||||
"offload_optimizer": { "device": "none" },
|
"offload_optimizer": { "device": "none" },
|
||||||
"offload_param": { "device": "none" },
|
"offload_param": { "device": "none" },
|
||||||
|
|
||||||
"stage3_gather_16bit_weights_on_model_save": false
|
"stage3_gather_16bit_weights_on_model_save": false,
|
||||||
},
|
"zero_allow_untested_optimizer": true
|
||||||
|
|
||||||
"optimizer": {
|
|
||||||
"type": "AdamW",
|
|
||||||
"params": { "lr": 2e-5, "betas": [0.9, 0.999], "eps": 1e-8, "weight_decay": 0.1 }
|
|
||||||
},
|
},
|
||||||
|
|
||||||
"bf16": { "enabled": true },
|
"bf16": { "enabled": true },
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ from transformers import (
|
||||||
)
|
)
|
||||||
from transformers.trainer_callback import TrainerCallback
|
from transformers.trainer_callback import TrainerCallback
|
||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
|
from torch.optim import AdamW as TorchAdamW
|
||||||
|
|
||||||
# ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ====
|
# ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ====
|
||||||
import os, sys, site, shutil
|
import os, sys, site, shutil
|
||||||
|
|
@ -883,6 +883,30 @@ def main():
|
||||||
else:
|
else:
|
||||||
trainer_kwargs["tokenizer"] = tokenizer
|
trainer_kwargs["tokenizer"] = tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
decay_params, no_decay_params = [], []
|
||||||
|
for n, p in model.named_parameters():
|
||||||
|
if not p.requires_grad:
|
||||||
|
continue
|
||||||
|
if any(nd in n for nd in ["bias", "LayerNorm.weight", "layer_norm.weight", "norm.weight", "ln_f.weight"]):
|
||||||
|
no_decay_params.append(p)
|
||||||
|
else:
|
||||||
|
decay_params.append(p)
|
||||||
|
|
||||||
|
optimizer_grouped_parameters = [
|
||||||
|
{"params": decay_params, "weight_decay": args.weight_decay},
|
||||||
|
{"params": no_decay_params, "weight_decay": 0.0},
|
||||||
|
]
|
||||||
|
|
||||||
|
optimizer = TorchAdamW(
|
||||||
|
optimizer_grouped_parameters,
|
||||||
|
lr=args.learning_rate,
|
||||||
|
betas=(0.9, 0.999),
|
||||||
|
eps=1e-8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
trainer = DebugTrainer(
|
trainer = DebugTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
|
|
@ -891,6 +915,7 @@ def main():
|
||||||
#tokenizer=tokenizer,
|
#tokenizer=tokenizer,
|
||||||
#processing_class=tokenizer,
|
#processing_class=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
|
optimizers=(optimizer, None),
|
||||||
**trainer_kwargs,
|
**trainer_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue