This commit is contained in:
parent
36044de8df
commit
8d81af86f4
|
|
@ -489,9 +489,11 @@ def main():
|
||||||
try:
|
try:
|
||||||
import deepspeed
|
import deepspeed
|
||||||
zero_init_ctx = deepspeed.zero.Init(
|
zero_init_ctx = deepspeed.zero.Init(
|
||||||
remote_device="cpu", # 参数初始驻留 CPU,安全
|
remote_device="cpu", # 参数最终托管在 CPU(可结合 offload)
|
||||||
|
device="cpu", # ← 关键:不要用 meta
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
dtype=dtype # 和你上面的 dtype 一致(bf16)
|
dtype=dtype,
|
||||||
|
config_dict_or_path=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None),
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
zero_init_ctx = nullcontext() # 没装 DS 时也能单机跑
|
zero_init_ctx = nullcontext() # 没装 DS 时也能单机跑
|
||||||
|
|
@ -500,7 +502,7 @@ def main():
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
args.model_name_or_path,
|
args.model_name_or_path,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=False,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
attn_implementation="sdpa"
|
attn_implementation="sdpa"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue