diff --git a/ds_config_zero3.json b/ds_config_zero3.json index 86d9d3b..f128df6 100644 --- a/ds_config_zero3.json +++ b/ds_config_zero3.json @@ -7,9 +7,9 @@ "overlap_comm": true, "contiguous_gradients": true, - "reduce_bucket_size": 150000000, - "stage3_prefetch_bucket_size": 75000000, - "stage3_param_persistence_threshold": 1000000, + "reduce_bucket_size": 100000000, + "stage3_prefetch_bucket_size": 50000000, + "stage3_param_persistence_threshold": 0, "offload_optimizer": { "device": "none" }, "offload_param": { "device": "none" }, diff --git a/mm-zero3.sh b/mm-zero3.sh index 7c05e40..82c008e 100755 --- a/mm-zero3.sh +++ b/mm-zero3.sh @@ -4,6 +4,8 @@ export TORCH_EXTENSIONS_DIR=/tmp/$USER/torch_ext export PYTORCH_CUDA_ALLOC_CONF="max_split_size_mb:128,expandable_segments:True,garbage_collection_threshold:0.9" +export PYTORCH_CUDA_ALLOC_CONF="max_split_size_mb:64" + deepspeed --hostfile hostfile \ --num_nodes 6 --num_gpus 4 \ /home/test/jd_train/train_sft_ds.py \