This commit is contained in:
parent
890506c505
commit
36044de8df
|
|
@ -11,10 +11,14 @@
|
|||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 2e8,
|
||||
"stage3_prefetch_bucket_size": 2e8,
|
||||
"stage3_param_persistence_threshold": 1e6,
|
||||
"stage3_gather_16bit_weights_on_model_save": false
|
||||
"reduce_bucket_size": 50000000,
|
||||
"stage3_prefetch_bucket_size": 50000000,
|
||||
"stage3_param_persistence_threshold": 0,
|
||||
"stage3_gather_16bit_weights_on_model_save": false,
|
||||
|
||||
"offload_param": { "device": "cpu", "pin_memory": true },
|
||||
"offload_optimizer": { "device": "cpu", "pin_memory": true }
|
||||
|
||||
},
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
export PYTORCH_CUDA_ALLOC_CONF="max_split_size_mb:128,expandable_segments:True,garbage_collection_threshold:0.9"
|
||||
|
||||
deepspeed --hostfile hostfile \
|
||||
--num_nodes 6 --num_gpus 4 \
|
||||
/home/test/jd_train/train_sft_ds.py \
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ from typing import Dict, List, Iterable, Iterator, Tuple, Optional
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import IterableDataset, Dataset
|
||||
from contextlib import nullcontext
|
||||
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import (
|
||||
|
|
@ -482,13 +484,38 @@ def main():
|
|||
(torch.float16 if torch.cuda.is_available() else torch.float32))
|
||||
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
torch_dtype=dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True,
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
import deepspeed
|
||||
zero_init_ctx = deepspeed.zero.Init(
|
||||
remote_device="cpu", # 参数初始驻留 CPU,安全
|
||||
pin_memory=True,
|
||||
dtype=dtype # 和你上面的 dtype 一致(bf16)
|
||||
)
|
||||
except Exception:
|
||||
zero_init_ctx = nullcontext() # 没装 DS 时也能单机跑
|
||||
|
||||
with zero_init_ctx:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
torch_dtype=dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True,
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# model = AutoModelForCausalLM.from_pretrained(
|
||||
# args.model_name_or_path,
|
||||
# torch_dtype=dtype,
|
||||
# low_cpu_mem_usage=True,
|
||||
# trust_remote_code=True,
|
||||
# attn_implementation="sdpa"
|
||||
# )
|
||||
|
||||
print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True)
|
||||
dbg(f"model loaded: dtype={next(model.parameters()).dtype} "
|
||||
|
|
|
|||
Loading…
Reference in New Issue