This commit is contained in:
parent
3a8ead6971
commit
fd9971f1c8
|
|
@ -2,8 +2,8 @@ WANDB_BASE_URL=https://wandb.szaiai.com
|
||||||
WANDB_API_KEY=local-701636f51b4741d3862007df5cf7f12cca53d8d1
|
WANDB_API_KEY=local-701636f51b4741d3862007df5cf7f12cca53d8d1
|
||||||
WANDB_PROJECT=ds-qwen3
|
WANDB_PROJECT=ds-qwen3
|
||||||
WANDB_ENTITY=hailin
|
WANDB_ENTITY=hailin
|
||||||
WANDB_GROUP=q3-32b-ds4-2025-09-05
|
WANDB_GROUP=q3-32b-ds4-2025-09-24
|
||||||
WANDB_NAME=q3-32b-lr2e-5-train3
|
WANDB_NAME=q3-32b-lr2e-5-train1
|
||||||
WANDB_RESUME=allow
|
WANDB_RESUME=allow
|
||||||
WANDB_INIT_TIMEOUT=300
|
WANDB_INIT_TIMEOUT=300
|
||||||
WANDB_DIR=/tmp/$USER/wandb
|
WANDB_DIR=/tmp/$USER/wandb
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ FORCE_COLOR=1 deepspeed --hostfile hostfile \
|
||||||
--per_device_train_batch_size 1 \
|
--per_device_train_batch_size 1 \
|
||||||
--gradient_accumulation_steps 1 \
|
--gradient_accumulation_steps 1 \
|
||||||
--learning_rate 2e-5 --weight_decay 0.1 --warmup_ratio 0.02 \
|
--learning_rate 2e-5 --weight_decay 0.1 --warmup_ratio 0.02 \
|
||||||
--max_steps 20 \
|
--max_steps 300 \
|
||||||
--log_interval 1 \
|
--log_interval 1 \
|
||||||
--gradient_checkpointing \
|
--gradient_checkpointing \
|
||||||
--bf16 \
|
--bf16 \
|
||||||
|
|
@ -43,5 +43,10 @@ FORCE_COLOR=1 deepspeed --hostfile hostfile \
|
||||||
--report_to wandb \
|
--report_to wandb \
|
||||||
--wandb_project ds-qwen3 \
|
--wandb_project ds-qwen3 \
|
||||||
--eval_steps 10 \
|
--eval_steps 10 \
|
||||||
|
--save_steps 10 \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--early_stopping_patience 5 \
|
||||||
|
--early_stopping_threshold 0.0 \
|
||||||
|
--metric_for_best_model eval_loss \
|
||||||
--eval_data_glob "/home/test/datasets/my_corpus/test.jsonl"
|
--eval_data_glob "/home/test/datasets/my_corpus/test.jsonl"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ from transformers import (
|
||||||
from transformers.trainer_callback import TrainerCallback
|
from transformers.trainer_callback import TrainerCallback
|
||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
from torch.optim import AdamW as TorchAdamW
|
from torch.optim import AdamW as TorchAdamW
|
||||||
|
from transformers import EarlyStoppingCallback
|
||||||
|
|
||||||
# ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ====
|
# ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ====
|
||||||
import site, shutil
|
import site, shutil
|
||||||
|
|
@ -444,6 +445,17 @@ def parse_args():
|
||||||
ap.add_argument("--eval_steps", type=int, default=10,
|
ap.add_argument("--eval_steps", type=int, default=10,
|
||||||
help="Evaluate every N optimizer steps when eval_dataset is provided")
|
help="Evaluate every N optimizer steps when eval_dataset is provided")
|
||||||
|
|
||||||
|
ap.add_argument("--load_best_model_at_end", action="store_true",
|
||||||
|
help="训练结束时自动加载最优 checkpoint")
|
||||||
|
ap.add_argument("--metric_for_best_model", type=str, default="eval_loss",
|
||||||
|
help="用哪个指标选最优,默认 eval_loss")
|
||||||
|
ap.add_argument("--greater_is_better", action="store_true",
|
||||||
|
help="是否指标越大越好;eval_loss 用 False(默认不传即可)")
|
||||||
|
ap.add_argument("--early_stopping_patience", type=int, default=0,
|
||||||
|
help=">0 启用早停;单位是 eval 轮次数(非 step 数)")
|
||||||
|
ap.add_argument("--early_stopping_threshold", type=float, default=0.0,
|
||||||
|
help="改进阈值,0 表示严格变好才算改进")
|
||||||
|
|
||||||
return ap.parse_args()
|
return ap.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -921,7 +933,7 @@ def main():
|
||||||
lr_scheduler_type="cosine",
|
lr_scheduler_type="cosine",
|
||||||
logging_steps=args.log_interval,
|
logging_steps=args.log_interval,
|
||||||
save_steps=args.save_steps,
|
save_steps=args.save_steps,
|
||||||
save_total_limit=2,
|
save_total_limit=None,
|
||||||
deepspeed=(args.deepspeed if use_ds else None),
|
deepspeed=(args.deepspeed if use_ds else None),
|
||||||
dataloader_drop_last=False,
|
dataloader_drop_last=False,
|
||||||
dataloader_num_workers=0,
|
dataloader_num_workers=0,
|
||||||
|
|
@ -948,6 +960,21 @@ def main():
|
||||||
"fp16": (torch.cuda.is_available() and not use_bf16),
|
"fp16": (torch.cuda.is_available() and not use_bf16),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
ta_sig = inspect.signature(TrainingArguments.__init__).parameters
|
||||||
|
if "save_strategy" in ta_sig:
|
||||||
|
ta_kwargs2["save_strategy"] = "steps"
|
||||||
|
|
||||||
|
ta_kwargs2.update({
|
||||||
|
"load_best_model_at_end": args.load_best_model_at_end,
|
||||||
|
"metric_for_best_model": args.metric_for_best_model,
|
||||||
|
"greater_is_better": args.greater_is_better, # 对 eval_loss 保持 False(默认)
|
||||||
|
# "save_strategy": "steps", # 与 eval_steps 对齐
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
if args.early_stopping_patience > 0 and eval_dataset is None:
|
||||||
|
print("[warn] early_stopping_patience>0 但未提供 eval 数据集;早停将不会触发。", flush=True)
|
||||||
|
|
||||||
training_args = TrainingArguments(**ta_kwargs2)
|
training_args = TrainingArguments(**ta_kwargs2)
|
||||||
|
|
||||||
trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
|
|
@ -966,6 +993,13 @@ def main():
|
||||||
**trainer_kwargs,
|
**trainer_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if args.early_stopping_patience and args.early_stopping_patience > 0:
|
||||||
|
trainer.add_callback(EarlyStoppingCallback(
|
||||||
|
early_stopping_patience=args.early_stopping_patience,
|
||||||
|
early_stopping_threshold=args.early_stopping_threshold
|
||||||
|
))
|
||||||
|
|
||||||
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
|
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue