This commit is contained in:
parent
3a8ead6971
commit
fd9971f1c8
|
|
@ -2,8 +2,8 @@ WANDB_BASE_URL=https://wandb.szaiai.com
|
|||
WANDB_API_KEY=local-701636f51b4741d3862007df5cf7f12cca53d8d1
|
||||
WANDB_PROJECT=ds-qwen3
|
||||
WANDB_ENTITY=hailin
|
||||
WANDB_GROUP=q3-32b-ds4-2025-09-05
|
||||
WANDB_NAME=q3-32b-lr2e-5-train3
|
||||
WANDB_GROUP=q3-32b-ds4-2025-09-24
|
||||
WANDB_NAME=q3-32b-lr2e-5-train1
|
||||
WANDB_RESUME=allow
|
||||
WANDB_INIT_TIMEOUT=300
|
||||
WANDB_DIR=/tmp/$USER/wandb
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ FORCE_COLOR=1 deepspeed --hostfile hostfile \
|
|||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--learning_rate 2e-5 --weight_decay 0.1 --warmup_ratio 0.02 \
|
||||
--max_steps 20 \
|
||||
--max_steps 300 \
|
||||
--log_interval 1 \
|
||||
--gradient_checkpointing \
|
||||
--bf16 \
|
||||
|
|
@ -43,5 +43,10 @@ FORCE_COLOR=1 deepspeed --hostfile hostfile \
|
|||
--report_to wandb \
|
||||
--wandb_project ds-qwen3 \
|
||||
--eval_steps 10 \
|
||||
--save_steps 10 \
|
||||
--load_best_model_at_end \
|
||||
--early_stopping_patience 5 \
|
||||
--early_stopping_threshold 0.0 \
|
||||
--metric_for_best_model eval_loss \
|
||||
--eval_data_glob "/home/test/datasets/my_corpus/test.jsonl"
|
||||
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from transformers import (
|
|||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from torch.optim import AdamW as TorchAdamW
|
||||
from transformers import EarlyStoppingCallback
|
||||
|
||||
# ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ====
|
||||
import site, shutil
|
||||
|
|
@ -444,6 +445,17 @@ def parse_args():
|
|||
ap.add_argument("--eval_steps", type=int, default=10,
|
||||
help="Evaluate every N optimizer steps when eval_dataset is provided")
|
||||
|
||||
ap.add_argument("--load_best_model_at_end", action="store_true",
|
||||
help="训练结束时自动加载最优 checkpoint")
|
||||
ap.add_argument("--metric_for_best_model", type=str, default="eval_loss",
|
||||
help="用哪个指标选最优,默认 eval_loss")
|
||||
ap.add_argument("--greater_is_better", action="store_true",
|
||||
help="是否指标越大越好;eval_loss 用 False(默认不传即可)")
|
||||
ap.add_argument("--early_stopping_patience", type=int, default=0,
|
||||
help=">0 启用早停;单位是 eval 轮次数(非 step 数)")
|
||||
ap.add_argument("--early_stopping_threshold", type=float, default=0.0,
|
||||
help="改进阈值,0 表示严格变好才算改进")
|
||||
|
||||
return ap.parse_args()
|
||||
|
||||
|
||||
|
|
@ -921,7 +933,7 @@ def main():
|
|||
lr_scheduler_type="cosine",
|
||||
logging_steps=args.log_interval,
|
||||
save_steps=args.save_steps,
|
||||
save_total_limit=2,
|
||||
save_total_limit=None,
|
||||
deepspeed=(args.deepspeed if use_ds else None),
|
||||
dataloader_drop_last=False,
|
||||
dataloader_num_workers=0,
|
||||
|
|
@ -948,6 +960,21 @@ def main():
|
|||
"fp16": (torch.cuda.is_available() and not use_bf16),
|
||||
})
|
||||
|
||||
ta_sig = inspect.signature(TrainingArguments.__init__).parameters
|
||||
if "save_strategy" in ta_sig:
|
||||
ta_kwargs2["save_strategy"] = "steps"
|
||||
|
||||
ta_kwargs2.update({
|
||||
"load_best_model_at_end": args.load_best_model_at_end,
|
||||
"metric_for_best_model": args.metric_for_best_model,
|
||||
"greater_is_better": args.greater_is_better, # 对 eval_loss 保持 False(默认)
|
||||
# "save_strategy": "steps", # 与 eval_steps 对齐
|
||||
})
|
||||
|
||||
|
||||
if args.early_stopping_patience > 0 and eval_dataset is None:
|
||||
print("[warn] early_stopping_patience>0 但未提供 eval 数据集;早停将不会触发。", flush=True)
|
||||
|
||||
training_args = TrainingArguments(**ta_kwargs2)
|
||||
|
||||
trainer_kwargs = {}
|
||||
|
|
@ -966,6 +993,13 @@ def main():
|
|||
**trainer_kwargs,
|
||||
)
|
||||
|
||||
|
||||
if args.early_stopping_patience and args.early_stopping_patience > 0:
|
||||
trainer.add_callback(EarlyStoppingCallback(
|
||||
early_stopping_patience=args.early_stopping_patience,
|
||||
early_stopping_threshold=args.early_stopping_threshold
|
||||
))
|
||||
|
||||
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue