From 36044de8dfa0276d7d4e365693148a16e260da44 Mon Sep 17 00:00:00 2001 From: hailin Date: Fri, 29 Aug 2025 15:41:28 +0800 Subject: [PATCH] =?UTF-8?q?.=E2=80=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ds_config_zero3.json | 12 ++++++++---- mm-zero3.sh | 2 ++ train_sft_ds.py | 41 ++++++++++++++++++++++++++++++++++------- 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/ds_config_zero3.json b/ds_config_zero3.json index 54d81fc..c725416 100644 --- a/ds_config_zero3.json +++ b/ds_config_zero3.json @@ -11,10 +11,14 @@ "overlap_comm": true, "contiguous_gradients": true, "reduce_scatter": true, - "reduce_bucket_size": 2e8, - "stage3_prefetch_bucket_size": 2e8, - "stage3_param_persistence_threshold": 1e6, - "stage3_gather_16bit_weights_on_model_save": false + "reduce_bucket_size": 50000000, + "stage3_prefetch_bucket_size": 50000000, + "stage3_param_persistence_threshold": 0, + "stage3_gather_16bit_weights_on_model_save": false, + + "offload_param": { "device": "cpu", "pin_memory": true }, + "offload_optimizer": { "device": "cpu", "pin_memory": true } + }, "wall_clock_breakdown": false } diff --git a/mm-zero3.sh b/mm-zero3.sh index 77b1858..6319333 100755 --- a/mm-zero3.sh +++ b/mm-zero3.sh @@ -1,3 +1,5 @@ +export PYTORCH_CUDA_ALLOC_CONF="max_split_size_mb:128,expandable_segments:True,garbage_collection_threshold:0.9" + deepspeed --hostfile hostfile \ --num_nodes 6 --num_gpus 4 \ /home/test/jd_train/train_sft_ds.py \ diff --git a/train_sft_ds.py b/train_sft_ds.py index e6bd4cb..9114270 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -11,6 +11,8 @@ from typing import Dict, List, Iterable, Iterator, Tuple, Optional import torch import torch.distributed as dist from torch.utils.data import IterableDataset, Dataset +from contextlib import nullcontext + from datasets import load_dataset from transformers import ( @@ -482,13 +484,38 @@ def main(): (torch.float16 if torch.cuda.is_available() else torch.float32)) - model = AutoModelForCausalLM.from_pretrained( - args.model_name_or_path, - torch_dtype=dtype, - low_cpu_mem_usage=True, - trust_remote_code=True, - attn_implementation="sdpa" - ) + + + try: + import deepspeed + zero_init_ctx = deepspeed.zero.Init( + remote_device="cpu", # 参数初始驻留 CPU,安全 + pin_memory=True, + dtype=dtype # 和你上面的 dtype 一致(bf16) + ) + except Exception: + zero_init_ctx = nullcontext() # 没装 DS 时也能单机跑 + + with zero_init_ctx: + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + torch_dtype=dtype, + low_cpu_mem_usage=True, + trust_remote_code=True, + attn_implementation="sdpa" + ) + + + + + + # model = AutoModelForCausalLM.from_pretrained( + # args.model_name_or_path, + # torch_dtype=dtype, + # low_cpu_mem_usage=True, + # trust_remote_code=True, + # attn_implementation="sdpa" + # ) print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True) dbg(f"model loaded: dtype={next(model.parameters()).dtype} "