diff --git a/lora_merge_zero3_safetensors.sh b/lora_merge_zero3_safetensors.sh new file mode 100644 index 0000000..f57a68f --- /dev/null +++ b/lora_merge_zero3_safetensors.sh @@ -0,0 +1,247 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ===== 可调参数 ===== +CKPT_ROOT="/home/test/checkpoints/q3-32b-lora" # 若实际是 .../checkpoint-62/global_step62,请把 CKPT_ROOT 改成 .../checkpoint-62 +TAG="global_step30" +HOSTS=(tn01 tn02 tn03 tn04 tn05 tn06) +AGGREGATOR_HOST="tn06" # 本脚本运行/汇总所在机器 +EXPECTED_SHARDS_PER_HOST=4 # 每机应写出分片数(按你的并行布局) +MAX_SHARD_SIZE="5GB" + +# ★★★ 新增:参考模型目录(你用来做 LoRA 的 Qwen3-32B 或其 Instruct 变体) ★★★ +REF_MODEL_DIR="/home/test/Qwen3-32B" + +STRICT_PRECHECK=true # true: 预检不通过就退出;false: 仅告警 +SSH_OPTS="-o BatchMode=yes -o StrictHostKeyChecking=accept-new -o ConnectTimeout=8" +RSYNC_OPTS="-a --info=progress2 --human-readable --partial --inplace" +# ==================== + +# ===== 派生参数(一般不用改) ===== +EXPECTED_TOTAL_SHARDS=$(( EXPECTED_SHARDS_PER_HOST * ${#HOSTS[@]} )) +STAGING_BASE="${CKPT_ROOT}/_staging" +STAGING_TAG_DIR="${STAGING_BASE}/${TAG}" +OUT_DIR="${CKPT_ROOT}/merged-${TAG}" +TMP_PT_DIR="${CKPT_ROOT}/_tmp-fp32-pt-${TAG}" # 临时 FP32(pytorch_model.bin)目录 +export OUT_DIR TMP_PT_DIR MAX_SHARD_SIZE REF_MODEL_DIR +# ================================= + +echo "== 预检查 SSH ==" +for h in "${HOSTS[@]}"; do + ssh ${SSH_OPTS} "$h" "true" >/dev/null || { echo "!! 无法免密 SSH 到 $h"; exit 1; } +done + +echo "== 0/7 逐节点分片预检(统计各机 ${CKPT_ROOT}/${TAG} 下的 *model_states.pt)==" +remote_total=0 +agg_cnt=0 +for h in "${HOSTS[@]}"; do + c=$(ssh ${SSH_OPTS} "$h" "find '${CKPT_ROOT}/${TAG}' -maxdepth 1 -type f -name '*model_states.pt' 2>/dev/null | wc -l" || echo 0) + c=$(echo "$c" | tr -d ' ') + printf " - %-8s: %s 分片\n" "$h" "$c" + if [[ "$h" == "$AGGREGATOR_HOST" ]]; then + agg_cnt=$c + else + remote_total=$(( remote_total + c )) + if (( c < EXPECTED_SHARDS_PER_HOST )); then + echo "!! 预警:$h 分片仅 $c 个(期望 ${EXPECTED_SHARDS_PER_HOST})" >&2 + fi + fi +done +expected_remote_total=$(( EXPECTED_TOTAL_SHARDS - EXPECTED_SHARDS_PER_HOST )) +echo " - 远端合计(不含 ${AGGREGATOR_HOST}):$remote_total(期望 ${expected_remote_total})" +echo " - ${AGGREGATOR_HOST} 自身:$agg_cnt(期望 ${EXPECTED_SHARDS_PER_HOST})" + +precheck_ok=true +if (( remote_total != expected_remote_total )); then + echo "!! 远端总分片不等:实际 ${remote_total} / 期望 ${expected_remote_total}" >&2 + precheck_ok=false +fi +if (( agg_cnt < EXPECTED_SHARDS_PER_HOST )); then + echo "!! ${AGGREGATOR_HOST} 本机分片不足:实际 ${agg_cnt} / 期望 ${EXPECTED_SHARDS_PER_HOST}" >&2 + precheck_ok=false +fi +if [[ "${STRICT_PRECHECK}" == "true" && "${precheck_ok}" == "false" ]]; then + echo "!! STRICT_PRECHECK 开启:预检不通过,停止执行" >&2 + exit 2 +fi +[[ "${precheck_ok}" == "true" ]] && echo "OK: 预检通过(远端=${remote_total}、本机=${agg_cnt},总计期望=${EXPECTED_TOTAL_SHARDS})" || echo "WARN: 预检未通过(分片数量与期望不符),已启用宽松模式,继续执行..." + +echo "== 1/7 准备 staging 目录(干净环境)==" +rm -rf "${STAGING_TAG_DIR}" +mkdir -p "${STAGING_TAG_DIR}" + +echo "== 2/7 收集分片到 staging ==" +for h in "${HOSTS[@]}"; do + if ssh ${SSH_OPTS} "$h" "test -d '${CKPT_ROOT}/${TAG}'"; then + echo " - 收集 ${h}:${CKPT_ROOT}/${TAG}/ -> ${STAGING_TAG_DIR}/" + rsync ${RSYNC_OPTS} -e "ssh ${SSH_OPTS}" \ + "${h}:${CKPT_ROOT}/${TAG}/" "${STAGING_TAG_DIR}/" || true + else + echo " - ${h} 无 ${CKPT_ROOT}/${TAG},跳过" + fi +done + +echo "== 3/7 在 staging 校验总分片数(应为 ${EXPECTED_TOTAL_SHARDS})==" +mapfile -t SHARDS < <(find "${STAGING_TAG_DIR}" -maxdepth 1 -type f -name "*model_states.pt" | sort -u) +CNT=${#SHARDS[@]} +echo " - staging 中发现分片数:${CNT}" +if (( CNT != EXPECTED_TOTAL_SHARDS )); then + echo "!! 分片总数不等:staging 实际 ${CNT} / 期望 ${EXPECTED_TOTAL_SHARDS}。请检查是否缺片或命名不一致。" >&2 + exit 3 +fi + +echo "== 4/7 合并分片 -> 临时 FP32(PyTorch .bin),避免共享权重导致 safetensors 报错 ==" +rm -rf "${TMP_PT_DIR}" +mkdir -p "${TMP_PT_DIR}" + +python - <