95 lines
3.8 KiB
Bash
Executable File
95 lines
3.8 KiB
Bash
Executable File
#!/usr/bin/env bash
|
||
set -euo pipefail
|
||
|
||
# ===== 可调参数 =====
|
||
CKPT_ROOT="/home/test/checkpoints/q3-32b-ds4" # 如果分片实际在 checkpoint-62/global_step62 下,就把这里改成 .../checkpoint-62
|
||
TAG="global_step62"
|
||
HOSTS=(tn01 tn02 tn03 tn04 tn05 tn06)
|
||
OUT_DIR="${CKPT_ROOT}/merged-${TAG}"
|
||
MAX_SHARD_SIZE="5GB"
|
||
SSH_OPTS="-o BatchMode=yes -o StrictHostKeyChecking=accept-new -o ConnectTimeout=8"
|
||
RSYNC_OPTS="-a --info=progress2 --human-readable --partial --inplace"
|
||
# ====================
|
||
|
||
echo "== 预检查 SSH 与(非必需)远端目录存在 =="
|
||
for h in "${HOSTS[@]}"; do
|
||
ssh ${SSH_OPTS} "$h" "true" >/dev/null || { echo "!! 无法免密 SSH 到 $h"; exit 1; }
|
||
# 目录不存在也不致命,后面会跳过
|
||
done
|
||
|
||
echo "== 1/4 同步各节点的 ${TAG} 整个目录(带进度)=="
|
||
mkdir -p "${CKPT_ROOT}/${TAG}"
|
||
LOCAL_HOST="$(hostname -s || hostname)"
|
||
for h in "${HOSTS[@]}"; do
|
||
[[ "$h" == "$LOCAL_HOST" ]] && { echo " - 跳过本机 $h"; continue; }
|
||
if ssh ${SSH_OPTS} "$h" "test -d '${CKPT_ROOT}/${TAG}'"; then
|
||
echo " - 从 $h 拉取 ${CKPT_ROOT}/${TAG}/"
|
||
# 不做 include/exclude 过滤,避免漏掉不同命名风格的分片文件
|
||
rsync ${RSYNC_OPTS} -e "ssh ${SSH_OPTS}" \
|
||
"${h}:${CKPT_ROOT}/${TAG}/" "${CKPT_ROOT}/${TAG}/" || true
|
||
else
|
||
echo " - $h 无 ${CKPT_ROOT}/${TAG},跳过"
|
||
fi
|
||
done
|
||
|
||
echo "== 2/4 校验是否有分片“文件”(不是目录)=="
|
||
# 兼容两种常见命名:mp_rank_*_model_states.pt 与 *mp_rank*model_states.pt(含 pp 维度)
|
||
CNT_A=$(ls -1 "${CKPT_ROOT}/${TAG}"/mp_rank_*_model_states.pt 2>/dev/null | wc -l | tr -d ' ' || true)
|
||
CNT_B=$(ls -1 "${CKPT_ROOT}/${TAG}"/*mp_rank*model_states.pt 2>/dev/null | wc -l | tr -d ' ' || true)
|
||
CNT=$(( CNT_A + CNT_B ))
|
||
echo " - 发现 model_states 分片文件数:${CNT}"
|
||
if [[ "${CNT}" -eq 0 ]]; then
|
||
echo "!! 未检测到任何 *_model_states.pt;请在各机上 ls 看看 ${CKPT_ROOT}/${TAG} 的实际文件名,再调整匹配规则" >&2
|
||
exit 1
|
||
fi
|
||
|
||
echo "== 3/4 合并为 safetensors 输出到:${OUT_DIR} =="
|
||
mkdir -p "${OUT_DIR}"
|
||
if [[ -f "${CKPT_ROOT}/zero_to_fp32.py" ]]; then
|
||
# 优先使用与该 checkpoint 同版本的官方脚本
|
||
python "${CKPT_ROOT}/zero_to_fp32.py" \
|
||
"${CKPT_ROOT}" \
|
||
"${OUT_DIR}" \
|
||
--tag "${TAG}" \
|
||
--safe_serialization \
|
||
--max_shard_size "${MAX_SHARD_SIZE}"
|
||
else
|
||
# 退回 DeepSpeed API
|
||
python - <<PY
|
||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||
convert_zero_checkpoint_to_fp32_state_dict(
|
||
checkpoint_dir=r"${CKPT_ROOT}",
|
||
output_dir=r"${OUT_DIR}",
|
||
tag=r"${TAG}",
|
||
safe_serialization=True,
|
||
max_shard_size=r"${MAX_SHARD_SIZE}",
|
||
)
|
||
print("合并完成:", r"${OUT_DIR}")
|
||
PY
|
||
fi
|
||
|
||
echo "== 3.1 拷贝 config/tokenizer 工件(如存在)=="
|
||
for f in config.json generation_config.json tokenizer_config.json tokenizer.json merges.txt vocab.json special_tokens_map.json added_tokens.json; do
|
||
[[ -f "${CKPT_ROOT}/${f}" ]] && cp -n "${CKPT_ROOT}/${f}" "${OUT_DIR}/"
|
||
done
|
||
|
||
echo "== 4/4 自检(索引与config)=="
|
||
python - <<'PY'
|
||
import os, json, sys
|
||
out_dir = os.environ.get("OUT_DIR")
|
||
idx = os.path.join(out_dir, "model.safetensors.index.json")
|
||
if os.path.exists(idx):
|
||
with open(idx) as f: j = json.load(f)
|
||
print(f"OK: 找到 safetensors 索引:{idx}(参数条目 {len(j.get('weight_map', {}))})")
|
||
else:
|
||
print("WARN: 未找到 model.safetensors.index.json", file=sys.stderr)
|
||
try:
|
||
from transformers import AutoConfig
|
||
cfg = AutoConfig.from_pretrained(out_dir)
|
||
print("OK: 读取到 config:", cfg.model_type, "hidden:", getattr(cfg,'hidden_size',None), "layers:", getattr(cfg,'num_hidden_layers',None))
|
||
except Exception as e:
|
||
print("WARN: 读取 config 失败(若无 config.json 可忽略):", e, file=sys.stderr)
|
||
PY
|
||
|
||
echo "== 完成:${OUT_DIR} =="
|