This commit is contained in:
hailin 2025-09-09 22:43:41 +08:00
parent d129562a43
commit 591b6ceb00
4 changed files with 25 additions and 15 deletions

View File

@ -3,7 +3,7 @@ WANDB_API_KEY=local-701636f51b4741d3862007df5cf7f12cca53d8d1
WANDB_PROJECT=ds-qwen3 WANDB_PROJECT=ds-qwen3
WANDB_ENTITY=hailin WANDB_ENTITY=hailin
WANDB_GROUP=q3-32b-ds4-2025-09-05 WANDB_GROUP=q3-32b-ds4-2025-09-05
WANDB_NAME=q3-32b-lr2e-5-train2 WANDB_NAME=q3-32b-lr2e-5-train3
WANDB_RESUME=allow WANDB_RESUME=allow
WANDB_INIT_TIMEOUT=300 WANDB_INIT_TIMEOUT=300
WANDB_DIR=/tmp/$USER/wandb WANDB_DIR=/tmp/$USER/wandb

View File

@ -1,6 +1,6 @@
{ {
"train_micro_batch_size_per_gpu": 1, "train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 4,
"bf16": { "enabled": true }, "bf16": { "enabled": true },
"fp16": { "enabled": false }, "fp16": { "enabled": false },
@ -9,13 +9,17 @@
"stage": 3, "stage": 3,
"overlap_comm": true, "overlap_comm": true,
"contiguous_gradients": true, "contiguous_gradients": true,
"reduce_bucket_size": 500000000, "allgather_partitions": true,
"stage3_prefetch_bucket_size": 200000000, "reduce_scatter": true,
"round_robin_gradients": true,
"reduce_bucket_size": 150000000,
"stage3_prefetch_bucket_size": 100000000,
"stage3_param_persistence_threshold": 1000000, "stage3_param_persistence_threshold": 1000000,
"offload_optimizer": { "device": "none" }, "offload_optimizer": { "device": "none" },
"offload_param": { "device": "none" } "offload_param": { "device": "none" }
}, },
"stage3_gather_16bit_weights_on_model_save": false,
"gradient_clipping": 1.0, "gradient_clipping": 1.0,
"wall_clock_breakdown": false "wall_clock_breakdown": false
} }

View File

@ -1,4 +1,4 @@
FORCE_COLOR=1 deepspeed --hostfile hostfile \ FORCE_COLOR=1 deepspeed --hostfile hostfile \
--num_nodes 6 --num_gpus 4 \ --num_nodes 6 --num_gpus 4 \
train_sft_lora.py \ train_sft_lora.py \
--model_name_or_path /home/test/Qwen3-32B \ --model_name_or_path /home/test/Qwen3-32B \
@ -6,12 +6,15 @@ FORCE_COLOR=1 deepspeed --hostfile hostfile \
--output_dir /home/test/checkpoints/q3-32b-lora \ --output_dir /home/test/checkpoints/q3-32b-lora \
--seq_len 1024 \ --seq_len 1024 \
--bf16 \ --bf16 \
--gradient_accumulation_steps 1 \
--per_device_train_batch_size 1 \ --per_device_train_batch_size 1 \
--learning_rate 1e-4 \ --gradient_accumulation_steps 4 \
--learning_rate 2e-4 \
--warmup_ratio 0.03 \ --warmup_ratio 0.03 \
--lora_r 16 --lora_alpha 32 --lora_dropout 0.05 \ --lora_r 16 --lora_alpha 64 --lora_dropout 0.05 \
--max_steps 62 \ --lora_exclude lm_head \
--max_steps 3000 \
--log_interval 10 \
--eval_steps 50 \
--gradient_checkpointing \ --gradient_checkpointing \
--deepspeed /home/test/jd_train/ds_config_zero3_lora.json \ --deepspeed /home/test/jd_train/ds_config_zero3_lora_gpu.json \
--report_to wandb --wandb_project ds-qwen3-lora --report_to wandb --wandb_project ds-qwen3-lora

View File

@ -201,7 +201,9 @@ class QwenChatSFTDataset(IterableDataset):
rank = int(os.environ.get("RANK", "0")) rank = int(os.environ.get("RANK", "0"))
lrank = int(os.environ.get("LOCAL_RANK", "-1")) lrank = int(os.environ.get("LOCAL_RANK", "-1"))
for ex in self.ex_iter: it = self.ex_iter() if callable(self.ex_iter) else iter(self.ex_iter)
for ex in it:
# for ex in self.ex_iter:
msgs = ex.get("messages") msgs = ex.get("messages")
if not msgs or not isinstance(msgs, list): if not msgs or not isinstance(msgs, list):
continue continue
@ -657,7 +659,8 @@ def main():
assert bool((labs[attn == 0] == -100).all()), "[fatal] padded tokens must have label -100" assert bool((labs[attn == 0] == -100).all()), "[fatal] padded tokens must have label -100"
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True).shuffle(buffer_size=50000, seed=args.seed) ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True).shuffle(buffer_size=50000, seed=args.seed)
train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len) # train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
train_stream = QwenChatSFTDataset(ds_stream2, tokenizer, seq_len=args.seq_len)
ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len) probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len)
@ -781,7 +784,7 @@ def main():
learning_rate=args.learning_rate, learning_rate=args.learning_rate,
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
warmup_ratio=args.warmup_ratio, warmup_ratio=args.warmup_ratio,
num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0, num_train_epochs = args.num_train_epochs,
max_steps=args.max_steps if args.max_steps > 0 else -1, max_steps=args.max_steps if args.max_steps > 0 else -1,
lr_scheduler_type="cosine", lr_scheduler_type="cosine",
logging_steps=args.log_interval, logging_steps=args.log_interval,