jd_train/merge_zero3_safetensors.sh

143 lines
5.4 KiB
Bash
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env bash
set -euo pipefail
# ===== 可调参数 =====
CKPT_ROOT="/home/test/checkpoints/q3-32b-ds4" # 若实际是 .../checkpoint-62/global_step62请把 CKPT_ROOT 改成 .../checkpoint-62
TAG="global_step62"
HOSTS=(tn01 tn02 tn03 tn04 tn05 tn06)
OUT_DIR="${CKPT_ROOT}/merged-${TAG}"
MAX_SHARD_SIZE="5GB"
# 预检控制:总分片数(通常=总写出 rank 数,如 6 节点×4 GPU=24最小每机分片数按你布局调
EXPECTED_TOTAL_SHARDS=24
MIN_SHARDS_PER_HOST=1
STRICT_PRECHECK=true # true: 若总分片≠期望则直接退出false: 仅告警但继续
SSH_OPTS="-o BatchMode=yes -o StrictHostKeyChecking=accept-new -o ConnectTimeout=8"
RSYNC_OPTS="-a --info=progress2 --human-readable --partial --inplace"
# ====================
export OUT_DIR # 让后面的 Python 自检拿得到
echo "== 预检查 SSH =="
for h in "${HOSTS[@]}"; do
ssh ${SSH_OPTS} "$h" "true" >/dev/null || { echo "!! 无法免密 SSH 到 $h"; exit 1; }
done
echo "== 0/4 逐节点分片预检(只统计 ${CKPT_ROOT}/${TAG} 下的 *model_states.pt 文件) =="
total=0
declare -A host_cnt
for h in "${HOSTS[@]}"; do
# 只认文件,不认目录;限制在 TAG 这一层
c=$(ssh ${SSH_OPTS} "$h" "find '${CKPT_ROOT}/${TAG}' -maxdepth 1 -type f -name '*model_states.pt' 2>/dev/null | wc -l" || echo 0)
c=$(echo "$c" | tr -d ' ')
host_cnt["$h"]=$c
total=$(( total + c ))
printf " - %-8s: %s 分片\n" "$h" "$c"
if (( c < MIN_SHARDS_PER_HOST )); then
echo "!! 预警:$h 分片仅 $c 个 (< ${MIN_SHARDS_PER_HOST}),该节点可能未写出或路径不同" >&2
fi
done
echo " - 汇总分片数(未同步前):$total"
if [[ -n "${EXPECTED_TOTAL_SHARDS:-}" ]]; then
if (( total != EXPECTED_TOTAL_SHARDS )); then
echo "!! 分片总数($total) ≠ 期望(${EXPECTED_TOTAL_SHARDS}),很可能缺片或路径不一致" >&2
if [[ "${STRICT_PRECHECK}" == "true" ]]; then
echo "!! STRICT_PRECHECK 开启:中止合并,请先排查缺片节点" >&2
exit 2
else
echo ">> 严格校验关闭:继续执行(可能在合并/加载时失败)" >&2
fi
fi
fi
echo "== 1/4 同步各节点的 ${TAG} 目录(带进度)=="
mkdir -p "${CKPT_ROOT}/${TAG}"
LOCAL_HOST="$(hostname -s || hostname)"
for h in "${HOSTS[@]}"; do
if [[ "$h" == "$LOCAL_HOST" ]]; then
echo " - 跳过本机 $h"
continue
fi
if ssh ${SSH_OPTS} "$h" "test -d '${CKPT_ROOT}/${TAG}'"; then
echo " - 从 $h 拉取 ${CKPT_ROOT}/${TAG}/"
rsync ${RSYNC_OPTS} -e "ssh ${SSH_OPTS}" \
"${h}:${CKPT_ROOT}/${TAG}/" "${CKPT_ROOT}/${TAG}/" || true
else
echo " - $h${CKPT_ROOT}/${TAG},跳过"
fi
done
echo "== 2/4 统计与校验分片文件(本机聚合后) =="
# 只认文件,不认目录;用一次 find 去重,避免重复计数
mapfile -t SHARDS < <(find "${CKPT_ROOT}/${TAG}" -maxdepth 1 -type f -name "*model_states.pt" | sort -u)
CNT=${#SHARDS[@]}
echo " - 发现 model_states 分片文件数:${CNT}"
if [[ "${CNT}" -eq 0 ]]; then
echo "!! 未检测到任何 *model_states.pt请在各机上 ls 看看 ${CKPT_ROOT}/${TAG} 的实际文件名" >&2
exit 1
fi
# 简单健壮性检查:分片数至少不低于主机数(经验性检查)
if [[ "${CNT}" -lt "${#HOSTS[@]}" ]]; then
echo "!! 分片数(${CNT}) < 主机数(${#HOSTS[@]}),可能有节点没同步到分片,继续可能失败" >&2
fi
echo "== 3/4 合并为 safetensors 到:${OUT_DIR} =="
mkdir -p "${OUT_DIR}"
# 先探测 zero_to_fp32.py 是否支持新参数;不支持就走 API
USE_Z2FP32_SCRIPT=false
if [[ -f "${CKPT_ROOT}/zero_to_fp32.py" ]]; then
if python "${CKPT_ROOT}/zero_to_fp32.py" --help 2>&1 | grep -q -- "--safe_serialization"; then
USE_Z2FP32_SCRIPT=true
fi
fi
if $USE_Z2FP32_SCRIPT; then
python "${CKPT_ROOT}/zero_to_fp32.py" \
"${CKPT_ROOT}" \
"${OUT_DIR}" \
--tag "${TAG}" \
--safe_serialization \
--max_shard_size "${MAX_SHARD_SIZE}"
else
python - <<PY
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
convert_zero_checkpoint_to_fp32_state_dict(
checkpoint_dir=r"${CKPT_ROOT}",
output_dir=r"${OUT_DIR}",
tag=r"${TAG}",
safe_serialization=True,
max_shard_size=r"${MAX_SHARD_SIZE}",
)
print("合并完成:", r"${OUT_DIR}")
PY
fi
echo "== 3.1 拷贝 config/tokenizer 工件(如存在)=="
for f in config.json generation_config.json tokenizer_config.json tokenizer.json merges.txt vocab.json special_tokens_map.json added_tokens.json; do
[[ -f "${CKPT_ROOT}/${f}" ]] && cp -n "${CKPT_ROOT}/${f}" "${OUT_DIR}/"
done
echo "== 4/4 自检(索引与 config=="
python - <<'PY'
import os, json, sys
out_dir = os.environ.get("OUT_DIR")
idx = os.path.join(out_dir, "model.safetensors.index.json")
if os.path.exists(idx):
with open(idx) as f: j = json.load(f)
print(f"OK: 找到 safetensors 索引:{idx}(参数条目 {len(j.get('weight_map', {}))}")
else:
# 单分片时没有 index.json属于正常情况
print("NOTE: 未找到 model.safetensors.index.json可能是单分片")
try:
from transformers import AutoConfig
cfg = AutoConfig.from_pretrained(out_dir)
print("OK: 读取到 config", cfg.model_type, "hidden:", getattr(cfg,'hidden_size',None), "layers:", getattr(cfg,'num_hidden_layers',None))
except Exception as e:
print("WARN: 读取 config 失败(若无 config.json 可忽略):", e, file=sys.stderr)
PY
echo "== 完成:${OUT_DIR} =="