jd_train/merge_zero3_safetensors.sh

95 lines
3.8 KiB
Bash
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env bash
set -euo pipefail
# ===== 可调参数 =====
CKPT_ROOT="/home/test/checkpoints/q3-32b-ds4" # 如果分片实际在 checkpoint-62/global_step62 下,就把这里改成 .../checkpoint-62
TAG="global_step62"
HOSTS=(tn01 tn02 tn03 tn04 tn05 tn06)
OUT_DIR="${CKPT_ROOT}/merged-${TAG}"
MAX_SHARD_SIZE="5GB"
SSH_OPTS="-o BatchMode=yes -o StrictHostKeyChecking=accept-new -o ConnectTimeout=8"
RSYNC_OPTS="-a --info=progress2 --human-readable --partial --inplace"
# ====================
echo "== 预检查 SSH 与(非必需)远端目录存在 =="
for h in "${HOSTS[@]}"; do
ssh ${SSH_OPTS} "$h" "true" >/dev/null || { echo "!! 无法免密 SSH 到 $h"; exit 1; }
# 目录不存在也不致命,后面会跳过
done
echo "== 1/4 同步各节点的 ${TAG} 整个目录(带进度)=="
mkdir -p "${CKPT_ROOT}/${TAG}"
LOCAL_HOST="$(hostname -s || hostname)"
for h in "${HOSTS[@]}"; do
[[ "$h" == "$LOCAL_HOST" ]] && { echo " - 跳过本机 $h"; continue; }
if ssh ${SSH_OPTS} "$h" "test -d '${CKPT_ROOT}/${TAG}'"; then
echo " - 从 $h 拉取 ${CKPT_ROOT}/${TAG}/"
# 不做 include/exclude 过滤,避免漏掉不同命名风格的分片文件
rsync ${RSYNC_OPTS} -e "ssh ${SSH_OPTS}" \
"${h}:${CKPT_ROOT}/${TAG}/" "${CKPT_ROOT}/${TAG}/" || true
else
echo " - $h${CKPT_ROOT}/${TAG},跳过"
fi
done
echo "== 2/4 校验是否有分片“文件”(不是目录)=="
# 兼容两种常见命名mp_rank_*_model_states.pt 与 *mp_rank*model_states.pt含 pp 维度)
CNT_A=$(ls -1 "${CKPT_ROOT}/${TAG}"/mp_rank_*_model_states.pt 2>/dev/null | wc -l | tr -d ' ' || true)
CNT_B=$(ls -1 "${CKPT_ROOT}/${TAG}"/*mp_rank*model_states.pt 2>/dev/null | wc -l | tr -d ' ' || true)
CNT=$(( CNT_A + CNT_B ))
echo " - 发现 model_states 分片文件数:${CNT}"
if [[ "${CNT}" -eq 0 ]]; then
echo "!! 未检测到任何 *_model_states.pt请在各机上 ls 看看 ${CKPT_ROOT}/${TAG} 的实际文件名,再调整匹配规则" >&2
exit 1
fi
echo "== 3/4 合并为 safetensors 输出到:${OUT_DIR} =="
mkdir -p "${OUT_DIR}"
if [[ -f "${CKPT_ROOT}/zero_to_fp32.py" ]]; then
# 优先使用与该 checkpoint 同版本的官方脚本
python "${CKPT_ROOT}/zero_to_fp32.py" \
"${CKPT_ROOT}" \
"${OUT_DIR}" \
--tag "${TAG}" \
--safe_serialization \
--max_shard_size "${MAX_SHARD_SIZE}"
else
# 退回 DeepSpeed API
python - <<PY
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
convert_zero_checkpoint_to_fp32_state_dict(
checkpoint_dir=r"${CKPT_ROOT}",
output_dir=r"${OUT_DIR}",
tag=r"${TAG}",
safe_serialization=True,
max_shard_size=r"${MAX_SHARD_SIZE}",
)
print("合并完成:", r"${OUT_DIR}")
PY
fi
echo "== 3.1 拷贝 config/tokenizer 工件(如存在)=="
for f in config.json generation_config.json tokenizer_config.json tokenizer.json merges.txt vocab.json special_tokens_map.json added_tokens.json; do
[[ -f "${CKPT_ROOT}/${f}" ]] && cp -n "${CKPT_ROOT}/${f}" "${OUT_DIR}/"
done
echo "== 4/4 自检索引与config=="
python - <<'PY'
import os, json, sys
out_dir = os.environ.get("OUT_DIR")
idx = os.path.join(out_dir, "model.safetensors.index.json")
if os.path.exists(idx):
with open(idx) as f: j = json.load(f)
print(f"OK: 找到 safetensors 索引:{idx}(参数条目 {len(j.get('weight_map', {}))}")
else:
print("WARN: 未找到 model.safetensors.index.json", file=sys.stderr)
try:
from transformers import AutoConfig
cfg = AutoConfig.from_pretrained(out_dir)
print("OK: 读取到 config", cfg.model_type, "hidden:", getattr(cfg,'hidden_size',None), "layers:", getattr(cfg,'num_hidden_layers',None))
except Exception as e:
print("WARN: 读取 config 失败(若无 config.json 可忽略):", e, file=sys.stderr)
PY
echo "== 完成:${OUT_DIR} =="