From d1c6564255ea978e39ce19cd384b436adec00d87 Mon Sep 17 00:00:00 2001 From: hailin Date: Mon, 1 Sep 2025 19:18:24 +0800 Subject: [PATCH] . --- ds_config_zero3.json | 8 ++------ train_sft_ds.py | 27 ++++++++++++++++++++++++++- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/ds_config_zero3.json b/ds_config_zero3.json index 87f998c..defa626 100644 --- a/ds_config_zero3.json +++ b/ds_config_zero3.json @@ -14,12 +14,8 @@ "offload_optimizer": { "device": "none" }, "offload_param": { "device": "none" }, - "stage3_gather_16bit_weights_on_model_save": false - }, - - "optimizer": { - "type": "AdamW", - "params": { "lr": 2e-5, "betas": [0.9, 0.999], "eps": 1e-8, "weight_decay": 0.1 } + "stage3_gather_16bit_weights_on_model_save": false, + "zero_allow_untested_optimizer": true }, "bf16": { "enabled": true }, diff --git a/train_sft_ds.py b/train_sft_ds.py index 5038b91..62e5368 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -24,7 +24,7 @@ from transformers import ( ) from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import get_last_checkpoint - +from torch.optim import AdamW as TorchAdamW # ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ==== import os, sys, site, shutil @@ -883,6 +883,30 @@ def main(): else: trainer_kwargs["tokenizer"] = tokenizer + + + decay_params, no_decay_params = [], [] + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if any(nd in n for nd in ["bias", "LayerNorm.weight", "layer_norm.weight", "norm.weight", "ln_f.weight"]): + no_decay_params.append(p) + else: + decay_params.append(p) + + optimizer_grouped_parameters = [ + {"params": decay_params, "weight_decay": args.weight_decay}, + {"params": no_decay_params, "weight_decay": 0.0}, + ] + + optimizer = TorchAdamW( + optimizer_grouped_parameters, + lr=args.learning_rate, + betas=(0.9, 0.999), + eps=1e-8, + ) + + trainer = DebugTrainer( model=model, args=training_args, @@ -891,6 +915,7 @@ def main(): #tokenizer=tokenizer, #processing_class=tokenizer, data_collator=data_collator, + optimizers=(optimizer, None), **trainer_kwargs, )