This commit is contained in:
hailin 2025-08-29 15:41:28 +08:00
parent 890506c505
commit 36044de8df
3 changed files with 44 additions and 11 deletions

View File

@ -11,10 +11,14 @@
"overlap_comm": true, "overlap_comm": true,
"contiguous_gradients": true, "contiguous_gradients": true,
"reduce_scatter": true, "reduce_scatter": true,
"reduce_bucket_size": 2e8, "reduce_bucket_size": 50000000,
"stage3_prefetch_bucket_size": 2e8, "stage3_prefetch_bucket_size": 50000000,
"stage3_param_persistence_threshold": 1e6, "stage3_param_persistence_threshold": 0,
"stage3_gather_16bit_weights_on_model_save": false "stage3_gather_16bit_weights_on_model_save": false,
"offload_param": { "device": "cpu", "pin_memory": true },
"offload_optimizer": { "device": "cpu", "pin_memory": true }
}, },
"wall_clock_breakdown": false "wall_clock_breakdown": false
} }

View File

@ -1,3 +1,5 @@
export PYTORCH_CUDA_ALLOC_CONF="max_split_size_mb:128,expandable_segments:True,garbage_collection_threshold:0.9"
deepspeed --hostfile hostfile \ deepspeed --hostfile hostfile \
--num_nodes 6 --num_gpus 4 \ --num_nodes 6 --num_gpus 4 \
/home/test/jd_train/train_sft_ds.py \ /home/test/jd_train/train_sft_ds.py \

View File

@ -11,6 +11,8 @@ from typing import Dict, List, Iterable, Iterator, Tuple, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.utils.data import IterableDataset, Dataset from torch.utils.data import IterableDataset, Dataset
from contextlib import nullcontext
from datasets import load_dataset from datasets import load_dataset
from transformers import ( from transformers import (
@ -482,13 +484,38 @@ def main():
(torch.float16 if torch.cuda.is_available() else torch.float32)) (torch.float16 if torch.cuda.is_available() else torch.float32))
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
torch_dtype=dtype, try:
low_cpu_mem_usage=True, import deepspeed
trust_remote_code=True, zero_init_ctx = deepspeed.zero.Init(
attn_implementation="sdpa" remote_device="cpu", # 参数初始驻留 CPU安全
) pin_memory=True,
dtype=dtype # 和你上面的 dtype 一致bf16
)
except Exception:
zero_init_ctx = nullcontext() # 没装 DS 时也能单机跑
with zero_init_ctx:
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
torch_dtype=dtype,
low_cpu_mem_usage=True,
trust_remote_code=True,
attn_implementation="sdpa"
)
# model = AutoModelForCausalLM.from_pretrained(
# args.model_name_or_path,
# torch_dtype=dtype,
# low_cpu_mem_usage=True,
# trust_remote_code=True,
# attn_implementation="sdpa"
# )
print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True) print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True)
dbg(f"model loaded: dtype={next(model.parameters()).dtype} " dbg(f"model loaded: dtype={next(model.parameters()).dtype} "