This commit is contained in:
hailin 2025-09-22 16:55:19 +08:00
parent 8248562758
commit d8c6d00d0d
2 changed files with 16 additions and 3 deletions

View File

@ -3,6 +3,7 @@ FORCE_COLOR=1 deepspeed --hostfile hostfile \
train_sft_lora.py \ train_sft_lora.py \
--model_name_or_path /home/test/Qwen3-32B \ --model_name_or_path /home/test/Qwen3-32B \
--data_glob "/home/test/datasets/my_corpus/train*.jsonl" \ --data_glob "/home/test/datasets/my_corpus/train*.jsonl" \
--eval_data_glob "/home/test/jd_train/datasets/test/*.jsonl" \
--output_dir /home/test/checkpoints/q3-32b-lora \ --output_dir /home/test/checkpoints/q3-32b-lora \
--seq_len 512 \ --seq_len 512 \
--bf16 \ --bf16 \
@ -12,9 +13,11 @@ FORCE_COLOR=1 deepspeed --hostfile hostfile \
--warmup_ratio 0.03 \ --warmup_ratio 0.03 \
--lora_r 32 --lora_alpha 64 --lora_dropout 0.05 \ --lora_r 32 --lora_alpha 64 --lora_dropout 0.05 \
--lora_exclude lm_head \ --lora_exclude lm_head \
--max_steps 3000 \ --max_steps 300 \
--log_interval 10 \ --log_interval 10 \
--eval_steps 200 \ --eval_steps 50 \
--save_steps 50 \
--save_total_limit 4 \
--gradient_checkpointing \ --gradient_checkpointing \
--deepspeed /home/test/jd_train/ds_config_zero3_lora.json \ --deepspeed /home/test/jd_train/ds_config_zero3_lora.json \
--report_to wandb --wandb_project ds-qwen3-lora --report_to wandb --wandb_project ds-qwen3-lora

View File

@ -29,6 +29,7 @@ from transformers import (
from transformers.trainer_callback import TrainerCallback from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from torch.optim import AdamW as TorchAdamW from torch.optim import AdamW as TorchAdamW
from transformers import EarlyStoppingCallback
# ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ==== # ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ====
import site, shutil import site, shutil
@ -368,6 +369,7 @@ def parse_args():
ap.add_argument("--per_device_eval_batch_size", type=int, default=1) ap.add_argument("--per_device_eval_batch_size", type=int, default=1)
ap.add_argument("--deepspeed", type=str, default=None) ap.add_argument("--deepspeed", type=str, default=None)
ap.add_argument("--eval_steps", type=int, default=10) ap.add_argument("--eval_steps", type=int, default=10)
ap.add_argument("--save_total_limit", type=int, default=2)
# ===== LoRA 相关 ===== # ===== LoRA 相关 =====
ap.add_argument("--lora_r", type=int, default=16) ap.add_argument("--lora_r", type=int, default=16)
@ -808,7 +810,7 @@ def main():
lr_scheduler_type="cosine", lr_scheduler_type="cosine",
logging_steps=args.log_interval, logging_steps=args.log_interval,
save_steps=args.save_steps, save_steps=args.save_steps,
save_total_limit=2, save_total_limit=args.save_total_limit,
deepspeed=(args.deepspeed if use_ds else None), deepspeed=(args.deepspeed if use_ds else None),
dataloader_drop_last=False, dataloader_drop_last=False,
dataloader_num_workers=0, dataloader_num_workers=0,
@ -827,6 +829,12 @@ def main():
ta_kwargs2["torch_compile"] = False ta_kwargs2["torch_compile"] = False
ta_kwargs2.update({"bf16": (dtype==torch.bfloat16), "fp16": (dtype==torch.float16)}) ta_kwargs2.update({"bf16": (dtype==torch.bfloat16), "fp16": (dtype==torch.float16)})
ta_kwargs2.update(dict(
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
))
training_args = TrainingArguments(**ta_kwargs2) training_args = TrainingArguments(**ta_kwargs2)
trainer_kwargs = {} trainer_kwargs = {}
@ -893,6 +901,8 @@ def main():
print(f"[resume] {resume_flag} not found locally -> disable resume.", flush=True) print(f"[resume] {resume_flag} not found locally -> disable resume.", flush=True)
resume_flag = None resume_flag = None
trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=1e-3))
print_once(f"[resume] final = {resume_flag if resume_flag else 'None (fresh start)'}") print_once(f"[resume] final = {resume_flag if resume_flag else 'None (fresh start)'}")
print_once("***** Starting LoRA training *****") print_once("***** Starting LoRA training *****")