This commit is contained in:
parent
8248562758
commit
d8c6d00d0d
|
|
@ -3,6 +3,7 @@ FORCE_COLOR=1 deepspeed --hostfile hostfile \
|
||||||
train_sft_lora.py \
|
train_sft_lora.py \
|
||||||
--model_name_or_path /home/test/Qwen3-32B \
|
--model_name_or_path /home/test/Qwen3-32B \
|
||||||
--data_glob "/home/test/datasets/my_corpus/train*.jsonl" \
|
--data_glob "/home/test/datasets/my_corpus/train*.jsonl" \
|
||||||
|
--eval_data_glob "/home/test/jd_train/datasets/test/*.jsonl" \
|
||||||
--output_dir /home/test/checkpoints/q3-32b-lora \
|
--output_dir /home/test/checkpoints/q3-32b-lora \
|
||||||
--seq_len 512 \
|
--seq_len 512 \
|
||||||
--bf16 \
|
--bf16 \
|
||||||
|
|
@ -12,9 +13,11 @@ FORCE_COLOR=1 deepspeed --hostfile hostfile \
|
||||||
--warmup_ratio 0.03 \
|
--warmup_ratio 0.03 \
|
||||||
--lora_r 32 --lora_alpha 64 --lora_dropout 0.05 \
|
--lora_r 32 --lora_alpha 64 --lora_dropout 0.05 \
|
||||||
--lora_exclude lm_head \
|
--lora_exclude lm_head \
|
||||||
--max_steps 3000 \
|
--max_steps 300 \
|
||||||
--log_interval 10 \
|
--log_interval 10 \
|
||||||
--eval_steps 200 \
|
--eval_steps 50 \
|
||||||
|
--save_steps 50 \
|
||||||
|
--save_total_limit 4 \
|
||||||
--gradient_checkpointing \
|
--gradient_checkpointing \
|
||||||
--deepspeed /home/test/jd_train/ds_config_zero3_lora.json \
|
--deepspeed /home/test/jd_train/ds_config_zero3_lora.json \
|
||||||
--report_to wandb --wandb_project ds-qwen3-lora
|
--report_to wandb --wandb_project ds-qwen3-lora
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ from transformers import (
|
||||||
from transformers.trainer_callback import TrainerCallback
|
from transformers.trainer_callback import TrainerCallback
|
||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
from torch.optim import AdamW as TorchAdamW
|
from torch.optim import AdamW as TorchAdamW
|
||||||
|
from transformers import EarlyStoppingCallback
|
||||||
|
|
||||||
# ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ====
|
# ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ====
|
||||||
import site, shutil
|
import site, shutil
|
||||||
|
|
@ -368,6 +369,7 @@ def parse_args():
|
||||||
ap.add_argument("--per_device_eval_batch_size", type=int, default=1)
|
ap.add_argument("--per_device_eval_batch_size", type=int, default=1)
|
||||||
ap.add_argument("--deepspeed", type=str, default=None)
|
ap.add_argument("--deepspeed", type=str, default=None)
|
||||||
ap.add_argument("--eval_steps", type=int, default=10)
|
ap.add_argument("--eval_steps", type=int, default=10)
|
||||||
|
ap.add_argument("--save_total_limit", type=int, default=2)
|
||||||
|
|
||||||
# ===== LoRA 相关 =====
|
# ===== LoRA 相关 =====
|
||||||
ap.add_argument("--lora_r", type=int, default=16)
|
ap.add_argument("--lora_r", type=int, default=16)
|
||||||
|
|
@ -808,7 +810,7 @@ def main():
|
||||||
lr_scheduler_type="cosine",
|
lr_scheduler_type="cosine",
|
||||||
logging_steps=args.log_interval,
|
logging_steps=args.log_interval,
|
||||||
save_steps=args.save_steps,
|
save_steps=args.save_steps,
|
||||||
save_total_limit=2,
|
save_total_limit=args.save_total_limit,
|
||||||
deepspeed=(args.deepspeed if use_ds else None),
|
deepspeed=(args.deepspeed if use_ds else None),
|
||||||
dataloader_drop_last=False,
|
dataloader_drop_last=False,
|
||||||
dataloader_num_workers=0,
|
dataloader_num_workers=0,
|
||||||
|
|
@ -827,6 +829,12 @@ def main():
|
||||||
ta_kwargs2["torch_compile"] = False
|
ta_kwargs2["torch_compile"] = False
|
||||||
ta_kwargs2.update({"bf16": (dtype==torch.bfloat16), "fp16": (dtype==torch.float16)})
|
ta_kwargs2.update({"bf16": (dtype==torch.bfloat16), "fp16": (dtype==torch.float16)})
|
||||||
|
|
||||||
|
ta_kwargs2.update(dict(
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
metric_for_best_model="eval_loss",
|
||||||
|
greater_is_better=False,
|
||||||
|
))
|
||||||
|
|
||||||
training_args = TrainingArguments(**ta_kwargs2)
|
training_args = TrainingArguments(**ta_kwargs2)
|
||||||
|
|
||||||
trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
|
|
@ -893,6 +901,8 @@ def main():
|
||||||
print(f"[resume] {resume_flag} not found locally -> disable resume.", flush=True)
|
print(f"[resume] {resume_flag} not found locally -> disable resume.", flush=True)
|
||||||
resume_flag = None
|
resume_flag = None
|
||||||
|
|
||||||
|
trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=1e-3))
|
||||||
|
|
||||||
print_once(f"[resume] final = {resume_flag if resume_flag else 'None (fresh start)'}")
|
print_once(f"[resume] final = {resume_flag if resume_flag else 'None (fresh start)'}")
|
||||||
print_once("***** Starting LoRA training *****")
|
print_once("***** Starting LoRA training *****")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue