jd_train/convertfp32_to_bf16.py

30 lines
1.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
SRC = "/home/test/checkpoints/q3-32b-ds4/merged-global_step62" # FP32 合并目录
DST = "/home/test/checkpoints/q3-32b-ds4/merged-global_step62-bf16" # 目标输出目录
# 1) 以 bfloat16 加载(不会加载优化器状态)
model = AutoModelForCausalLM.from_pretrained(
SRC,
torch_dtype="bfloat16",
low_cpu_mem_usage=True, # 节省内存
device_map=None # 全在 CPU 上处理
)
# 2) 可选:某些层保持 FP32更稳一般是 LayerNorm/Embedding
# 若你想更“训练一致”,可以跳过这步;若想更稳,放开下面注释:
# for name, module in model.named_modules():
# if "norm" in name.lower():
# module.to(dtype=None) # 让它跟随权重 dtype已是 bf16
# transformers 默认就会按权重 dtype 存;真的想强行 FP32可手动 .float()
# 3) 保存为 bf16 + safetensors 分片(按 5GB 切片)
model.save_pretrained(DST, safe_serialization=True, max_shard_size="5GB")
# 4) 同步 tokenizer/config若上一步没自动带上
tok = AutoTokenizer.from_pretrained(SRC, use_fast=True)
tok.save_pretrained(DST)
cfg = AutoConfig.from_pretrained(SRC)
cfg.save_pretrained(DST)
print("✅ 保存完成:", DST)