#!/usr/bin/env bash # merge_zero3_safetensors.sh set -euo pipefail # ======= 你可以改的变量 ======= CKPT_ROOT="/home/test/checkpoints/q3-32b-ds4" # 你的checkpoint根目录 TAG="global_step62" # 要合并的tag(目录名) HOSTS=(tn01 tn02 tn03 tn04 tn05 tn06) # 参与训练的节点列表 OUT_DIR="${CKPT_ROOT}/merged-${TAG}" # 输出目录 MAX_SHARD_SIZE="5GB" # safetensors每片大小 # ================================= echo "==> 1/4 同步各节点的分片到本机: ${CKPT_ROOT}" mkdir -p "${CKPT_ROOT}" LOCAL_HOST="$(hostname -s || hostname)" for h in "${HOSTS[@]}"; do if [[ "${h}" != "${LOCAL_HOST}" ]]; then echo " - rsync from ${h}:${CKPT_ROOT}/ -> ${CKPT_ROOT}/" rsync -a --delete --inplace --partial "${h}:${CKPT_ROOT}/" "${CKPT_ROOT}/" else echo " - 跳过本机 ${h}" fi done echo "==> 2/4 基本校验" if [[ ! -d "${CKPT_ROOT}/${TAG}" ]]; then echo "!! 未找到 ${CKPT_ROOT}/${TAG},请确认TAG与目录名一致" >&2 exit 1 fi MP_CNT=$(find "${CKPT_ROOT}/${TAG}" -maxdepth 1 -type d -name "mp_rank_*" | wc -l | tr -d ' ') if [[ "${MP_CNT}" -eq 0 ]]; then echo "!! ${CKPT_ROOT}/${TAG} 下未发现 mp_rank_* 分片目录" >&2 exit 1 fi echo " - 分片目录数: ${MP_CNT}" echo "==> 3/4 合并 ZeRO-3 分片为 safetensors 到: ${OUT_DIR}" python - < 3.1 拷贝 config / tokenizer 工件(如存在)" pushd "${CKPT_ROOT}" >/dev/null for f in config.json generation_config.json tokenizer_config.json tokenizer.json merges.txt vocab.json special_tokens_map.json added_tokens.json; do if [[ -f "$f" ]]; then cp -n "$f" "${OUT_DIR}/" fi done popd >/dev/null echo "==> 4/4 运行快速加载自检(仅CPU加载meta,不占大内存)" python - <<'PY' import os, json, sys out_dir = os.environ.get("OUT_DIR") from transformers import AutoConfig try: cfg = AutoConfig.from_pretrained(out_dir) print("模型config:", cfg.model_type, "hidden:", getattr(cfg,"hidden_size",None), "layers:", getattr(cfg,"num_hidden_layers",None)) except Exception as e: print("读取config失败(可忽略,如无config.json):", e, file=sys.stderr) # 校验 safetensors 索引存在 idx = os.path.join(out_dir, "model.safetensors.index.json") if os.path.exists(idx): with open(idx) as f: j = json.load(f) nfiles = len(j.get("weight_map", {})) print(f"safetensors 索引存在:{idx} | 参数条目:{nfiles}") else: print("未发现 model.safetensors.index.json(检查上一步是否成功)", file=sys.stderr) PY