This commit is contained in:
parent
8248562758
commit
d8c6d00d0d
|
|
@ -3,6 +3,7 @@ FORCE_COLOR=1 deepspeed --hostfile hostfile \
|
|||
train_sft_lora.py \
|
||||
--model_name_or_path /home/test/Qwen3-32B \
|
||||
--data_glob "/home/test/datasets/my_corpus/train*.jsonl" \
|
||||
--eval_data_glob "/home/test/jd_train/datasets/test/*.jsonl" \
|
||||
--output_dir /home/test/checkpoints/q3-32b-lora \
|
||||
--seq_len 512 \
|
||||
--bf16 \
|
||||
|
|
@ -12,9 +13,11 @@ FORCE_COLOR=1 deepspeed --hostfile hostfile \
|
|||
--warmup_ratio 0.03 \
|
||||
--lora_r 32 --lora_alpha 64 --lora_dropout 0.05 \
|
||||
--lora_exclude lm_head \
|
||||
--max_steps 3000 \
|
||||
--max_steps 300 \
|
||||
--log_interval 10 \
|
||||
--eval_steps 200 \
|
||||
--eval_steps 50 \
|
||||
--save_steps 50 \
|
||||
--save_total_limit 4 \
|
||||
--gradient_checkpointing \
|
||||
--deepspeed /home/test/jd_train/ds_config_zero3_lora.json \
|
||||
--report_to wandb --wandb_project ds-qwen3-lora
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from transformers import (
|
|||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from torch.optim import AdamW as TorchAdamW
|
||||
from transformers import EarlyStoppingCallback
|
||||
|
||||
# ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ====
|
||||
import site, shutil
|
||||
|
|
@ -368,6 +369,7 @@ def parse_args():
|
|||
ap.add_argument("--per_device_eval_batch_size", type=int, default=1)
|
||||
ap.add_argument("--deepspeed", type=str, default=None)
|
||||
ap.add_argument("--eval_steps", type=int, default=10)
|
||||
ap.add_argument("--save_total_limit", type=int, default=2)
|
||||
|
||||
# ===== LoRA 相关 =====
|
||||
ap.add_argument("--lora_r", type=int, default=16)
|
||||
|
|
@ -808,7 +810,7 @@ def main():
|
|||
lr_scheduler_type="cosine",
|
||||
logging_steps=args.log_interval,
|
||||
save_steps=args.save_steps,
|
||||
save_total_limit=2,
|
||||
save_total_limit=args.save_total_limit,
|
||||
deepspeed=(args.deepspeed if use_ds else None),
|
||||
dataloader_drop_last=False,
|
||||
dataloader_num_workers=0,
|
||||
|
|
@ -827,6 +829,12 @@ def main():
|
|||
ta_kwargs2["torch_compile"] = False
|
||||
ta_kwargs2.update({"bf16": (dtype==torch.bfloat16), "fp16": (dtype==torch.float16)})
|
||||
|
||||
ta_kwargs2.update(dict(
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model="eval_loss",
|
||||
greater_is_better=False,
|
||||
))
|
||||
|
||||
training_args = TrainingArguments(**ta_kwargs2)
|
||||
|
||||
trainer_kwargs = {}
|
||||
|
|
@ -893,6 +901,8 @@ def main():
|
|||
print(f"[resume] {resume_flag} not found locally -> disable resume.", flush=True)
|
||||
resume_flag = None
|
||||
|
||||
trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=1e-3))
|
||||
|
||||
print_once(f"[resume] final = {resume_flag if resume_flag else 'None (fresh start)'}")
|
||||
print_once("***** Starting LoRA training *****")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue