This commit is contained in:
hailin 2025-09-04 11:50:57 +08:00
parent bd9a294e42
commit edcc0b7b09
1 changed files with 87 additions and 0 deletions

View File

@ -0,0 +1,87 @@
#!/usr/bin/env bash
# merge_zero3_safetensors.sh
set -euo pipefail
# ======= 你可以改的变量 =======
CKPT_ROOT="/home/test/checkpoints/q3-32b-ds4" # 你的checkpoint根目录
TAG="global_step62" # 要合并的tag目录名
HOSTS=(tn01 tn02 tn03 tn04 tn05 tn06) # 参与训练的节点列表
OUT_DIR="${CKPT_ROOT}/merged-${TAG}" # 输出目录
MAX_SHARD_SIZE="5GB" # safetensors每片大小
# =================================
echo "==> 1/4 同步各节点的分片到本机: ${CKPT_ROOT}"
mkdir -p "${CKPT_ROOT}"
LOCAL_HOST="$(hostname -s || hostname)"
for h in "${HOSTS[@]}"; do
if [[ "${h}" != "${LOCAL_HOST}" ]]; then
echo " - rsync from ${h}:${CKPT_ROOT}/ -> ${CKPT_ROOT}/"
rsync -a --delete --inplace --partial "${h}:${CKPT_ROOT}/" "${CKPT_ROOT}/"
else
echo " - 跳过本机 ${h}"
fi
done
echo "==> 2/4 基本校验"
if [[ ! -d "${CKPT_ROOT}/${TAG}" ]]; then
echo "!! 未找到 ${CKPT_ROOT}/${TAG}请确认TAG与目录名一致" >&2
exit 1
fi
MP_CNT=$(find "${CKPT_ROOT}/${TAG}" -maxdepth 1 -type d -name "mp_rank_*" | wc -l | tr -d ' ')
if [[ "${MP_CNT}" -eq 0 ]]; then
echo "!! ${CKPT_ROOT}/${TAG} 下未发现 mp_rank_* 分片目录" >&2
exit 1
fi
echo " - 分片目录数: ${MP_CNT}"
echo "==> 3/4 合并 ZeRO-3 分片为 safetensors 到: ${OUT_DIR}"
python - <<PY
import os
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
ckpt_dir = r"${CKPT_ROOT}"
out_dir = r"${OUT_DIR}"
tag = r"${TAG}"
os.makedirs(out_dir, exist_ok=True)
convert_zero_checkpoint_to_fp32_state_dict(
checkpoint_dir=ckpt_dir,
output_dir=out_dir,
tag=tag, # None=使用latest这里显式指定
safe_serialization=True, # 写出safetensors
max_shard_size="${MAX_SHARD_SIZE}",
# exclude_frozen_parameters=False # 需要时可开启
)
print("合并完成:", out_dir)
PY
echo "==> 3.1 拷贝 config / tokenizer 工件(如存在)"
pushd "${CKPT_ROOT}" >/dev/null
for f in config.json generation_config.json tokenizer_config.json tokenizer.json merges.txt vocab.json special_tokens_map.json added_tokens.json; do
if [[ -f "$f" ]]; then
cp -n "$f" "${OUT_DIR}/"
fi
done
popd >/dev/null
echo "==> 4/4 运行快速加载自检仅CPU加载meta不占大内存"
python - <<'PY'
import os, json, sys
out_dir = os.environ.get("OUT_DIR")
from transformers import AutoConfig
try:
cfg = AutoConfig.from_pretrained(out_dir)
print("模型config", cfg.model_type, "hidden:", getattr(cfg,"hidden_size",None), "layers:", getattr(cfg,"num_hidden_layers",None))
except Exception as e:
print("读取config失败可忽略如无config.json", e, file=sys.stderr)
# 校验 safetensors 索引存在
idx = os.path.join(out_dir, "model.safetensors.index.json")
if os.path.exists(idx):
with open(idx) as f:
j = json.load(f)
nfiles = len(j.get("weight_map", {}))
print(f"safetensors 索引存在:{idx} | 参数条目:{nfiles}")
else:
print("未发现 model.safetensors.index.json检查上一步是否成功", file=sys.stderr)
PY