jd_train/merge_zero3_safetensors.sh

88 lines
3.1 KiB
Bash
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env bash
# merge_zero3_safetensors.sh
set -euo pipefail
# ======= 你可以改的变量 =======
CKPT_ROOT="/home/test/checkpoints/q3-32b-ds4" # 你的checkpoint根目录
TAG="global_step62" # 要合并的tag目录名
HOSTS=(tn01 tn02 tn03 tn04 tn05 tn06) # 参与训练的节点列表
OUT_DIR="${CKPT_ROOT}/merged-${TAG}" # 输出目录
MAX_SHARD_SIZE="5GB" # safetensors每片大小
# =================================
echo "==> 1/4 同步各节点的分片到本机: ${CKPT_ROOT}"
mkdir -p "${CKPT_ROOT}"
LOCAL_HOST="$(hostname -s || hostname)"
for h in "${HOSTS[@]}"; do
if [[ "${h}" != "${LOCAL_HOST}" ]]; then
echo " - rsync from ${h}:${CKPT_ROOT}/ -> ${CKPT_ROOT}/"
rsync -a --delete --inplace --partial "${h}:${CKPT_ROOT}/" "${CKPT_ROOT}/"
else
echo " - 跳过本机 ${h}"
fi
done
echo "==> 2/4 基本校验"
if [[ ! -d "${CKPT_ROOT}/${TAG}" ]]; then
echo "!! 未找到 ${CKPT_ROOT}/${TAG}请确认TAG与目录名一致" >&2
exit 1
fi
MP_CNT=$(find "${CKPT_ROOT}/${TAG}" -maxdepth 1 -type d -name "mp_rank_*" | wc -l | tr -d ' ')
if [[ "${MP_CNT}" -eq 0 ]]; then
echo "!! ${CKPT_ROOT}/${TAG} 下未发现 mp_rank_* 分片目录" >&2
exit 1
fi
echo " - 分片目录数: ${MP_CNT}"
echo "==> 3/4 合并 ZeRO-3 分片为 safetensors 到: ${OUT_DIR}"
python - <<PY
import os
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
ckpt_dir = r"${CKPT_ROOT}"
out_dir = r"${OUT_DIR}"
tag = r"${TAG}"
os.makedirs(out_dir, exist_ok=True)
convert_zero_checkpoint_to_fp32_state_dict(
checkpoint_dir=ckpt_dir,
output_dir=out_dir,
tag=tag, # None=使用latest这里显式指定
safe_serialization=True, # 写出safetensors
max_shard_size="${MAX_SHARD_SIZE}",
# exclude_frozen_parameters=False # 需要时可开启
)
print("合并完成:", out_dir)
PY
echo "==> 3.1 拷贝 config / tokenizer 工件(如存在)"
pushd "${CKPT_ROOT}" >/dev/null
for f in config.json generation_config.json tokenizer_config.json tokenizer.json merges.txt vocab.json special_tokens_map.json added_tokens.json; do
if [[ -f "$f" ]]; then
cp -n "$f" "${OUT_DIR}/"
fi
done
popd >/dev/null
echo "==> 4/4 运行快速加载自检仅CPU加载meta不占大内存"
python - <<'PY'
import os, json, sys
out_dir = os.environ.get("OUT_DIR")
from transformers import AutoConfig
try:
cfg = AutoConfig.from_pretrained(out_dir)
print("模型config", cfg.model_type, "hidden:", getattr(cfg,"hidden_size",None), "layers:", getattr(cfg,"num_hidden_layers",None))
except Exception as e:
print("读取config失败可忽略如无config.json", e, file=sys.stderr)
# 校验 safetensors 索引存在
idx = os.path.join(out_dir, "model.safetensors.index.json")
if os.path.exists(idx):
with open(idx) as f:
j = json.load(f)
nfiles = len(j.get("weight_map", {}))
print(f"safetensors 索引存在:{idx} | 参数条目:{nfiles}")
else:
print("未发现 model.safetensors.index.json检查上一步是否成功", file=sys.stderr)
PY