From f115389dd0e896bccf51f7d379846814250a4eed Mon Sep 17 00:00:00 2001 From: hailin Date: Thu, 4 Sep 2025 12:00:27 +0800 Subject: [PATCH] . --- merge_zero3_safetensors.sh | 96 ++++++++++++++++++++------------------ 1 file changed, 50 insertions(+), 46 deletions(-) diff --git a/merge_zero3_safetensors.sh b/merge_zero3_safetensors.sh index fc99a23..3eaaa37 100755 --- a/merge_zero3_safetensors.sh +++ b/merge_zero3_safetensors.sh @@ -1,87 +1,91 @@ #!/usr/bin/env bash -# merge_zero3_safetensors.sh set -euo pipefail -# ======= 你可以改的变量 ======= -CKPT_ROOT="/home/test/checkpoints/q3-32b-ds4" # 你的checkpoint根目录 -TAG="global_step62" # 要合并的tag(目录名) -HOSTS=(tn01 tn02 tn03 tn04 tn05 tn06) # 参与训练的节点列表 -OUT_DIR="${CKPT_ROOT}/merged-${TAG}" # 输出目录 -MAX_SHARD_SIZE="5GB" # safetensors每片大小 -# ================================= +# ===== 可调参数 ===== +CKPT_ROOT="/home/test/checkpoints/q3-32b-ds4" +TAG="global_step62" +HOSTS=(tn01 tn02 tn03 tn04 tn05 tn06) +OUT_DIR="${CKPT_ROOT}/merged-${TAG}" +MAX_SHARD_SIZE="5GB" +SSH_OPTS="-o BatchMode=yes -o StrictHostKeyChecking=accept-new -o ConnectTimeout=8" +RSYNC_OPTS="-a --info=progress2 --human-readable --partial --inplace" +# ==================== -echo "==> 1/4 同步各节点的分片到本机: ${CKPT_ROOT}" -mkdir -p "${CKPT_ROOT}" -LOCAL_HOST="$(hostname -s || hostname)" +echo "== 预检查 SSH 与远端目录 ==" for h in "${HOSTS[@]}"; do - if [[ "${h}" != "${LOCAL_HOST}" ]]; then - echo " - rsync from ${h}:${CKPT_ROOT}/ -> ${CKPT_ROOT}/" - rsync -a --delete --inplace --partial "${h}:${CKPT_ROOT}/" "${CKPT_ROOT}/" - else - echo " - 跳过本机 ${h}" + if ! ssh ${SSH_OPTS} "$h" "true" >/dev/null 2>&1; then + echo "!! 无法免密 SSH 到 $h(检查 ~/.ssh/config/authorized_keys/防火墙)" >&2 + exit 1 + fi + if ! ssh ${SSH_OPTS} "$h" "test -d '${CKPT_ROOT}/${TAG}'"; then + echo "!! $h 上缺少目录 ${CKPT_ROOT}/${TAG},确认训练是否在该机产生了分片" >&2 fi done -echo "==> 2/4 基本校验" +echo "== 1/4 开始按节点同步分片(仅 ${TAG},带进度)==" +mkdir -p "${CKPT_ROOT}" +LOCAL_HOST="$(hostname -s || hostname)" +for h in "${HOSTS[@]}"; do + [[ "$h" == "$LOCAL_HOST" ]] && { echo " - 跳过本机 $h"; continue; } + echo " - 从 $h 拉取 ${CKPT_ROOT}/${TAG}/mp_rank_*/" + # 只拉取该 step 下的 mp_rank_* 目录,避免无关文件 + rsync ${RSYNC_OPTS} -e "ssh ${SSH_OPTS}" \ + --include="${TAG}/" --include="${TAG}/mp_rank_*/" --include="${TAG}/mp_rank_*/**" \ + --exclude="*" \ + "${h}:${CKPT_ROOT}/" "${CKPT_ROOT}/" +done + +echo "== 2/4 校验是否凑齐分片目录 ==" if [[ ! -d "${CKPT_ROOT}/${TAG}" ]]; then - echo "!! 未找到 ${CKPT_ROOT}/${TAG},请确认TAG与目录名一致" >&2 - exit 1 + echo "!! 未发现 ${CKPT_ROOT}/${TAG}" >&2; exit 1 fi MP_CNT=$(find "${CKPT_ROOT}/${TAG}" -maxdepth 1 -type d -name "mp_rank_*" | wc -l | tr -d ' ') +echo " - 已发现 mp_rank_* 目录数:${MP_CNT}" if [[ "${MP_CNT}" -eq 0 ]]; then - echo "!! ${CKPT_ROOT}/${TAG} 下未发现 mp_rank_* 分片目录" >&2 - exit 1 + echo "!! 没有任何 mp_rank_* 分片,请检查同步" >&2; exit 1 fi -echo " - 分片目录数: ${MP_CNT}" -echo "==> 3/4 合并 ZeRO-3 分片为 safetensors 到: ${OUT_DIR}" +echo "== 3/4 合并为 safetensors 输出到:${OUT_DIR} ==" python - < 3.1 拷贝 config / tokenizer 工件(如存在)" +echo "== 3.1 拷贝 config/tokenizer 工件(如存在)==" pushd "${CKPT_ROOT}" >/dev/null for f in config.json generation_config.json tokenizer_config.json tokenizer.json merges.txt vocab.json special_tokens_map.json added_tokens.json; do - if [[ -f "$f" ]]; then - cp -n "$f" "${OUT_DIR}/" - fi + [[ -f "$f" ]] && cp -n "$f" "${OUT_DIR}/" done popd >/dev/null -echo "==> 4/4 运行快速加载自检(仅CPU加载meta,不占大内存)" +echo "== 4/4 自检(索引与config)==" python - <<'PY' import os, json, sys out_dir = os.environ.get("OUT_DIR") -from transformers import AutoConfig -try: - cfg = AutoConfig.from_pretrained(out_dir) - print("模型config:", cfg.model_type, "hidden:", getattr(cfg,"hidden_size",None), "layers:", getattr(cfg,"num_hidden_layers",None)) -except Exception as e: - print("读取config失败(可忽略,如无config.json):", e, file=sys.stderr) - -# 校验 safetensors 索引存在 idx = os.path.join(out_dir, "model.safetensors.index.json") if os.path.exists(idx): - with open(idx) as f: - j = json.load(f) - nfiles = len(j.get("weight_map", {})) - print(f"safetensors 索引存在:{idx} | 参数条目:{nfiles}") + with open(idx) as f: j = json.load(f) + print(f"OK: 找到 safetensors 索引:{idx}(参数条目 {len(j.get('weight_map', {}))})") else: - print("未发现 model.safetensors.index.json(检查上一步是否成功)", file=sys.stderr) + print("WARN: 未找到 model.safetensors.index.json", file=sys.stderr) +try: + from transformers import AutoConfig + cfg = AutoConfig.from_pretrained(out_dir) + print("OK: 读取到 config:", cfg.model_type, "hidden:", getattr(cfg,'hidden_size',None), "layers:", getattr(cfg,'num_hidden_layers',None)) +except Exception as e: + print("WARN: 读取 config 失败(若无 config.json 可忽略):", e, file=sys.stderr) PY + +echo "== 完成:${OUT_DIR} =="