jd_train/merge_zero3_safetensors.sh

92 lines
3.3 KiB
Bash
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env bash
set -euo pipefail
# ===== 可调参数 =====
CKPT_ROOT="/home/test/checkpoints/q3-32b-ds4"
TAG="global_step62"
HOSTS=(tn01 tn02 tn03 tn04 tn05 tn06)
OUT_DIR="${CKPT_ROOT}/merged-${TAG}"
MAX_SHARD_SIZE="5GB"
SSH_OPTS="-o BatchMode=yes -o StrictHostKeyChecking=accept-new -o ConnectTimeout=8"
RSYNC_OPTS="-a --info=progress2 --human-readable --partial --inplace"
# ====================
echo "== 预检查 SSH 与远端目录 =="
for h in "${HOSTS[@]}"; do
if ! ssh ${SSH_OPTS} "$h" "true" >/dev/null 2>&1; then
echo "!! 无法免密 SSH 到 $h(检查 ~/.ssh/config/authorized_keys/防火墙)" >&2
exit 1
fi
if ! ssh ${SSH_OPTS} "$h" "test -d '${CKPT_ROOT}/${TAG}'"; then
echo "!! $h 上缺少目录 ${CKPT_ROOT}/${TAG},确认训练是否在该机产生了分片" >&2
fi
done
echo "== 1/4 开始按节点同步分片(仅 ${TAG},带进度)=="
mkdir -p "${CKPT_ROOT}"
LOCAL_HOST="$(hostname -s || hostname)"
for h in "${HOSTS[@]}"; do
[[ "$h" == "$LOCAL_HOST" ]] && { echo " - 跳过本机 $h"; continue; }
echo " - 从 $h 拉取 ${CKPT_ROOT}/${TAG}/mp_rank_*/"
# 只拉取该 step 下的 mp_rank_* 目录,避免无关文件
rsync ${RSYNC_OPTS} -e "ssh ${SSH_OPTS}" \
--include="${TAG}/" --include="${TAG}/mp_rank_*/" --include="${TAG}/mp_rank_*/**" \
--exclude="*" \
"${h}:${CKPT_ROOT}/" "${CKPT_ROOT}/"
done
echo "== 2/4 校验是否凑齐分片目录 =="
if [[ ! -d "${CKPT_ROOT}/${TAG}" ]]; then
echo "!! 未发现 ${CKPT_ROOT}/${TAG}" >&2; exit 1
fi
MP_CNT=$(find "${CKPT_ROOT}/${TAG}" -maxdepth 1 -type d -name "mp_rank_*" | wc -l | tr -d ' ')
echo " - 已发现 mp_rank_* 目录数:${MP_CNT}"
if [[ "${MP_CNT}" -eq 0 ]]; then
echo "!! 没有任何 mp_rank_* 分片,请检查同步" >&2; exit 1
fi
echo "== 3/4 合并为 safetensors 输出到:${OUT_DIR} =="
python - <<PY
import os
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
ckpt_dir = r"${CKPT_ROOT}"
out_dir = r"${OUT_DIR}"
tag = r"${TAG}"
os.makedirs(out_dir, exist_ok=True)
convert_zero_checkpoint_to_fp32_state_dict(
checkpoint_dir=ckpt_dir,
output_dir=out_dir,
tag=tag,
safe_serialization=True,
max_shard_size="${MAX_SHARD_SIZE}",
)
print("合并完成:", out_dir)
PY
echo "== 3.1 拷贝 config/tokenizer 工件(如存在)=="
pushd "${CKPT_ROOT}" >/dev/null
for f in config.json generation_config.json tokenizer_config.json tokenizer.json merges.txt vocab.json special_tokens_map.json added_tokens.json; do
[[ -f "$f" ]] && cp -n "$f" "${OUT_DIR}/"
done
popd >/dev/null
echo "== 4/4 自检索引与config=="
python - <<'PY'
import os, json, sys
out_dir = os.environ.get("OUT_DIR")
idx = os.path.join(out_dir, "model.safetensors.index.json")
if os.path.exists(idx):
with open(idx) as f: j = json.load(f)
print(f"OK: 找到 safetensors 索引:{idx}(参数条目 {len(j.get('weight_map', {}))}")
else:
print("WARN: 未找到 model.safetensors.index.json", file=sys.stderr)
try:
from transformers import AutoConfig
cfg = AutoConfig.from_pretrained(out_dir)
print("OK: 读取到 config", cfg.model_type, "hidden:", getattr(cfg,'hidden_size',None), "layers:", getattr(cfg,'num_hidden_layers',None))
except Exception as e:
print("WARN: 读取 config 失败(若无 config.json 可忽略):", e, file=sys.stderr)
PY
echo "== 完成:${OUT_DIR} =="