diff --git a/train_sft_ds.py b/train_sft_ds.py index 9114270..ee6a30c 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -489,9 +489,11 @@ def main(): try: import deepspeed zero_init_ctx = deepspeed.zero.Init( - remote_device="cpu", # 参数初始驻留 CPU,安全 + remote_device="cpu", # 参数最终托管在 CPU(可结合 offload) + device="cpu", # ← 关键:不要用 meta pin_memory=True, - dtype=dtype # 和你上面的 dtype 一致(bf16) + dtype=dtype, + config_dict_or_path=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None), ) except Exception: zero_init_ctx = nullcontext() # 没装 DS 时也能单机跑 @@ -500,7 +502,7 @@ def main(): model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, torch_dtype=dtype, - low_cpu_mem_usage=True, + low_cpu_mem_usage=False, trust_remote_code=True, attn_implementation="sdpa" )