This commit is contained in:
hailin 2025-09-24 17:30:31 +08:00
parent 3a8ead6971
commit fd9971f1c8
3 changed files with 43 additions and 4 deletions

View File

@ -2,8 +2,8 @@ WANDB_BASE_URL=https://wandb.szaiai.com
WANDB_API_KEY=local-701636f51b4741d3862007df5cf7f12cca53d8d1 WANDB_API_KEY=local-701636f51b4741d3862007df5cf7f12cca53d8d1
WANDB_PROJECT=ds-qwen3 WANDB_PROJECT=ds-qwen3
WANDB_ENTITY=hailin WANDB_ENTITY=hailin
WANDB_GROUP=q3-32b-ds4-2025-09-05 WANDB_GROUP=q3-32b-ds4-2025-09-24
WANDB_NAME=q3-32b-lr2e-5-train3 WANDB_NAME=q3-32b-lr2e-5-train1
WANDB_RESUME=allow WANDB_RESUME=allow
WANDB_INIT_TIMEOUT=300 WANDB_INIT_TIMEOUT=300
WANDB_DIR=/tmp/$USER/wandb WANDB_DIR=/tmp/$USER/wandb

View File

@ -35,7 +35,7 @@ FORCE_COLOR=1 deepspeed --hostfile hostfile \
--per_device_train_batch_size 1 \ --per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \ --gradient_accumulation_steps 1 \
--learning_rate 2e-5 --weight_decay 0.1 --warmup_ratio 0.02 \ --learning_rate 2e-5 --weight_decay 0.1 --warmup_ratio 0.02 \
--max_steps 20 \ --max_steps 300 \
--log_interval 1 \ --log_interval 1 \
--gradient_checkpointing \ --gradient_checkpointing \
--bf16 \ --bf16 \
@ -43,5 +43,10 @@ FORCE_COLOR=1 deepspeed --hostfile hostfile \
--report_to wandb \ --report_to wandb \
--wandb_project ds-qwen3 \ --wandb_project ds-qwen3 \
--eval_steps 10 \ --eval_steps 10 \
--save_steps 10 \
--load_best_model_at_end \
--early_stopping_patience 5 \
--early_stopping_threshold 0.0 \
--metric_for_best_model eval_loss \
--eval_data_glob "/home/test/datasets/my_corpus/test.jsonl" --eval_data_glob "/home/test/datasets/my_corpus/test.jsonl"

View File

@ -28,6 +28,7 @@ from transformers import (
from transformers.trainer_callback import TrainerCallback from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from torch.optim import AdamW as TorchAdamW from torch.optim import AdamW as TorchAdamW
from transformers import EarlyStoppingCallback
# ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ==== # ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ====
import site, shutil import site, shutil
@ -444,6 +445,17 @@ def parse_args():
ap.add_argument("--eval_steps", type=int, default=10, ap.add_argument("--eval_steps", type=int, default=10,
help="Evaluate every N optimizer steps when eval_dataset is provided") help="Evaluate every N optimizer steps when eval_dataset is provided")
ap.add_argument("--load_best_model_at_end", action="store_true",
help="训练结束时自动加载最优 checkpoint")
ap.add_argument("--metric_for_best_model", type=str, default="eval_loss",
help="用哪个指标选最优,默认 eval_loss")
ap.add_argument("--greater_is_better", action="store_true",
help="是否指标越大越好eval_loss 用 False默认不传即可")
ap.add_argument("--early_stopping_patience", type=int, default=0,
help=">0 启用早停;单位是 eval 轮次数(非 step 数)")
ap.add_argument("--early_stopping_threshold", type=float, default=0.0,
help="改进阈值0 表示严格变好才算改进")
return ap.parse_args() return ap.parse_args()
@ -921,7 +933,7 @@ def main():
lr_scheduler_type="cosine", lr_scheduler_type="cosine",
logging_steps=args.log_interval, logging_steps=args.log_interval,
save_steps=args.save_steps, save_steps=args.save_steps,
save_total_limit=2, save_total_limit=None,
deepspeed=(args.deepspeed if use_ds else None), deepspeed=(args.deepspeed if use_ds else None),
dataloader_drop_last=False, dataloader_drop_last=False,
dataloader_num_workers=0, dataloader_num_workers=0,
@ -948,6 +960,21 @@ def main():
"fp16": (torch.cuda.is_available() and not use_bf16), "fp16": (torch.cuda.is_available() and not use_bf16),
}) })
ta_sig = inspect.signature(TrainingArguments.__init__).parameters
if "save_strategy" in ta_sig:
ta_kwargs2["save_strategy"] = "steps"
ta_kwargs2.update({
"load_best_model_at_end": args.load_best_model_at_end,
"metric_for_best_model": args.metric_for_best_model,
"greater_is_better": args.greater_is_better, # 对 eval_loss 保持 False默认
# "save_strategy": "steps", # 与 eval_steps 对齐
})
if args.early_stopping_patience > 0 and eval_dataset is None:
print("[warn] early_stopping_patience>0 但未提供 eval 数据集;早停将不会触发。", flush=True)
training_args = TrainingArguments(**ta_kwargs2) training_args = TrainingArguments(**ta_kwargs2)
trainer_kwargs = {} trainer_kwargs = {}
@ -966,6 +993,13 @@ def main():
**trainer_kwargs, **trainer_kwargs,
) )
if args.early_stopping_patience and args.early_stopping_patience > 0:
trainer.add_callback(EarlyStoppingCallback(
early_stopping_patience=args.early_stopping_patience,
early_stopping_threshold=args.early_stopping_threshold
))
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv"))) trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))