This commit is contained in:
hailin 2025-08-28 21:47:43 +08:00
parent ff39f31718
commit 4ad2be2d34
3 changed files with 7 additions and 40 deletions

View File

@ -11,9 +11,9 @@
"overlap_comm": true, "overlap_comm": true,
"contiguous_gradients": true, "contiguous_gradients": true,
"reduce_scatter": true, "reduce_scatter": true,
"reduce_bucket_size": 50000000, "reduce_bucket_size": 2e8,
"stage3_prefetch_bucket_size": 50000000, "stage3_prefetch_bucket_size": 2e8,
"stage3_param_persistence_threshold": 100000, "stage3_param_persistence_threshold": 1e6,
"stage3_gather_16bit_weights_on_model_save": true "stage3_gather_16bit_weights_on_model_save": true
}, },
"wall_clock_breakdown": false "wall_clock_breakdown": false

View File

@ -11,6 +11,7 @@ torchrun --nproc_per_node 4 /home/test/jd_train/train_sft_ds.py \
--gradient_accumulation_steps 1 \ --gradient_accumulation_steps 1 \
--learning_rate 2e-5 --weight_decay 0.1 --warmup_ratio 0.02 \ --learning_rate 2e-5 --weight_decay 0.1 --warmup_ratio 0.02 \
--max_steps 375 --log_interval 1 \ --max_steps 375 --log_interval 1 \
--gradient_checkpointing \
--bf16 \ --bf16 \
--deepspeed /home/test/jd_train/ds_config_zero3.json \ --deepspeed /home/test/jd_train/ds_config_zero3.json \
--report_to none \ --report_to none \

View File

@ -484,12 +484,13 @@ def main():
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path, args.model_name_or_path,
# torch_dtype=(torch.bfloat16 if args.bf16 else torch.float16),
torch_dtype=dtype, torch_dtype=dtype,
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
trust_remote_code=True trust_remote_code=True,
attn_implementation="flash_attention_2"
) )
print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True)
dbg(f"model loaded: dtype={next(model.parameters()).dtype} " dbg(f"model loaded: dtype={next(model.parameters()).dtype} "
f"use_cache={getattr(model.config,'use_cache',None)} " f"use_cache={getattr(model.config,'use_cache',None)} "
f"pad_token_id={getattr(model.config,'pad_token_id',None)}") f"pad_token_id={getattr(model.config,'pad_token_id',None)}")
@ -696,41 +697,6 @@ def main():
elif "evaluation_strategy" in sig: elif "evaluation_strategy" in sig:
ta_kwargs["evaluation_strategy"] = "no" ta_kwargs["evaluation_strategy"] = "no"
# training_args = TrainingArguments(
# output_dir=args.output_dir,
# logging_dir=logging_dir,
# do_train=True,
# do_eval=(eval_dataset is not None),
# eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
# per_device_train_batch_size=args.per_device_train_batch_size,
# gradient_accumulation_steps=args.gradient_accumulation_steps,
# learning_rate=args.learning_rate,
# weight_decay=args.weight_decay,
# warmup_ratio=args.warmup_ratio,
# num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0,
# max_steps=args.max_steps if args.max_steps > 0 else -1,
# lr_scheduler_type="cosine",
# logging_steps=args.log_interval,
# save_steps=args.save_steps,
# save_total_limit=2,
# deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None),
# dataloader_drop_last=False, # 关键:别丢尾,避免空 batch
# dataloader_num_workers=0,
# dataloader_prefetch_factor=None,
# dataloader_pin_memory=False,
# per_device_eval_batch_size=args.per_device_eval_batch_size,
# report_to=([] if args.report_to == "none" else [args.report_to]),
# bf16=args.bf16,
# fp16=(not args.bf16),
# gradient_checkpointing=args.gradient_checkpointing,
# remove_unused_columns=False,
# torch_compile=False,
# save_on_each_node=True,
# logging_first_step=True,
# **ta_kwargs,
# )
ta_sig = inspect.signature(TrainingArguments.__init__).parameters ta_sig = inspect.signature(TrainingArguments.__init__).parameters
ta_kwargs2 = dict( ta_kwargs2 = dict(
output_dir=args.output_dir, output_dir=args.output_dir,