diff --git a/ds_config_zero3.json b/ds_config_zero3.json index 5961ae9..da426b5 100644 --- a/ds_config_zero3.json +++ b/ds_config_zero3.json @@ -11,9 +11,9 @@ "overlap_comm": true, "contiguous_gradients": true, "reduce_scatter": true, - "reduce_bucket_size": 50000000, - "stage3_prefetch_bucket_size": 50000000, - "stage3_param_persistence_threshold": 100000, + "reduce_bucket_size": 2e8, + "stage3_prefetch_bucket_size": 2e8, + "stage3_param_persistence_threshold": 1e6, "stage3_gather_16bit_weights_on_model_save": true }, "wall_clock_breakdown": false diff --git a/ss-zero3.sh b/ss-zero3.sh index 4bb4782..81aefbe 100755 --- a/ss-zero3.sh +++ b/ss-zero3.sh @@ -11,6 +11,7 @@ torchrun --nproc_per_node 4 /home/test/jd_train/train_sft_ds.py \ --gradient_accumulation_steps 1 \ --learning_rate 2e-5 --weight_decay 0.1 --warmup_ratio 0.02 \ --max_steps 375 --log_interval 1 \ + --gradient_checkpointing \ --bf16 \ --deepspeed /home/test/jd_train/ds_config_zero3.json \ --report_to none \ diff --git a/train_sft_ds.py b/train_sft_ds.py index a478012..c255748 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -484,12 +484,13 @@ def main(): model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, - # torch_dtype=(torch.bfloat16 if args.bf16 else torch.float16), torch_dtype=dtype, low_cpu_mem_usage=True, - trust_remote_code=True + trust_remote_code=True, + attn_implementation="flash_attention_2" ) + print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True) dbg(f"model loaded: dtype={next(model.parameters()).dtype} " f"use_cache={getattr(model.config,'use_cache',None)} " f"pad_token_id={getattr(model.config,'pad_token_id',None)}") @@ -696,41 +697,6 @@ def main(): elif "evaluation_strategy" in sig: ta_kwargs["evaluation_strategy"] = "no" - # training_args = TrainingArguments( - # output_dir=args.output_dir, - # logging_dir=logging_dir, - # do_train=True, - # do_eval=(eval_dataset is not None), - # eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None, - # per_device_train_batch_size=args.per_device_train_batch_size, - # gradient_accumulation_steps=args.gradient_accumulation_steps, - # learning_rate=args.learning_rate, - # weight_decay=args.weight_decay, - # warmup_ratio=args.warmup_ratio, - # num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0, - # max_steps=args.max_steps if args.max_steps > 0 else -1, - # lr_scheduler_type="cosine", - # logging_steps=args.log_interval, - # save_steps=args.save_steps, - # save_total_limit=2, - # deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None), - # dataloader_drop_last=False, # 关键:别丢尾,避免空 batch - # dataloader_num_workers=0, - # dataloader_prefetch_factor=None, - # dataloader_pin_memory=False, - # per_device_eval_batch_size=args.per_device_eval_batch_size, - # report_to=([] if args.report_to == "none" else [args.report_to]), - # bf16=args.bf16, - # fp16=(not args.bf16), - # gradient_checkpointing=args.gradient_checkpointing, - # remove_unused_columns=False, - # torch_compile=False, - # save_on_each_node=True, - # logging_first_step=True, - # **ta_kwargs, - # ) - - ta_sig = inspect.signature(TrainingArguments.__init__).parameters ta_kwargs2 = dict( output_dir=args.output_dir,