This commit is contained in:
hailin 2025-08-29 15:41:28 +08:00
parent 890506c505
commit 36044de8df
3 changed files with 44 additions and 11 deletions

View File

@ -11,10 +11,14 @@
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"stage3_prefetch_bucket_size": 2e8,
"stage3_param_persistence_threshold": 1e6,
"stage3_gather_16bit_weights_on_model_save": false
"reduce_bucket_size": 50000000,
"stage3_prefetch_bucket_size": 50000000,
"stage3_param_persistence_threshold": 0,
"stage3_gather_16bit_weights_on_model_save": false,
"offload_param": { "device": "cpu", "pin_memory": true },
"offload_optimizer": { "device": "cpu", "pin_memory": true }
},
"wall_clock_breakdown": false
}

View File

@ -1,3 +1,5 @@
export PYTORCH_CUDA_ALLOC_CONF="max_split_size_mb:128,expandable_segments:True,garbage_collection_threshold:0.9"
deepspeed --hostfile hostfile \
--num_nodes 6 --num_gpus 4 \
/home/test/jd_train/train_sft_ds.py \

View File

@ -11,6 +11,8 @@ from typing import Dict, List, Iterable, Iterator, Tuple, Optional
import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset, Dataset
from contextlib import nullcontext
from datasets import load_dataset
from transformers import (
@ -482,6 +484,19 @@ def main():
(torch.float16 if torch.cuda.is_available() else torch.float32))
try:
import deepspeed
zero_init_ctx = deepspeed.zero.Init(
remote_device="cpu", # 参数初始驻留 CPU安全
pin_memory=True,
dtype=dtype # 和你上面的 dtype 一致bf16
)
except Exception:
zero_init_ctx = nullcontext() # 没装 DS 时也能单机跑
with zero_init_ctx:
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
torch_dtype=dtype,
@ -490,6 +505,18 @@ def main():
attn_implementation="sdpa"
)
# model = AutoModelForCausalLM.from_pretrained(
# args.model_name_or_path,
# torch_dtype=dtype,
# low_cpu_mem_usage=True,
# trust_remote_code=True,
# attn_implementation="sdpa"
# )
print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True)
dbg(f"model loaded: dtype={next(model.parameters()).dtype} "
f"use_cache={getattr(model.config,'use_cache',None)} "