diff --git a/ds_config_zero3.json b/ds_config_zero3.json deleted file mode 100644 index 07b2d36..0000000 --- a/ds_config_zero3.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - "train_micro_batch_size_per_gpu": 1, - "gradient_accumulation_steps": 1, - - "zero_optimization": { - "stage": 3, - "overlap_comm": true, - "contiguous_gradients": true, - - "reduce_bucket_size": 80000000, - "stage3_prefetch_bucket_size": 40000000, - "stage3_param_persistence_threshold": 0, - - "offload_optimizer": { "device": "none" }, - "offload_param": { "device": "none" }, - - "stage3_gather_16bit_weights_on_model_save": false - }, - - "bf16": { "enabled": true }, - "fp16": { "enabled": false }, - - "gradient_clipping": 1.0, - "wall_clock_breakdown": false -} diff --git a/ds_config_zero3_cpu_offload.json b/ds_config_zero3_cpu.json similarity index 100% rename from ds_config_zero3_cpu_offload.json rename to ds_config_zero3_cpu.json diff --git a/train_sft_ds.py b/train_sft_ds.py index 62e5368..905abc4 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -847,7 +847,6 @@ def main(): logging_steps=args.log_interval, save_steps=args.save_steps, save_total_limit=2, - optim="adamw_torch", # deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None), deepspeed=(args.deepspeed if use_ds else None), dataloader_drop_last=False, @@ -884,29 +883,6 @@ def main(): trainer_kwargs["tokenizer"] = tokenizer - - decay_params, no_decay_params = [], [] - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if any(nd in n for nd in ["bias", "LayerNorm.weight", "layer_norm.weight", "norm.weight", "ln_f.weight"]): - no_decay_params.append(p) - else: - decay_params.append(p) - - optimizer_grouped_parameters = [ - {"params": decay_params, "weight_decay": args.weight_decay}, - {"params": no_decay_params, "weight_decay": 0.0}, - ] - - optimizer = TorchAdamW( - optimizer_grouped_parameters, - lr=args.learning_rate, - betas=(0.9, 0.999), - eps=1e-8, - ) - - trainer = DebugTrainer( model=model, args=training_args, @@ -915,7 +891,6 @@ def main(): #tokenizer=tokenizer, #processing_class=tokenizer, data_collator=data_collator, - optimizers=(optimizer, None), **trainer_kwargs, )