From 2d2e42c4dd2bb348da1858f0c13c179de3cc4426 Mon Sep 17 00:00:00 2001 From: hailin Date: Thu, 4 Sep 2025 16:42:02 +0800 Subject: [PATCH] . --- merge_zero3_safetensors.sh | 82 ++++++++++++++++++++++++++++++-------- 1 file changed, 65 insertions(+), 17 deletions(-) diff --git a/merge_zero3_safetensors.sh b/merge_zero3_safetensors.sh index 7844f60..f644666 100755 --- a/merge_zero3_safetensors.sh +++ b/merge_zero3_safetensors.sh @@ -2,29 +2,66 @@ set -euo pipefail # ===== 可调参数 ===== -CKPT_ROOT="/home/test/checkpoints/q3-32b-ds4" # 如果分片实际在 checkpoint-62/global_step62 下,就把这里改成 .../checkpoint-62 +CKPT_ROOT="/home/test/checkpoints/q3-32b-ds4" # 若实际是 .../checkpoint-62/global_step62,请把 CKPT_ROOT 改成 .../checkpoint-62 TAG="global_step62" HOSTS=(tn01 tn02 tn03 tn04 tn05 tn06) OUT_DIR="${CKPT_ROOT}/merged-${TAG}" MAX_SHARD_SIZE="5GB" + +# 预检控制:总分片数(通常=总写出 rank 数,如 6 节点×4 GPU=24),最小每机分片数(按你布局调) +EXPECTED_TOTAL_SHARDS=24 +MIN_SHARDS_PER_HOST=1 +STRICT_PRECHECK=true # true: 若总分片≠期望则直接退出;false: 仅告警但继续 + SSH_OPTS="-o BatchMode=yes -o StrictHostKeyChecking=accept-new -o ConnectTimeout=8" RSYNC_OPTS="-a --info=progress2 --human-readable --partial --inplace" # ==================== -echo "== 预检查 SSH 与(非必需)远端目录存在 ==" +export OUT_DIR # 让后面的 Python 自检拿得到 + +echo "== 预检查 SSH ==" for h in "${HOSTS[@]}"; do ssh ${SSH_OPTS} "$h" "true" >/dev/null || { echo "!! 无法免密 SSH 到 $h"; exit 1; } - # 目录不存在也不致命,后面会跳过 done -echo "== 1/4 同步各节点的 ${TAG} 整个目录(带进度)==" +echo "== 0/4 逐节点分片预检(只统计 ${CKPT_ROOT}/${TAG} 下的 *model_states.pt 文件) ==" +total=0 +declare -A host_cnt +for h in "${HOSTS[@]}"; do + # 只认文件,不认目录;限制在 TAG 这一层 + c=$(ssh ${SSH_OPTS} "$h" "find '${CKPT_ROOT}/${TAG}' -maxdepth 1 -type f -name '*model_states.pt' 2>/dev/null | wc -l" || echo 0) + c=$(echo "$c" | tr -d ' ') + host_cnt["$h"]=$c + total=$(( total + c )) + printf " - %-8s: %s 分片\n" "$h" "$c" + if (( c < MIN_SHARDS_PER_HOST )); then + echo "!! 预警:$h 分片仅 $c 个 (< ${MIN_SHARDS_PER_HOST}),该节点可能未写出或路径不同" >&2 + fi +done +echo " - 汇总分片数(未同步前):$total" + +if [[ -n "${EXPECTED_TOTAL_SHARDS:-}" ]]; then + if (( total != EXPECTED_TOTAL_SHARDS )); then + echo "!! 分片总数($total) ≠ 期望(${EXPECTED_TOTAL_SHARDS}),很可能缺片或路径不一致" >&2 + if [[ "${STRICT_PRECHECK}" == "true" ]]; then + echo "!! STRICT_PRECHECK 开启:中止合并,请先排查缺片节点" >&2 + exit 2 + else + echo ">> 严格校验关闭:继续执行(可能在合并/加载时失败)" >&2 + fi + fi +fi + +echo "== 1/4 同步各节点的 ${TAG} 目录(带进度)==" mkdir -p "${CKPT_ROOT}/${TAG}" LOCAL_HOST="$(hostname -s || hostname)" for h in "${HOSTS[@]}"; do - [[ "$h" == "$LOCAL_HOST" ]] && { echo " - 跳过本机 $h"; continue; } + if [[ "$h" == "$LOCAL_HOST" ]]; then + echo " - 跳过本机 $h" + continue + fi if ssh ${SSH_OPTS} "$h" "test -d '${CKPT_ROOT}/${TAG}'"; then echo " - 从 $h 拉取 ${CKPT_ROOT}/${TAG}/" - # 不做 include/exclude 过滤,避免漏掉不同命名风格的分片文件 rsync ${RSYNC_OPTS} -e "ssh ${SSH_OPTS}" \ "${h}:${CKPT_ROOT}/${TAG}/" "${CKPT_ROOT}/${TAG}/" || true else @@ -32,21 +69,32 @@ for h in "${HOSTS[@]}"; do fi done -echo "== 2/4 校验是否有分片“文件”(不是目录)==" -# 兼容两种常见命名:mp_rank_*_model_states.pt 与 *mp_rank*model_states.pt(含 pp 维度) -CNT_A=$(ls -1 "${CKPT_ROOT}/${TAG}"/mp_rank_*_model_states.pt 2>/dev/null | wc -l | tr -d ' ' || true) -CNT_B=$(ls -1 "${CKPT_ROOT}/${TAG}"/*mp_rank*model_states.pt 2>/dev/null | wc -l | tr -d ' ' || true) -CNT=$(( CNT_A + CNT_B )) +echo "== 2/4 统计与校验分片文件(本机聚合后) ==" +# 只认文件,不认目录;用一次 find 去重,避免重复计数 +mapfile -t SHARDS < <(find "${CKPT_ROOT}/${TAG}" -maxdepth 1 -type f -name "*model_states.pt" | sort -u) +CNT=${#SHARDS[@]} echo " - 发现 model_states 分片文件数:${CNT}" if [[ "${CNT}" -eq 0 ]]; then - echo "!! 未检测到任何 *_model_states.pt;请在各机上 ls 看看 ${CKPT_ROOT}/${TAG} 的实际文件名,再调整匹配规则" >&2 + echo "!! 未检测到任何 *model_states.pt;请在各机上 ls 看看 ${CKPT_ROOT}/${TAG} 的实际文件名" >&2 exit 1 fi +# 简单健壮性检查:分片数至少不低于主机数(经验性检查) +if [[ "${CNT}" -lt "${#HOSTS[@]}" ]]; then + echo "!! 分片数(${CNT}) < 主机数(${#HOSTS[@]}),可能有节点没同步到分片,继续可能失败" >&2 +fi -echo "== 3/4 合并为 safetensors 输出到:${OUT_DIR} ==" +echo "== 3/4 合并为 safetensors 到:${OUT_DIR} ==" mkdir -p "${OUT_DIR}" + +# 先探测 zero_to_fp32.py 是否支持新参数;不支持就走 API +USE_Z2FP32_SCRIPT=false if [[ -f "${CKPT_ROOT}/zero_to_fp32.py" ]]; then - # 优先使用与该 checkpoint 同版本的官方脚本 + if python "${CKPT_ROOT}/zero_to_fp32.py" --help 2>&1 | grep -q -- "--safe_serialization"; then + USE_Z2FP32_SCRIPT=true + fi +fi + +if $USE_Z2FP32_SCRIPT; then python "${CKPT_ROOT}/zero_to_fp32.py" \ "${CKPT_ROOT}" \ "${OUT_DIR}" \ @@ -54,7 +102,6 @@ if [[ -f "${CKPT_ROOT}/zero_to_fp32.py" ]]; then --safe_serialization \ --max_shard_size "${MAX_SHARD_SIZE}" else - # 退回 DeepSpeed API python - <