From 591b6ceb00b11eff5d54035b703c37ecb744fc81 Mon Sep 17 00:00:00 2001 From: hailin Date: Tue, 9 Sep 2025 22:43:41 +0800 Subject: [PATCH] . --- .deepspeed_env | 2 +- ds_config_zero3_lora.json | 14 +++++++++----- train_mm_zero3_lora.sh | 15 +++++++++------ train_sft_lora.py | 9 ++++++--- 4 files changed, 25 insertions(+), 15 deletions(-) diff --git a/.deepspeed_env b/.deepspeed_env index 84b51fe..29b45df 100644 --- a/.deepspeed_env +++ b/.deepspeed_env @@ -3,7 +3,7 @@ WANDB_API_KEY=local-701636f51b4741d3862007df5cf7f12cca53d8d1 WANDB_PROJECT=ds-qwen3 WANDB_ENTITY=hailin WANDB_GROUP=q3-32b-ds4-2025-09-05 -WANDB_NAME=q3-32b-lr2e-5-train2 +WANDB_NAME=q3-32b-lr2e-5-train3 WANDB_RESUME=allow WANDB_INIT_TIMEOUT=300 WANDB_DIR=/tmp/$USER/wandb diff --git a/ds_config_zero3_lora.json b/ds_config_zero3_lora.json index 2757e03..0e53a51 100644 --- a/ds_config_zero3_lora.json +++ b/ds_config_zero3_lora.json @@ -1,6 +1,6 @@ { "train_micro_batch_size_per_gpu": 1, - "gradient_accumulation_steps": 1, + "gradient_accumulation_steps": 4, "bf16": { "enabled": true }, "fp16": { "enabled": false }, @@ -9,13 +9,17 @@ "stage": 3, "overlap_comm": true, "contiguous_gradients": true, - "reduce_bucket_size": 500000000, - "stage3_prefetch_bucket_size": 200000000, + "allgather_partitions": true, + "reduce_scatter": true, + "round_robin_gradients": true, + "reduce_bucket_size": 150000000, + "stage3_prefetch_bucket_size": 100000000, "stage3_param_persistence_threshold": 1000000, "offload_optimizer": { "device": "none" }, - "offload_param": { "device": "none" } + "offload_param": { "device": "none" } }, - + + "stage3_gather_16bit_weights_on_model_save": false, "gradient_clipping": 1.0, "wall_clock_breakdown": false } diff --git a/train_mm_zero3_lora.sh b/train_mm_zero3_lora.sh index 2770e0c..4775743 100755 --- a/train_mm_zero3_lora.sh +++ b/train_mm_zero3_lora.sh @@ -1,4 +1,4 @@ -FORCE_COLOR=1 deepspeed --hostfile hostfile \ +FORCE_COLOR=1 deepspeed --hostfile hostfile \ --num_nodes 6 --num_gpus 4 \ train_sft_lora.py \ --model_name_or_path /home/test/Qwen3-32B \ @@ -6,12 +6,15 @@ FORCE_COLOR=1 deepspeed --hostfile hostfile \ --output_dir /home/test/checkpoints/q3-32b-lora \ --seq_len 1024 \ --bf16 \ - --gradient_accumulation_steps 1 \ --per_device_train_batch_size 1 \ - --learning_rate 1e-4 \ + --gradient_accumulation_steps 4 \ + --learning_rate 2e-4 \ --warmup_ratio 0.03 \ - --lora_r 16 --lora_alpha 32 --lora_dropout 0.05 \ - --max_steps 62 \ + --lora_r 16 --lora_alpha 64 --lora_dropout 0.05 \ + --lora_exclude lm_head \ + --max_steps 3000 \ + --log_interval 10 \ + --eval_steps 50 \ --gradient_checkpointing \ - --deepspeed /home/test/jd_train/ds_config_zero3_lora.json \ + --deepspeed /home/test/jd_train/ds_config_zero3_lora_gpu.json \ --report_to wandb --wandb_project ds-qwen3-lora diff --git a/train_sft_lora.py b/train_sft_lora.py index 3232513..528d94f 100644 --- a/train_sft_lora.py +++ b/train_sft_lora.py @@ -201,7 +201,9 @@ class QwenChatSFTDataset(IterableDataset): rank = int(os.environ.get("RANK", "0")) lrank = int(os.environ.get("LOCAL_RANK", "-1")) - for ex in self.ex_iter: + it = self.ex_iter() if callable(self.ex_iter) else iter(self.ex_iter) + for ex in it: + # for ex in self.ex_iter: msgs = ex.get("messages") if not msgs or not isinstance(msgs, list): continue @@ -657,7 +659,8 @@ def main(): assert bool((labs[attn == 0] == -100).all()), "[fatal] padded tokens must have label -100" ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True).shuffle(buffer_size=50000, seed=args.seed) - train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len) + # train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len) + train_stream = QwenChatSFTDataset(ds_stream2, tokenizer, seq_len=args.seq_len) ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len) @@ -781,7 +784,7 @@ def main(): learning_rate=args.learning_rate, weight_decay=args.weight_decay, warmup_ratio=args.warmup_ratio, - num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0, + num_train_epochs = args.num_train_epochs, max_steps=args.max_steps if args.max_steps > 0 else -1, lr_scheduler_type="cosine", logging_steps=args.log_interval,