This commit is contained in:
parent
36044de8df
commit
8d81af86f4
|
|
@ -489,9 +489,11 @@ def main():
|
|||
try:
|
||||
import deepspeed
|
||||
zero_init_ctx = deepspeed.zero.Init(
|
||||
remote_device="cpu", # 参数初始驻留 CPU,安全
|
||||
remote_device="cpu", # 参数最终托管在 CPU(可结合 offload)
|
||||
device="cpu", # ← 关键:不要用 meta
|
||||
pin_memory=True,
|
||||
dtype=dtype # 和你上面的 dtype 一致(bf16)
|
||||
dtype=dtype,
|
||||
config_dict_or_path=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None),
|
||||
)
|
||||
except Exception:
|
||||
zero_init_ctx = nullcontext() # 没装 DS 时也能单机跑
|
||||
|
|
@ -500,7 +502,7 @@ def main():
|
|||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
torch_dtype=dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
low_cpu_mem_usage=False,
|
||||
trust_remote_code=True,
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue