jd_train/merge_zero3_safetensors.sh

191 lines
8.0 KiB
Bash
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env bash
set -euo pipefail
# ===== 可调参数 =====
CKPT_ROOT="/home/test/checkpoints/q3-32b-lora" # 若实际是 .../checkpoint-62/global_step62请把 CKPT_ROOT 改成 .../checkpoint-62
TAG="global_step110"
HOSTS=(tn01 tn02 tn03 tn04 tn05 tn06)
AGGREGATOR_HOST="tn06" # 本脚本运行/汇总所在机器
EXPECTED_SHARDS_PER_HOST=4 # 每机应写出分片数(按你的并行布局)
MAX_SHARD_SIZE="5GB"
STRICT_PRECHECK=true # true: 预检不通过就退出false: 仅告警
SSH_OPTS="-o BatchMode=yes -o StrictHostKeyChecking=accept-new -o ConnectTimeout=8"
RSYNC_OPTS="-a --info=progress2 --human-readable --partial --inplace"
# ====================
# ===== 派生参数(一般不用改) =====
EXPECTED_TOTAL_SHARDS=$(( EXPECTED_SHARDS_PER_HOST * ${#HOSTS[@]} ))
STAGING_BASE="${CKPT_ROOT}/_staging"
STAGING_TAG_DIR="${STAGING_BASE}/${TAG}"
OUT_DIR="${CKPT_ROOT}/merged-${TAG}"
TMP_PT_DIR="${CKPT_ROOT}/_tmp-fp32-pt-${TAG}" # 临时 FP32pytorch_model.bin目录
export OUT_DIR TMP_PT_DIR MAX_SHARD_SIZE
# =================================
echo "== 预检查 SSH =="
for h in "${HOSTS[@]}"; do
ssh ${SSH_OPTS} "$h" "true" >/dev/null || { echo "!! 无法免密 SSH 到 $h"; exit 1; }
done
echo "== 0/7 逐节点分片预检(统计各机 ${CKPT_ROOT}/${TAG} 下的 *model_states.pt=="
remote_total=0
agg_cnt=0
for h in "${HOSTS[@]}"; do
c=$(ssh ${SSH_OPTS} "$h" "find '${CKPT_ROOT}/${TAG}' -maxdepth 1 -type f -name '*model_states.pt' 2>/dev/null | wc -l" || echo 0)
c=$(echo "$c" | tr -d ' ')
printf " - %-8s: %s 分片\n" "$h" "$c"
if [[ "$h" == "$AGGREGATOR_HOST" ]]; then
agg_cnt=$c
else
remote_total=$(( remote_total + c ))
# 每台机最简单的 sanity至少应有 EXPECTED_SHARDS_PER_HOST 个
if (( c < EXPECTED_SHARDS_PER_HOST )); then
echo "!! 预警:$h 分片仅 $c 个(期望 ${EXPECTED_SHARDS_PER_HOST}" >&2
fi
fi
done
expected_remote_total=$(( EXPECTED_TOTAL_SHARDS - EXPECTED_SHARDS_PER_HOST ))
echo " - 远端合计(不含 ${AGGREGATOR_HOST}$remote_total(期望 ${expected_remote_total}"
echo " - ${AGGREGATOR_HOST} 自身:$agg_cnt(期望 ${EXPECTED_SHARDS_PER_HOST}"
precheck_ok=true
if (( remote_total != expected_remote_total )); then
echo "!! 远端总分片不等:实际 ${remote_total} / 期望 ${expected_remote_total}" >&2
precheck_ok=false
fi
if (( agg_cnt < EXPECTED_SHARDS_PER_HOST )); then
echo "!! ${AGGREGATOR_HOST} 本机分片不足:实际 ${agg_cnt} / 期望 ${EXPECTED_SHARDS_PER_HOST}" >&2
precheck_ok=false
fi
if [[ "${STRICT_PRECHECK}" == "true" && "${precheck_ok}" == "false" ]]; then
echo "!! STRICT_PRECHECK 开启:预检不通过,停止执行" >&2
exit 2
fi
[[ "${precheck_ok}" == "true" ]] && echo "OK: 预检通过(远端=${remote_total}、本机=${agg_cnt},总计期望=${EXPECTED_TOTAL_SHARDS}" || echo "WARN: 预检未通过(分片数量与期望不符),已启用宽松模式,继续执行..."
echo "== 1/7 准备 staging 目录(干净环境)=="
rm -rf "${STAGING_TAG_DIR}"
mkdir -p "${STAGING_TAG_DIR}"
echo "== 2/7 收集分片到 staging =="
for h in "${HOSTS[@]}"; do
if ssh ${SSH_OPTS} "$h" "test -d '${CKPT_ROOT}/${TAG}'"; then
echo " - 收集 ${h}:${CKPT_ROOT}/${TAG}/ -> ${STAGING_TAG_DIR}/"
rsync ${RSYNC_OPTS} -e "ssh ${SSH_OPTS}" \
"${h}:${CKPT_ROOT}/${TAG}/" "${STAGING_TAG_DIR}/" || true
else
echo " - ${h}${CKPT_ROOT}/${TAG},跳过"
fi
done
echo "== 3/7 在 staging 校验总分片数(应为 ${EXPECTED_TOTAL_SHARDS}=="
mapfile -t SHARDS < <(find "${STAGING_TAG_DIR}" -maxdepth 1 -type f -name "*model_states.pt" | sort -u)
CNT=${#SHARDS[@]}
echo " - staging 中发现分片数:${CNT}"
if (( CNT != EXPECTED_TOTAL_SHARDS )); then
echo "!! 分片总数不等staging 实际 ${CNT} / 期望 ${EXPECTED_TOTAL_SHARDS}。请检查是否缺片或命名不一致。" >&2
exit 3
fi
echo "== 4/7 合并分片 -> 临时 FP32PyTorch .bin避免共享权重导致 safetensors 报错 =="
rm -rf "${TMP_PT_DIR}"
mkdir -p "${TMP_PT_DIR}"
# 直接走 APIsafe_serialization=False -> 生成 pytorch_model.binFP32
python - <<PY
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
convert_zero_checkpoint_to_fp32_state_dict(
checkpoint_dir=r"${STAGING_BASE}",
output_dir=r"${TMP_PT_DIR}",
tag=r"${TAG}",
safe_serialization=False, # 关键:先落成 .binFP32绕开共享权重的 safetensors 限制
)
print("合并完成FP32 .bin", r"${TMP_PT_DIR}")
PY
echo "== 4.1/7 将 config/tokenizer 工件拷贝到临时 FP32 目录(装载需要)=="
for f in config.json generation_config.json tokenizer_config.json tokenizer.json merges.txt vocab.json special_tokens_map.json added_tokens.json; do
[[ -f "${CKPT_ROOT}/${f}" ]] && cp -n "${CKPT_ROOT}/${f}" "${TMP_PT_DIR}/" || true
done
echo "== 5/7 FP32 -> BF16并解开 lm_head <-> embed_tokens 共享存储,保存为分片 safetensors${MAX_SHARD_SIZE}=="
python - <<'PY'
import os, sys, torch
from transformers import AutoConfig, AutoModelForCausalLM
TMP_PT_DIR = os.environ["TMP_PT_DIR"]
OUT_DIR = os.environ["OUT_DIR"]
MAX_SHARD_SIZE = os.environ.get("MAX_SHARD_SIZE", "5GB")
print("[load] from:", TMP_PT_DIR)
cfg = AutoConfig.from_pretrained(TMP_PT_DIR, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
TMP_PT_DIR,
config=cfg,
trust_remote_code=True,
torch_dtype=torch.bfloat16, # 目标 BF16
low_cpu_mem_usage=True,
device_map={"": "cpu"}, # 全在 CPU 装载,避免吃显存
)
# —— 如 lm_head 与 embed_tokens 权重共享,则手动 untie防止后续 safetensors 报共享存储 —— #
try:
emb = model.get_input_embeddings().weight if hasattr(model, "get_input_embeddings") else None
head = model.lm_head.weight if hasattr(model, "lm_head") else None
if emb is not None and head is not None and emb.data_ptr() == head.data_ptr():
with torch.no_grad():
model.lm_head.weight = torch.nn.Parameter(head.detach().clone())
print("[fix] Untied shared weights: lm_head.weight cloned from embed_tokens.weight")
else:
print("[fix] No shared storage detected between lm_head and embed_tokens")
except Exception as e:
print("[fix] Skip untie check:", e, file=sys.stderr)
# 再确保全模型 dtype 为 BF16
model.to(dtype=torch.bfloat16)
# 分片 safetensors支持大模型
os.makedirs(OUT_DIR, exist_ok=True)
model.save_pretrained(
OUT_DIR,
safe_serialization=True, # 写 safetensors
max_shard_size=MAX_SHARD_SIZE, # 分片上限
)
print("[save] BF16 safetensors saved to:", OUT_DIR)
PY
echo "== 5.1/7 拷贝(/补齐)最终目录的 tokenizer 工件(如存在)=="
for f in tokenizer_config.json tokenizer.json merges.txt vocab.json special_tokens_map.json added_tokens.json; do
[[ -f "${CKPT_ROOT}/${f}" ]] && cp -n "${CKPT_ROOT}/${f}" "${OUT_DIR}/" || true
done
echo "== 6/7 自检(索引与 config=="
python - <<'PY'
import os, json, sys
out_dir = os.environ.get("OUT_DIR")
idx = os.path.join(out_dir, "model.safetensors.index.json")
if os.path.exists(idx):
with open(idx) as f: j = json.load(f)
print(f"OK: 找到 safetensors 索引:{idx}(参数条目 {len(j.get('weight_map', {}))}")
else:
# 单分片也可能没有 index.json
sfts = [x for x in os.listdir(out_dir) if x.endswith(".safetensors")]
if len(sfts) == 1:
print(f"NOTE: 单分片 safetensors{sfts[0]}")
else:
print("WARN: 未找到 model.safetensors.index.json且分片数 != 1", file=sys.stderr)
try:
from transformers import AutoConfig
cfg = AutoConfig.from_pretrained(out_dir, trust_remote_code=True)
print("OK: 读取到 config", cfg.model_type, "hidden:", getattr(cfg,'hidden_size',None), "layers:", getattr(cfg,'num_hidden_layers',None))
except Exception as e:
print("WARN: 读取 config 失败(若无 config.json 可忽略):", e, file=sys.stderr)
PY
echo "== 7/7 清理提示 =="
echo "临时 FP32 目录:${TMP_PT_DIR}"
echo "BF16 safetensors 输出:${OUT_DIR}"
echo "完成。"