From d8c6d00d0de48e559555cb051bf110904205fcf1 Mon Sep 17 00:00:00 2001 From: hailin Date: Mon, 22 Sep 2025 16:55:19 +0800 Subject: [PATCH] . --- train_mm_zero3_lora.sh | 7 +++++-- train_sft_lora.py | 12 +++++++++++- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/train_mm_zero3_lora.sh b/train_mm_zero3_lora.sh index 513fd38..b2c5783 100755 --- a/train_mm_zero3_lora.sh +++ b/train_mm_zero3_lora.sh @@ -3,6 +3,7 @@ FORCE_COLOR=1 deepspeed --hostfile hostfile \ train_sft_lora.py \ --model_name_or_path /home/test/Qwen3-32B \ --data_glob "/home/test/datasets/my_corpus/train*.jsonl" \ + --eval_data_glob "/home/test/jd_train/datasets/test/*.jsonl" \ --output_dir /home/test/checkpoints/q3-32b-lora \ --seq_len 512 \ --bf16 \ @@ -12,9 +13,11 @@ FORCE_COLOR=1 deepspeed --hostfile hostfile \ --warmup_ratio 0.03 \ --lora_r 32 --lora_alpha 64 --lora_dropout 0.05 \ --lora_exclude lm_head \ - --max_steps 3000 \ + --max_steps 300 \ --log_interval 10 \ - --eval_steps 200 \ + --eval_steps 50 \ + --save_steps 50 \ + --save_total_limit 4 \ --gradient_checkpointing \ --deepspeed /home/test/jd_train/ds_config_zero3_lora.json \ --report_to wandb --wandb_project ds-qwen3-lora diff --git a/train_sft_lora.py b/train_sft_lora.py index b933ebb..0be24a5 100644 --- a/train_sft_lora.py +++ b/train_sft_lora.py @@ -29,6 +29,7 @@ from transformers import ( from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import get_last_checkpoint from torch.optim import AdamW as TorchAdamW +from transformers import EarlyStoppingCallback # ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ==== import site, shutil @@ -368,6 +369,7 @@ def parse_args(): ap.add_argument("--per_device_eval_batch_size", type=int, default=1) ap.add_argument("--deepspeed", type=str, default=None) ap.add_argument("--eval_steps", type=int, default=10) + ap.add_argument("--save_total_limit", type=int, default=2) # ===== LoRA 相关 ===== ap.add_argument("--lora_r", type=int, default=16) @@ -808,7 +810,7 @@ def main(): lr_scheduler_type="cosine", logging_steps=args.log_interval, save_steps=args.save_steps, - save_total_limit=2, + save_total_limit=args.save_total_limit, deepspeed=(args.deepspeed if use_ds else None), dataloader_drop_last=False, dataloader_num_workers=0, @@ -827,6 +829,12 @@ def main(): ta_kwargs2["torch_compile"] = False ta_kwargs2.update({"bf16": (dtype==torch.bfloat16), "fp16": (dtype==torch.float16)}) + ta_kwargs2.update(dict( + load_best_model_at_end=True, + metric_for_best_model="eval_loss", + greater_is_better=False, + )) + training_args = TrainingArguments(**ta_kwargs2) trainer_kwargs = {} @@ -893,6 +901,8 @@ def main(): print(f"[resume] {resume_flag} not found locally -> disable resume.", flush=True) resume_flag = None + trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=1e-3)) + print_once(f"[resume] final = {resume_flag if resume_flag else 'None (fresh start)'}") print_once("***** Starting LoRA training *****")