This commit is contained in:
hailin 2025-09-09 22:43:41 +08:00
parent d129562a43
commit 591b6ceb00
4 changed files with 25 additions and 15 deletions

View File

@ -3,7 +3,7 @@ WANDB_API_KEY=local-701636f51b4741d3862007df5cf7f12cca53d8d1
WANDB_PROJECT=ds-qwen3
WANDB_ENTITY=hailin
WANDB_GROUP=q3-32b-ds4-2025-09-05
WANDB_NAME=q3-32b-lr2e-5-train2
WANDB_NAME=q3-32b-lr2e-5-train3
WANDB_RESUME=allow
WANDB_INIT_TIMEOUT=300
WANDB_DIR=/tmp/$USER/wandb

View File

@ -1,6 +1,6 @@
{
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"gradient_accumulation_steps": 4,
"bf16": { "enabled": true },
"fp16": { "enabled": false },
@ -9,13 +9,17 @@
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_bucket_size": 500000000,
"stage3_prefetch_bucket_size": 200000000,
"allgather_partitions": true,
"reduce_scatter": true,
"round_robin_gradients": true,
"reduce_bucket_size": 150000000,
"stage3_prefetch_bucket_size": 100000000,
"stage3_param_persistence_threshold": 1000000,
"offload_optimizer": { "device": "none" },
"offload_param": { "device": "none" }
"offload_param": { "device": "none" }
},
"stage3_gather_16bit_weights_on_model_save": false,
"gradient_clipping": 1.0,
"wall_clock_breakdown": false
}

View File

@ -1,4 +1,4 @@
FORCE_COLOR=1 deepspeed --hostfile hostfile \
FORCE_COLOR=1 deepspeed --hostfile hostfile \
--num_nodes 6 --num_gpus 4 \
train_sft_lora.py \
--model_name_or_path /home/test/Qwen3-32B \
@ -6,12 +6,15 @@ FORCE_COLOR=1 deepspeed --hostfile hostfile \
--output_dir /home/test/checkpoints/q3-32b-lora \
--seq_len 1024 \
--bf16 \
--gradient_accumulation_steps 1 \
--per_device_train_batch_size 1 \
--learning_rate 1e-4 \
--gradient_accumulation_steps 4 \
--learning_rate 2e-4 \
--warmup_ratio 0.03 \
--lora_r 16 --lora_alpha 32 --lora_dropout 0.05 \
--max_steps 62 \
--lora_r 16 --lora_alpha 64 --lora_dropout 0.05 \
--lora_exclude lm_head \
--max_steps 3000 \
--log_interval 10 \
--eval_steps 50 \
--gradient_checkpointing \
--deepspeed /home/test/jd_train/ds_config_zero3_lora.json \
--deepspeed /home/test/jd_train/ds_config_zero3_lora_gpu.json \
--report_to wandb --wandb_project ds-qwen3-lora

View File

@ -201,7 +201,9 @@ class QwenChatSFTDataset(IterableDataset):
rank = int(os.environ.get("RANK", "0"))
lrank = int(os.environ.get("LOCAL_RANK", "-1"))
for ex in self.ex_iter:
it = self.ex_iter() if callable(self.ex_iter) else iter(self.ex_iter)
for ex in it:
# for ex in self.ex_iter:
msgs = ex.get("messages")
if not msgs or not isinstance(msgs, list):
continue
@ -657,7 +659,8 @@ def main():
assert bool((labs[attn == 0] == -100).all()), "[fatal] padded tokens must have label -100"
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True).shuffle(buffer_size=50000, seed=args.seed)
train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
# train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
train_stream = QwenChatSFTDataset(ds_stream2, tokenizer, seq_len=args.seq_len)
ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len)
@ -781,7 +784,7 @@ def main():
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
warmup_ratio=args.warmup_ratio,
num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0,
num_train_epochs = args.num_train_epochs,
max_steps=args.max_steps if args.max_steps > 0 else -1,
lr_scheduler_type="cosine",
logging_steps=args.log_interval,