This commit is contained in:
hailin 2025-09-04 16:49:02 +08:00
parent 2d2e42c4dd
commit a19d5b5a72
1 changed files with 59 additions and 56 deletions

View File

@ -5,88 +5,92 @@ set -euo pipefail
CKPT_ROOT="/home/test/checkpoints/q3-32b-ds4" # 若实际是 .../checkpoint-62/global_step62请把 CKPT_ROOT 改成 .../checkpoint-62 CKPT_ROOT="/home/test/checkpoints/q3-32b-ds4" # 若实际是 .../checkpoint-62/global_step62请把 CKPT_ROOT 改成 .../checkpoint-62
TAG="global_step62" TAG="global_step62"
HOSTS=(tn01 tn02 tn03 tn04 tn05 tn06) HOSTS=(tn01 tn02 tn03 tn04 tn05 tn06)
OUT_DIR="${CKPT_ROOT}/merged-${TAG}" AGGREGATOR_HOST="tn06" # 本脚本运行/汇总所在机器
EXPECTED_SHARDS_PER_HOST=4 # 每机应写出分片数(按你的并行布局)
MAX_SHARD_SIZE="5GB" MAX_SHARD_SIZE="5GB"
# 预检控制:总分片数(通常=总写出 rank 数,如 6 节点×4 GPU=24最小每机分片数按你布局调 STRICT_PRECHECK=true # true: 预检不通过就退出false: 仅告警
EXPECTED_TOTAL_SHARDS=24
MIN_SHARDS_PER_HOST=1
STRICT_PRECHECK=true # true: 若总分片≠期望则直接退出false: 仅告警但继续
SSH_OPTS="-o BatchMode=yes -o StrictHostKeyChecking=accept-new -o ConnectTimeout=8" SSH_OPTS="-o BatchMode=yes -o StrictHostKeyChecking=accept-new -o ConnectTimeout=8"
RSYNC_OPTS="-a --info=progress2 --human-readable --partial --inplace" RSYNC_OPTS="-a --info=progress2 --human-readable --partial --inplace"
# ==================== # ====================
export OUT_DIR # 让后面的 Python 自检拿得到 # ===== 派生参数(一般不用改) =====
EXPECTED_TOTAL_SHARDS=$(( EXPECTED_SHARDS_PER_HOST * ${#HOSTS[@]} ))
STAGING_BASE="${CKPT_ROOT}/_staging"
STAGING_TAG_DIR="${STAGING_BASE}/${TAG}"
OUT_DIR="${CKPT_ROOT}/merged-${TAG}"
export OUT_DIR
# =================================
echo "== 预检查 SSH ==" echo "== 预检查 SSH =="
for h in "${HOSTS[@]}"; do for h in "${HOSTS[@]}"; do
ssh ${SSH_OPTS} "$h" "true" >/dev/null || { echo "!! 无法免密 SSH 到 $h"; exit 1; } ssh ${SSH_OPTS} "$h" "true" >/dev/null || { echo "!! 无法免密 SSH 到 $h"; exit 1; }
done done
echo "== 0/4 逐节点分片预检(只统计 ${CKPT_ROOT}/${TAG} 下的 *model_states.pt 文件 ==" echo "== 0/5 逐节点分片预检(统计各机 ${CKPT_ROOT}/${TAG} 下的 *model_states.pt=="
total=0 remote_total=0
declare -A host_cnt agg_cnt=0
for h in "${HOSTS[@]}"; do for h in "${HOSTS[@]}"; do
# 只认文件,不认目录;限制在 TAG 这一层
c=$(ssh ${SSH_OPTS} "$h" "find '${CKPT_ROOT}/${TAG}' -maxdepth 1 -type f -name '*model_states.pt' 2>/dev/null | wc -l" || echo 0) c=$(ssh ${SSH_OPTS} "$h" "find '${CKPT_ROOT}/${TAG}' -maxdepth 1 -type f -name '*model_states.pt' 2>/dev/null | wc -l" || echo 0)
c=$(echo "$c" | tr -d ' ') c=$(echo "$c" | tr -d ' ')
host_cnt["$h"]=$c
total=$(( total + c ))
printf " - %-8s: %s 分片\n" "$h" "$c" printf " - %-8s: %s 分片\n" "$h" "$c"
if (( c < MIN_SHARDS_PER_HOST )); then if [[ "$h" == "$AGGREGATOR_HOST" ]]; then
echo "!! 预警:$h 分片仅 $c 个 (< ${MIN_SHARDS_PER_HOST}),该节点可能未写出或路径不同" >&2 agg_cnt=$c
fi else
done remote_total=$(( remote_total + c ))
echo " - 汇总分片数(未同步前):$total" # 每台机最简单的 sanity至少应有 EXPECTED_SHARDS_PER_HOST 个
if (( c < EXPECTED_SHARDS_PER_HOST )); then
if [[ -n "${EXPECTED_TOTAL_SHARDS:-}" ]]; then echo "!! 预警:$h 分片仅 $c 个(期望 ${EXPECTED_SHARDS_PER_HOST}" >&2
if (( total != EXPECTED_TOTAL_SHARDS )); then
echo "!! 分片总数($total) ≠ 期望(${EXPECTED_TOTAL_SHARDS}),很可能缺片或路径不一致" >&2
if [[ "${STRICT_PRECHECK}" == "true" ]]; then
echo "!! STRICT_PRECHECK 开启:中止合并,请先排查缺片节点" >&2
exit 2
else
echo ">> 严格校验关闭:继续执行(可能在合并/加载时失败)" >&2
fi fi
fi fi
fi done
expected_remote_total=$(( EXPECTED_TOTAL_SHARDS - EXPECTED_SHARDS_PER_HOST ))
echo " - 远端合计(不含 ${AGGREGATOR_HOST}$remote_total(期望 ${expected_remote_total}"
echo " - ${AGGREGATOR_HOST} 自身:$agg_cnt(期望 ${EXPECTED_SHARDS_PER_HOST}"
echo "== 1/4 同步各节点的 ${TAG} 目录(带进度)==" precheck_ok=true
mkdir -p "${CKPT_ROOT}/${TAG}" if (( remote_total != expected_remote_total )); then
LOCAL_HOST="$(hostname -s || hostname)" echo "!! 远端总分片不等:实际 ${remote_total} / 期望 ${expected_remote_total}" >&2
precheck_ok=false
fi
if (( agg_cnt < EXPECTED_SHARDS_PER_HOST )); then
echo "!! ${AGGREGATOR_HOST} 本机分片不足:实际 ${agg_cnt} / 期望 ${EXPECTED_SHARDS_PER_HOST}" >&2
precheck_ok=false
fi
if [[ "${STRICT_PRECHECK}" == "true" && "${precheck_ok}" == "false" ]]; then
echo "!! STRICT_PRECHECK 开启:预检不通过,停止执行" >&2
exit 2
fi
[[ "${precheck_ok}" == "true" ]] && echo "OK: 预检通过(远端=$(remote_total)、本机=${agg_cnt},总计期望=${EXPECTED_TOTAL_SHARDS}" || echo "WARN: 预检存在告警,继续执行..."
echo "== 1/5 准备 staging 目录(干净环境)=="
rm -rf "${STAGING_TAG_DIR}"
mkdir -p "${STAGING_TAG_DIR}"
echo "== 2/5 收集分片到 staging =="
for h in "${HOSTS[@]}"; do for h in "${HOSTS[@]}"; do
if [[ "$h" == "$LOCAL_HOST" ]]; then
echo " - 跳过本机 $h"
continue
fi
if ssh ${SSH_OPTS} "$h" "test -d '${CKPT_ROOT}/${TAG}'"; then if ssh ${SSH_OPTS} "$h" "test -d '${CKPT_ROOT}/${TAG}'"; then
echo " - 从 $h 拉取 ${CKPT_ROOT}/${TAG}/" echo " - 收集 ${h}:${CKPT_ROOT}/${TAG}/ -> ${STAGING_TAG_DIR}/"
rsync ${RSYNC_OPTS} -e "ssh ${SSH_OPTS}" \ rsync ${RSYNC_OPTS} -e "ssh ${SSH_OPTS}" \
"${h}:${CKPT_ROOT}/${TAG}/" "${CKPT_ROOT}/${TAG}/" || true "${h}:${CKPT_ROOT}/${TAG}/" "${STAGING_TAG_DIR}/" || true
else else
echo " - $h${CKPT_ROOT}/${TAG},跳过" echo " - ${h}${CKPT_ROOT}/${TAG},跳过"
fi fi
done done
echo "== 2/4 统计与校验分片文件(本机聚合后) ==" echo "== 3/5 在 staging 校验总分片数(应为 ${EXPECTED_TOTAL_SHARDS}=="
# 只认文件,不认目录;用一次 find 去重,避免重复计数 mapfile -t SHARDS < <(find "${STAGING_TAG_DIR}" -maxdepth 1 -type f -name "*model_states.pt" | sort -u)
mapfile -t SHARDS < <(find "${CKPT_ROOT}/${TAG}" -maxdepth 1 -type f -name "*model_states.pt" | sort -u)
CNT=${#SHARDS[@]} CNT=${#SHARDS[@]}
echo " - 发现 model_states 分片文件数:${CNT}" echo " - staging 中发现分片数:${CNT}"
if [[ "${CNT}" -eq 0 ]]; then if (( CNT != EXPECTED_TOTAL_SHARDS )); then
echo "!! 未检测到任何 *model_states.pt请在各机上 ls 看看 ${CKPT_ROOT}/${TAG} 的实际文件名" >&2 echo "!! 分片总数不等staging 实际 ${CNT} / 期望 ${EXPECTED_TOTAL_SHARDS}。请检查是否缺片或命名不一致。" >&2
exit 1 exit 3
fi
# 简单健壮性检查:分片数至少不低于主机数(经验性检查)
if [[ "${CNT}" -lt "${#HOSTS[@]}" ]]; then
echo "!! 分片数(${CNT}) < 主机数(${#HOSTS[@]}),可能有节点没同步到分片,继续可能失败" >&2
fi fi
echo "== 3/4 合并为 safetensors 到:${OUT_DIR} ==" echo "== 4/5 合并为 safetensors 到:${OUT_DIR} =="
mkdir -p "${OUT_DIR}" mkdir -p "${OUT_DIR}"
# 探测 zero_to_fp32.py 是否支持新参数;不支持就 API # 探测 zero_to_fp32.py 是否支持新参数;不支持就 API
USE_Z2FP32_SCRIPT=false USE_Z2FP32_SCRIPT=false
if [[ -f "${CKPT_ROOT}/zero_to_fp32.py" ]]; then if [[ -f "${CKPT_ROOT}/zero_to_fp32.py" ]]; then
if python "${CKPT_ROOT}/zero_to_fp32.py" --help 2>&1 | grep -q -- "--safe_serialization"; then if python "${CKPT_ROOT}/zero_to_fp32.py" --help 2>&1 | grep -q -- "--safe_serialization"; then
@ -96,7 +100,7 @@ fi
if $USE_Z2FP32_SCRIPT; then if $USE_Z2FP32_SCRIPT; then
python "${CKPT_ROOT}/zero_to_fp32.py" \ python "${CKPT_ROOT}/zero_to_fp32.py" \
"${CKPT_ROOT}" \ "${STAGING_BASE}" \
"${OUT_DIR}" \ "${OUT_DIR}" \
--tag "${TAG}" \ --tag "${TAG}" \
--safe_serialization \ --safe_serialization \
@ -105,7 +109,7 @@ else
python - <<PY python - <<PY
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
convert_zero_checkpoint_to_fp32_state_dict( convert_zero_checkpoint_to_fp32_state_dict(
checkpoint_dir=r"${CKPT_ROOT}", checkpoint_dir=r"${STAGING_BASE}",
output_dir=r"${OUT_DIR}", output_dir=r"${OUT_DIR}",
tag=r"${TAG}", tag=r"${TAG}",
safe_serialization=True, safe_serialization=True,
@ -115,12 +119,12 @@ print("合并完成:", r"${OUT_DIR}")
PY PY
fi fi
echo "== 3.1 拷贝 config/tokenizer 工件(如存在)==" echo "== 4.1 拷贝 config/tokenizer 工件(如存在)=="
for f in config.json generation_config.json tokenizer_config.json tokenizer.json merges.txt vocab.json special_tokens_map.json added_tokens.json; do for f in config.json generation_config.json tokenizer_config.json tokenizer.json merges.txt vocab.json special_tokens_map.json added_tokens.json; do
[[ -f "${CKPT_ROOT}/${f}" ]] && cp -n "${CKPT_ROOT}/${f}" "${OUT_DIR}/" [[ -f "${CKPT_ROOT}/${f}" ]] && cp -n "${CKPT_ROOT}/${f}" "${OUT_DIR}/"
done done
echo "== 4/4 自检(索引与 config==" echo "== 5/5 自检(索引与 config=="
python - <<'PY' python - <<'PY'
import os, json, sys import os, json, sys
out_dir = os.environ.get("OUT_DIR") out_dir = os.environ.get("OUT_DIR")
@ -129,7 +133,6 @@ if os.path.exists(idx):
with open(idx) as f: j = json.load(f) with open(idx) as f: j = json.load(f)
print(f"OK: 找到 safetensors 索引:{idx}(参数条目 {len(j.get('weight_map', {}))}") print(f"OK: 找到 safetensors 索引:{idx}(参数条目 {len(j.get('weight_map', {}))}")
else: else:
# 单分片时没有 index.json属于正常情况
print("NOTE: 未找到 model.safetensors.index.json可能是单分片") print("NOTE: 未找到 model.safetensors.index.json可能是单分片")
try: try:
from transformers import AutoConfig from transformers import AutoConfig