This commit is contained in:
parent
b2075d945b
commit
2d2e42c4dd
|
|
@ -2,29 +2,66 @@
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
# ===== 可调参数 =====
|
# ===== 可调参数 =====
|
||||||
CKPT_ROOT="/home/test/checkpoints/q3-32b-ds4" # 如果分片实际在 checkpoint-62/global_step62 下,就把这里改成 .../checkpoint-62
|
CKPT_ROOT="/home/test/checkpoints/q3-32b-ds4" # 若实际是 .../checkpoint-62/global_step62,请把 CKPT_ROOT 改成 .../checkpoint-62
|
||||||
TAG="global_step62"
|
TAG="global_step62"
|
||||||
HOSTS=(tn01 tn02 tn03 tn04 tn05 tn06)
|
HOSTS=(tn01 tn02 tn03 tn04 tn05 tn06)
|
||||||
OUT_DIR="${CKPT_ROOT}/merged-${TAG}"
|
OUT_DIR="${CKPT_ROOT}/merged-${TAG}"
|
||||||
MAX_SHARD_SIZE="5GB"
|
MAX_SHARD_SIZE="5GB"
|
||||||
|
|
||||||
|
# 预检控制:总分片数(通常=总写出 rank 数,如 6 节点×4 GPU=24),最小每机分片数(按你布局调)
|
||||||
|
EXPECTED_TOTAL_SHARDS=24
|
||||||
|
MIN_SHARDS_PER_HOST=1
|
||||||
|
STRICT_PRECHECK=true # true: 若总分片≠期望则直接退出;false: 仅告警但继续
|
||||||
|
|
||||||
SSH_OPTS="-o BatchMode=yes -o StrictHostKeyChecking=accept-new -o ConnectTimeout=8"
|
SSH_OPTS="-o BatchMode=yes -o StrictHostKeyChecking=accept-new -o ConnectTimeout=8"
|
||||||
RSYNC_OPTS="-a --info=progress2 --human-readable --partial --inplace"
|
RSYNC_OPTS="-a --info=progress2 --human-readable --partial --inplace"
|
||||||
# ====================
|
# ====================
|
||||||
|
|
||||||
echo "== 预检查 SSH 与(非必需)远端目录存在 =="
|
export OUT_DIR # 让后面的 Python 自检拿得到
|
||||||
|
|
||||||
|
echo "== 预检查 SSH =="
|
||||||
for h in "${HOSTS[@]}"; do
|
for h in "${HOSTS[@]}"; do
|
||||||
ssh ${SSH_OPTS} "$h" "true" >/dev/null || { echo "!! 无法免密 SSH 到 $h"; exit 1; }
|
ssh ${SSH_OPTS} "$h" "true" >/dev/null || { echo "!! 无法免密 SSH 到 $h"; exit 1; }
|
||||||
# 目录不存在也不致命,后面会跳过
|
|
||||||
done
|
done
|
||||||
|
|
||||||
echo "== 1/4 同步各节点的 ${TAG} 整个目录(带进度)=="
|
echo "== 0/4 逐节点分片预检(只统计 ${CKPT_ROOT}/${TAG} 下的 *model_states.pt 文件) =="
|
||||||
|
total=0
|
||||||
|
declare -A host_cnt
|
||||||
|
for h in "${HOSTS[@]}"; do
|
||||||
|
# 只认文件,不认目录;限制在 TAG 这一层
|
||||||
|
c=$(ssh ${SSH_OPTS} "$h" "find '${CKPT_ROOT}/${TAG}' -maxdepth 1 -type f -name '*model_states.pt' 2>/dev/null | wc -l" || echo 0)
|
||||||
|
c=$(echo "$c" | tr -d ' ')
|
||||||
|
host_cnt["$h"]=$c
|
||||||
|
total=$(( total + c ))
|
||||||
|
printf " - %-8s: %s 分片\n" "$h" "$c"
|
||||||
|
if (( c < MIN_SHARDS_PER_HOST )); then
|
||||||
|
echo "!! 预警:$h 分片仅 $c 个 (< ${MIN_SHARDS_PER_HOST}),该节点可能未写出或路径不同" >&2
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
echo " - 汇总分片数(未同步前):$total"
|
||||||
|
|
||||||
|
if [[ -n "${EXPECTED_TOTAL_SHARDS:-}" ]]; then
|
||||||
|
if (( total != EXPECTED_TOTAL_SHARDS )); then
|
||||||
|
echo "!! 分片总数($total) ≠ 期望(${EXPECTED_TOTAL_SHARDS}),很可能缺片或路径不一致" >&2
|
||||||
|
if [[ "${STRICT_PRECHECK}" == "true" ]]; then
|
||||||
|
echo "!! STRICT_PRECHECK 开启:中止合并,请先排查缺片节点" >&2
|
||||||
|
exit 2
|
||||||
|
else
|
||||||
|
echo ">> 严格校验关闭:继续执行(可能在合并/加载时失败)" >&2
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "== 1/4 同步各节点的 ${TAG} 目录(带进度)=="
|
||||||
mkdir -p "${CKPT_ROOT}/${TAG}"
|
mkdir -p "${CKPT_ROOT}/${TAG}"
|
||||||
LOCAL_HOST="$(hostname -s || hostname)"
|
LOCAL_HOST="$(hostname -s || hostname)"
|
||||||
for h in "${HOSTS[@]}"; do
|
for h in "${HOSTS[@]}"; do
|
||||||
[[ "$h" == "$LOCAL_HOST" ]] && { echo " - 跳过本机 $h"; continue; }
|
if [[ "$h" == "$LOCAL_HOST" ]]; then
|
||||||
|
echo " - 跳过本机 $h"
|
||||||
|
continue
|
||||||
|
fi
|
||||||
if ssh ${SSH_OPTS} "$h" "test -d '${CKPT_ROOT}/${TAG}'"; then
|
if ssh ${SSH_OPTS} "$h" "test -d '${CKPT_ROOT}/${TAG}'"; then
|
||||||
echo " - 从 $h 拉取 ${CKPT_ROOT}/${TAG}/"
|
echo " - 从 $h 拉取 ${CKPT_ROOT}/${TAG}/"
|
||||||
# 不做 include/exclude 过滤,避免漏掉不同命名风格的分片文件
|
|
||||||
rsync ${RSYNC_OPTS} -e "ssh ${SSH_OPTS}" \
|
rsync ${RSYNC_OPTS} -e "ssh ${SSH_OPTS}" \
|
||||||
"${h}:${CKPT_ROOT}/${TAG}/" "${CKPT_ROOT}/${TAG}/" || true
|
"${h}:${CKPT_ROOT}/${TAG}/" "${CKPT_ROOT}/${TAG}/" || true
|
||||||
else
|
else
|
||||||
|
|
@ -32,21 +69,32 @@ for h in "${HOSTS[@]}"; do
|
||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
|
|
||||||
echo "== 2/4 校验是否有分片“文件”(不是目录)=="
|
echo "== 2/4 统计与校验分片文件(本机聚合后) =="
|
||||||
# 兼容两种常见命名:mp_rank_*_model_states.pt 与 *mp_rank*model_states.pt(含 pp 维度)
|
# 只认文件,不认目录;用一次 find 去重,避免重复计数
|
||||||
CNT_A=$(ls -1 "${CKPT_ROOT}/${TAG}"/mp_rank_*_model_states.pt 2>/dev/null | wc -l | tr -d ' ' || true)
|
mapfile -t SHARDS < <(find "${CKPT_ROOT}/${TAG}" -maxdepth 1 -type f -name "*model_states.pt" | sort -u)
|
||||||
CNT_B=$(ls -1 "${CKPT_ROOT}/${TAG}"/*mp_rank*model_states.pt 2>/dev/null | wc -l | tr -d ' ' || true)
|
CNT=${#SHARDS[@]}
|
||||||
CNT=$(( CNT_A + CNT_B ))
|
|
||||||
echo " - 发现 model_states 分片文件数:${CNT}"
|
echo " - 发现 model_states 分片文件数:${CNT}"
|
||||||
if [[ "${CNT}" -eq 0 ]]; then
|
if [[ "${CNT}" -eq 0 ]]; then
|
||||||
echo "!! 未检测到任何 *_model_states.pt;请在各机上 ls 看看 ${CKPT_ROOT}/${TAG} 的实际文件名,再调整匹配规则" >&2
|
echo "!! 未检测到任何 *model_states.pt;请在各机上 ls 看看 ${CKPT_ROOT}/${TAG} 的实际文件名" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
# 简单健壮性检查:分片数至少不低于主机数(经验性检查)
|
||||||
|
if [[ "${CNT}" -lt "${#HOSTS[@]}" ]]; then
|
||||||
|
echo "!! 分片数(${CNT}) < 主机数(${#HOSTS[@]}),可能有节点没同步到分片,继续可能失败" >&2
|
||||||
|
fi
|
||||||
|
|
||||||
echo "== 3/4 合并为 safetensors 输出到:${OUT_DIR} =="
|
echo "== 3/4 合并为 safetensors 到:${OUT_DIR} =="
|
||||||
mkdir -p "${OUT_DIR}"
|
mkdir -p "${OUT_DIR}"
|
||||||
|
|
||||||
|
# 先探测 zero_to_fp32.py 是否支持新参数;不支持就走 API
|
||||||
|
USE_Z2FP32_SCRIPT=false
|
||||||
if [[ -f "${CKPT_ROOT}/zero_to_fp32.py" ]]; then
|
if [[ -f "${CKPT_ROOT}/zero_to_fp32.py" ]]; then
|
||||||
# 优先使用与该 checkpoint 同版本的官方脚本
|
if python "${CKPT_ROOT}/zero_to_fp32.py" --help 2>&1 | grep -q -- "--safe_serialization"; then
|
||||||
|
USE_Z2FP32_SCRIPT=true
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if $USE_Z2FP32_SCRIPT; then
|
||||||
python "${CKPT_ROOT}/zero_to_fp32.py" \
|
python "${CKPT_ROOT}/zero_to_fp32.py" \
|
||||||
"${CKPT_ROOT}" \
|
"${CKPT_ROOT}" \
|
||||||
"${OUT_DIR}" \
|
"${OUT_DIR}" \
|
||||||
|
|
@ -54,7 +102,6 @@ if [[ -f "${CKPT_ROOT}/zero_to_fp32.py" ]]; then
|
||||||
--safe_serialization \
|
--safe_serialization \
|
||||||
--max_shard_size "${MAX_SHARD_SIZE}"
|
--max_shard_size "${MAX_SHARD_SIZE}"
|
||||||
else
|
else
|
||||||
# 退回 DeepSpeed API
|
|
||||||
python - <<PY
|
python - <<PY
|
||||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||||
convert_zero_checkpoint_to_fp32_state_dict(
|
convert_zero_checkpoint_to_fp32_state_dict(
|
||||||
|
|
@ -73,7 +120,7 @@ for f in config.json generation_config.json tokenizer_config.json tokenizer.json
|
||||||
[[ -f "${CKPT_ROOT}/${f}" ]] && cp -n "${CKPT_ROOT}/${f}" "${OUT_DIR}/"
|
[[ -f "${CKPT_ROOT}/${f}" ]] && cp -n "${CKPT_ROOT}/${f}" "${OUT_DIR}/"
|
||||||
done
|
done
|
||||||
|
|
||||||
echo "== 4/4 自检(索引与config)=="
|
echo "== 4/4 自检(索引与 config)=="
|
||||||
python - <<'PY'
|
python - <<'PY'
|
||||||
import os, json, sys
|
import os, json, sys
|
||||||
out_dir = os.environ.get("OUT_DIR")
|
out_dir = os.environ.get("OUT_DIR")
|
||||||
|
|
@ -82,7 +129,8 @@ if os.path.exists(idx):
|
||||||
with open(idx) as f: j = json.load(f)
|
with open(idx) as f: j = json.load(f)
|
||||||
print(f"OK: 找到 safetensors 索引:{idx}(参数条目 {len(j.get('weight_map', {}))})")
|
print(f"OK: 找到 safetensors 索引:{idx}(参数条目 {len(j.get('weight_map', {}))})")
|
||||||
else:
|
else:
|
||||||
print("WARN: 未找到 model.safetensors.index.json", file=sys.stderr)
|
# 单分片时没有 index.json,属于正常情况
|
||||||
|
print("NOTE: 未找到 model.safetensors.index.json(可能是单分片)")
|
||||||
try:
|
try:
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
cfg = AutoConfig.from_pretrained(out_dir)
|
cfg = AutoConfig.from_pretrained(out_dir)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue