jd_train/merge_zero3_safetensors.sh

146 lines
5.7 KiB
Bash
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env bash
set -euo pipefail
# ===== 可调参数 =====
CKPT_ROOT="/home/test/checkpoints/q3-32b-ds4" # 若实际是 .../checkpoint-62/global_step62请把 CKPT_ROOT 改成 .../checkpoint-62
TAG="global_step62"
HOSTS=(tn01 tn02 tn03 tn04 tn05 tn06)
AGGREGATOR_HOST="tn06" # 本脚本运行/汇总所在机器
EXPECTED_SHARDS_PER_HOST=4 # 每机应写出分片数(按你的并行布局)
MAX_SHARD_SIZE="5GB"
STRICT_PRECHECK=true # true: 预检不通过就退出false: 仅告警
SSH_OPTS="-o BatchMode=yes -o StrictHostKeyChecking=accept-new -o ConnectTimeout=8"
RSYNC_OPTS="-a --info=progress2 --human-readable --partial --inplace"
# ====================
# ===== 派生参数(一般不用改) =====
EXPECTED_TOTAL_SHARDS=$(( EXPECTED_SHARDS_PER_HOST * ${#HOSTS[@]} ))
STAGING_BASE="${CKPT_ROOT}/_staging"
STAGING_TAG_DIR="${STAGING_BASE}/${TAG}"
OUT_DIR="${CKPT_ROOT}/merged-${TAG}"
export OUT_DIR
# =================================
echo "== 预检查 SSH =="
for h in "${HOSTS[@]}"; do
ssh ${SSH_OPTS} "$h" "true" >/dev/null || { echo "!! 无法免密 SSH 到 $h"; exit 1; }
done
echo "== 0/5 逐节点分片预检(统计各机 ${CKPT_ROOT}/${TAG} 下的 *model_states.pt=="
remote_total=0
agg_cnt=0
for h in "${HOSTS[@]}"; do
c=$(ssh ${SSH_OPTS} "$h" "find '${CKPT_ROOT}/${TAG}' -maxdepth 1 -type f -name '*model_states.pt' 2>/dev/null | wc -l" || echo 0)
c=$(echo "$c" | tr -d ' ')
printf " - %-8s: %s 分片\n" "$h" "$c"
if [[ "$h" == "$AGGREGATOR_HOST" ]]; then
agg_cnt=$c
else
remote_total=$(( remote_total + c ))
# 每台机最简单的 sanity至少应有 EXPECTED_SHARDS_PER_HOST 个
if (( c < EXPECTED_SHARDS_PER_HOST )); then
echo "!! 预警:$h 分片仅 $c 个(期望 ${EXPECTED_SHARDS_PER_HOST}" >&2
fi
fi
done
expected_remote_total=$(( EXPECTED_TOTAL_SHARDS - EXPECTED_SHARDS_PER_HOST ))
echo " - 远端合计(不含 ${AGGREGATOR_HOST}$remote_total(期望 ${expected_remote_total}"
echo " - ${AGGREGATOR_HOST} 自身:$agg_cnt(期望 ${EXPECTED_SHARDS_PER_HOST}"
precheck_ok=true
if (( remote_total != expected_remote_total )); then
echo "!! 远端总分片不等:实际 ${remote_total} / 期望 ${expected_remote_total}" >&2
precheck_ok=false
fi
if (( agg_cnt < EXPECTED_SHARDS_PER_HOST )); then
echo "!! ${AGGREGATOR_HOST} 本机分片不足:实际 ${agg_cnt} / 期望 ${EXPECTED_SHARDS_PER_HOST}" >&2
precheck_ok=false
fi
if [[ "${STRICT_PRECHECK}" == "true" && "${precheck_ok}" == "false" ]]; then
echo "!! STRICT_PRECHECK 开启:预检不通过,停止执行" >&2
exit 2
fi
[[ "${precheck_ok}" == "true" ]] && echo "OK: 预检通过(远端=${remote_total}、本机=${agg_cnt},总计期望=${EXPECTED_TOTAL_SHARDS}" || echo "WARN: 预检未通过(分片数量与期望不符),已启用宽松模式,继续执行..."
echo "== 1/5 准备 staging 目录(干净环境)=="
rm -rf "${STAGING_TAG_DIR}"
mkdir -p "${STAGING_TAG_DIR}"
echo "== 2/5 收集分片到 staging =="
for h in "${HOSTS[@]}"; do
if ssh ${SSH_OPTS} "$h" "test -d '${CKPT_ROOT}/${TAG}'"; then
echo " - 收集 ${h}:${CKPT_ROOT}/${TAG}/ -> ${STAGING_TAG_DIR}/"
rsync ${RSYNC_OPTS} -e "ssh ${SSH_OPTS}" \
"${h}:${CKPT_ROOT}/${TAG}/" "${STAGING_TAG_DIR}/" || true
else
echo " - ${h}${CKPT_ROOT}/${TAG},跳过"
fi
done
echo "== 3/5 在 staging 校验总分片数(应为 ${EXPECTED_TOTAL_SHARDS}=="
mapfile -t SHARDS < <(find "${STAGING_TAG_DIR}" -maxdepth 1 -type f -name "*model_states.pt" | sort -u)
CNT=${#SHARDS[@]}
echo " - staging 中发现分片数:${CNT}"
if (( CNT != EXPECTED_TOTAL_SHARDS )); then
echo "!! 分片总数不等staging 实际 ${CNT} / 期望 ${EXPECTED_TOTAL_SHARDS}。请检查是否缺片或命名不一致。" >&2
exit 3
fi
echo "== 4/5 合并为 safetensors 到:${OUT_DIR} =="
mkdir -p "${OUT_DIR}"
# 探测 zero_to_fp32.py 是否支持新参数;不支持就用 API
USE_Z2FP32_SCRIPT=false
if [[ -f "${CKPT_ROOT}/zero_to_fp32.py" ]]; then
if python "${CKPT_ROOT}/zero_to_fp32.py" --help 2>&1 | grep -q -- "--safe_serialization"; then
USE_Z2FP32_SCRIPT=true
fi
fi
if $USE_Z2FP32_SCRIPT; then
python "${CKPT_ROOT}/zero_to_fp32.py" \
"${STAGING_BASE}" \
"${OUT_DIR}" \
--tag "${TAG}" \
--safe_serialization \
--max_shard_size "${MAX_SHARD_SIZE}"
else
python - <<PY
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
convert_zero_checkpoint_to_fp32_state_dict(
checkpoint_dir=r"${STAGING_BASE}",
output_dir=r"${OUT_DIR}",
tag=r"${TAG}",
safe_serialization=True,
max_shard_size=r"${MAX_SHARD_SIZE}",
)
print("合并完成:", r"${OUT_DIR}")
PY
fi
echo "== 4.1 拷贝 config/tokenizer 工件(如存在)=="
for f in config.json generation_config.json tokenizer_config.json tokenizer.json merges.txt vocab.json special_tokens_map.json added_tokens.json; do
[[ -f "${CKPT_ROOT}/${f}" ]] && cp -n "${CKPT_ROOT}/${f}" "${OUT_DIR}/"
done
echo "== 5/5 自检(索引与 config=="
python - <<'PY'
import os, json, sys
out_dir = os.environ.get("OUT_DIR")
idx = os.path.join(out_dir, "model.safetensors.index.json")
if os.path.exists(idx):
with open(idx) as f: j = json.load(f)
print(f"OK: 找到 safetensors 索引:{idx}(参数条目 {len(j.get('weight_map', {}))}")
else:
print("NOTE: 未找到 model.safetensors.index.json可能是单分片")
try:
from transformers import AutoConfig
cfg = AutoConfig.from_pretrained(out_dir)
print("OK: 读取到 config", cfg.model_type, "hidden:", getattr(cfg,'hidden_size',None), "layers:", getattr(cfg,'num_hidden_layers',None))
except Exception as e:
print("WARN: 读取 config 失败(若无 config.json 可忽略):", e, file=sys.stderr)
PY
echo "== 完成:${OUT_DIR} =="