#!/usr/bin/env bash set -euo pipefail # ===== 可调参数 ===== CKPT_ROOT="/home/test/checkpoints/q3-32b-ds4" # 如果分片实际在 checkpoint-62/global_step62 下,就把这里改成 .../checkpoint-62 TAG="global_step62" HOSTS=(tn01 tn02 tn03 tn04 tn05 tn06) OUT_DIR="${CKPT_ROOT}/merged-${TAG}" MAX_SHARD_SIZE="5GB" SSH_OPTS="-o BatchMode=yes -o StrictHostKeyChecking=accept-new -o ConnectTimeout=8" RSYNC_OPTS="-a --info=progress2 --human-readable --partial --inplace" # ==================== echo "== 预检查 SSH 与(非必需)远端目录存在 ==" for h in "${HOSTS[@]}"; do ssh ${SSH_OPTS} "$h" "true" >/dev/null || { echo "!! 无法免密 SSH 到 $h"; exit 1; } # 目录不存在也不致命,后面会跳过 done echo "== 1/4 同步各节点的 ${TAG} 整个目录(带进度)==" mkdir -p "${CKPT_ROOT}/${TAG}" LOCAL_HOST="$(hostname -s || hostname)" for h in "${HOSTS[@]}"; do [[ "$h" == "$LOCAL_HOST" ]] && { echo " - 跳过本机 $h"; continue; } if ssh ${SSH_OPTS} "$h" "test -d '${CKPT_ROOT}/${TAG}'"; then echo " - 从 $h 拉取 ${CKPT_ROOT}/${TAG}/" # 不做 include/exclude 过滤,避免漏掉不同命名风格的分片文件 rsync ${RSYNC_OPTS} -e "ssh ${SSH_OPTS}" \ "${h}:${CKPT_ROOT}/${TAG}/" "${CKPT_ROOT}/${TAG}/" || true else echo " - $h 无 ${CKPT_ROOT}/${TAG},跳过" fi done echo "== 2/4 校验是否有分片“文件”(不是目录)==" # 兼容两种常见命名:mp_rank_*_model_states.pt 与 *mp_rank*model_states.pt(含 pp 维度) CNT_A=$(ls -1 "${CKPT_ROOT}/${TAG}"/mp_rank_*_model_states.pt 2>/dev/null | wc -l | tr -d ' ' || true) CNT_B=$(ls -1 "${CKPT_ROOT}/${TAG}"/*mp_rank*model_states.pt 2>/dev/null | wc -l | tr -d ' ' || true) CNT=$(( CNT_A + CNT_B )) echo " - 发现 model_states 分片文件数:${CNT}" if [[ "${CNT}" -eq 0 ]]; then echo "!! 未检测到任何 *_model_states.pt;请在各机上 ls 看看 ${CKPT_ROOT}/${TAG} 的实际文件名,再调整匹配规则" >&2 exit 1 fi echo "== 3/4 合并为 safetensors 输出到:${OUT_DIR} ==" mkdir -p "${OUT_DIR}" if [[ -f "${CKPT_ROOT}/zero_to_fp32.py" ]]; then # 优先使用与该 checkpoint 同版本的官方脚本 python "${CKPT_ROOT}/zero_to_fp32.py" \ "${CKPT_ROOT}" \ "${OUT_DIR}" \ --tag "${TAG}" \ --safe_serialization \ --max_shard_size "${MAX_SHARD_SIZE}" else # 退回 DeepSpeed API python - <