This commit is contained in:
parent
daf614ede4
commit
f115389dd0
|
|
@ -1,87 +1,91 @@
|
||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
# merge_zero3_safetensors.sh
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
# ======= 你可以改的变量 =======
|
# ===== 可调参数 =====
|
||||||
CKPT_ROOT="/home/test/checkpoints/q3-32b-ds4" # 你的checkpoint根目录
|
CKPT_ROOT="/home/test/checkpoints/q3-32b-ds4"
|
||||||
TAG="global_step62" # 要合并的tag(目录名)
|
TAG="global_step62"
|
||||||
HOSTS=(tn01 tn02 tn03 tn04 tn05 tn06) # 参与训练的节点列表
|
HOSTS=(tn01 tn02 tn03 tn04 tn05 tn06)
|
||||||
OUT_DIR="${CKPT_ROOT}/merged-${TAG}" # 输出目录
|
OUT_DIR="${CKPT_ROOT}/merged-${TAG}"
|
||||||
MAX_SHARD_SIZE="5GB" # safetensors每片大小
|
MAX_SHARD_SIZE="5GB"
|
||||||
# =================================
|
SSH_OPTS="-o BatchMode=yes -o StrictHostKeyChecking=accept-new -o ConnectTimeout=8"
|
||||||
|
RSYNC_OPTS="-a --info=progress2 --human-readable --partial --inplace"
|
||||||
|
# ====================
|
||||||
|
|
||||||
echo "==> 1/4 同步各节点的分片到本机: ${CKPT_ROOT}"
|
echo "== 预检查 SSH 与远端目录 =="
|
||||||
mkdir -p "${CKPT_ROOT}"
|
|
||||||
LOCAL_HOST="$(hostname -s || hostname)"
|
|
||||||
for h in "${HOSTS[@]}"; do
|
for h in "${HOSTS[@]}"; do
|
||||||
if [[ "${h}" != "${LOCAL_HOST}" ]]; then
|
if ! ssh ${SSH_OPTS} "$h" "true" >/dev/null 2>&1; then
|
||||||
echo " - rsync from ${h}:${CKPT_ROOT}/ -> ${CKPT_ROOT}/"
|
echo "!! 无法免密 SSH 到 $h(检查 ~/.ssh/config/authorized_keys/防火墙)" >&2
|
||||||
rsync -a --delete --inplace --partial "${h}:${CKPT_ROOT}/" "${CKPT_ROOT}/"
|
exit 1
|
||||||
else
|
fi
|
||||||
echo " - 跳过本机 ${h}"
|
if ! ssh ${SSH_OPTS} "$h" "test -d '${CKPT_ROOT}/${TAG}'"; then
|
||||||
|
echo "!! $h 上缺少目录 ${CKPT_ROOT}/${TAG},确认训练是否在该机产生了分片" >&2
|
||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
|
|
||||||
echo "==> 2/4 基本校验"
|
echo "== 1/4 开始按节点同步分片(仅 ${TAG},带进度)=="
|
||||||
|
mkdir -p "${CKPT_ROOT}"
|
||||||
|
LOCAL_HOST="$(hostname -s || hostname)"
|
||||||
|
for h in "${HOSTS[@]}"; do
|
||||||
|
[[ "$h" == "$LOCAL_HOST" ]] && { echo " - 跳过本机 $h"; continue; }
|
||||||
|
echo " - 从 $h 拉取 ${CKPT_ROOT}/${TAG}/mp_rank_*/"
|
||||||
|
# 只拉取该 step 下的 mp_rank_* 目录,避免无关文件
|
||||||
|
rsync ${RSYNC_OPTS} -e "ssh ${SSH_OPTS}" \
|
||||||
|
--include="${TAG}/" --include="${TAG}/mp_rank_*/" --include="${TAG}/mp_rank_*/**" \
|
||||||
|
--exclude="*" \
|
||||||
|
"${h}:${CKPT_ROOT}/" "${CKPT_ROOT}/"
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "== 2/4 校验是否凑齐分片目录 =="
|
||||||
if [[ ! -d "${CKPT_ROOT}/${TAG}" ]]; then
|
if [[ ! -d "${CKPT_ROOT}/${TAG}" ]]; then
|
||||||
echo "!! 未找到 ${CKPT_ROOT}/${TAG},请确认TAG与目录名一致" >&2
|
echo "!! 未发现 ${CKPT_ROOT}/${TAG}" >&2; exit 1
|
||||||
exit 1
|
|
||||||
fi
|
fi
|
||||||
MP_CNT=$(find "${CKPT_ROOT}/${TAG}" -maxdepth 1 -type d -name "mp_rank_*" | wc -l | tr -d ' ')
|
MP_CNT=$(find "${CKPT_ROOT}/${TAG}" -maxdepth 1 -type d -name "mp_rank_*" | wc -l | tr -d ' ')
|
||||||
|
echo " - 已发现 mp_rank_* 目录数:${MP_CNT}"
|
||||||
if [[ "${MP_CNT}" -eq 0 ]]; then
|
if [[ "${MP_CNT}" -eq 0 ]]; then
|
||||||
echo "!! ${CKPT_ROOT}/${TAG} 下未发现 mp_rank_* 分片目录" >&2
|
echo "!! 没有任何 mp_rank_* 分片,请检查同步" >&2; exit 1
|
||||||
exit 1
|
|
||||||
fi
|
fi
|
||||||
echo " - 分片目录数: ${MP_CNT}"
|
|
||||||
|
|
||||||
echo "==> 3/4 合并 ZeRO-3 分片为 safetensors 到: ${OUT_DIR}"
|
echo "== 3/4 合并为 safetensors 输出到:${OUT_DIR} =="
|
||||||
python - <<PY
|
python - <<PY
|
||||||
import os
|
import os
|
||||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||||
|
|
||||||
ckpt_dir = r"${CKPT_ROOT}"
|
ckpt_dir = r"${CKPT_ROOT}"
|
||||||
out_dir = r"${OUT_DIR}"
|
out_dir = r"${OUT_DIR}"
|
||||||
tag = r"${TAG}"
|
tag = r"${TAG}"
|
||||||
|
|
||||||
os.makedirs(out_dir, exist_ok=True)
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
convert_zero_checkpoint_to_fp32_state_dict(
|
convert_zero_checkpoint_to_fp32_state_dict(
|
||||||
checkpoint_dir=ckpt_dir,
|
checkpoint_dir=ckpt_dir,
|
||||||
output_dir=out_dir,
|
output_dir=out_dir,
|
||||||
tag=tag, # None=使用latest;这里显式指定
|
tag=tag,
|
||||||
safe_serialization=True, # 写出safetensors
|
safe_serialization=True,
|
||||||
max_shard_size="${MAX_SHARD_SIZE}",
|
max_shard_size="${MAX_SHARD_SIZE}",
|
||||||
# exclude_frozen_parameters=False # 需要时可开启
|
|
||||||
)
|
)
|
||||||
print("合并完成:", out_dir)
|
print("合并完成:", out_dir)
|
||||||
PY
|
PY
|
||||||
|
|
||||||
echo "==> 3.1 拷贝 config / tokenizer 工件(如存在)"
|
echo "== 3.1 拷贝 config/tokenizer 工件(如存在)=="
|
||||||
pushd "${CKPT_ROOT}" >/dev/null
|
pushd "${CKPT_ROOT}" >/dev/null
|
||||||
for f in config.json generation_config.json tokenizer_config.json tokenizer.json merges.txt vocab.json special_tokens_map.json added_tokens.json; do
|
for f in config.json generation_config.json tokenizer_config.json tokenizer.json merges.txt vocab.json special_tokens_map.json added_tokens.json; do
|
||||||
if [[ -f "$f" ]]; then
|
[[ -f "$f" ]] && cp -n "$f" "${OUT_DIR}/"
|
||||||
cp -n "$f" "${OUT_DIR}/"
|
|
||||||
fi
|
|
||||||
done
|
done
|
||||||
popd >/dev/null
|
popd >/dev/null
|
||||||
|
|
||||||
echo "==> 4/4 运行快速加载自检(仅CPU加载meta,不占大内存)"
|
echo "== 4/4 自检(索引与config)=="
|
||||||
python - <<'PY'
|
python - <<'PY'
|
||||||
import os, json, sys
|
import os, json, sys
|
||||||
out_dir = os.environ.get("OUT_DIR")
|
out_dir = os.environ.get("OUT_DIR")
|
||||||
from transformers import AutoConfig
|
|
||||||
try:
|
|
||||||
cfg = AutoConfig.from_pretrained(out_dir)
|
|
||||||
print("模型config:", cfg.model_type, "hidden:", getattr(cfg,"hidden_size",None), "layers:", getattr(cfg,"num_hidden_layers",None))
|
|
||||||
except Exception as e:
|
|
||||||
print("读取config失败(可忽略,如无config.json):", e, file=sys.stderr)
|
|
||||||
|
|
||||||
# 校验 safetensors 索引存在
|
|
||||||
idx = os.path.join(out_dir, "model.safetensors.index.json")
|
idx = os.path.join(out_dir, "model.safetensors.index.json")
|
||||||
if os.path.exists(idx):
|
if os.path.exists(idx):
|
||||||
with open(idx) as f:
|
with open(idx) as f: j = json.load(f)
|
||||||
j = json.load(f)
|
print(f"OK: 找到 safetensors 索引:{idx}(参数条目 {len(j.get('weight_map', {}))})")
|
||||||
nfiles = len(j.get("weight_map", {}))
|
|
||||||
print(f"safetensors 索引存在:{idx} | 参数条目:{nfiles}")
|
|
||||||
else:
|
else:
|
||||||
print("未发现 model.safetensors.index.json(检查上一步是否成功)", file=sys.stderr)
|
print("WARN: 未找到 model.safetensors.index.json", file=sys.stderr)
|
||||||
|
try:
|
||||||
|
from transformers import AutoConfig
|
||||||
|
cfg = AutoConfig.from_pretrained(out_dir)
|
||||||
|
print("OK: 读取到 config:", cfg.model_type, "hidden:", getattr(cfg,'hidden_size',None), "layers:", getattr(cfg,'num_hidden_layers',None))
|
||||||
|
except Exception as e:
|
||||||
|
print("WARN: 读取 config 失败(若无 config.json 可忽略):", e, file=sys.stderr)
|
||||||
PY
|
PY
|
||||||
|
|
||||||
|
echo "== 完成:${OUT_DIR} =="
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue