jd_train/lora_merge_zero3_safetensor...

248 lines
11 KiB
Bash
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env bash
set -euo pipefail
# ===== 可调参数 =====
CKPT_ROOT="/home/test/checkpoints/q3-32b-lora" # 若实际是 .../checkpoint-62/global_step62请把 CKPT_ROOT 改成 .../checkpoint-62
TAG="global_step30"
HOSTS=(tn01 tn02 tn03 tn04 tn05 tn06)
AGGREGATOR_HOST="tn06" # 本脚本运行/汇总所在机器
EXPECTED_SHARDS_PER_HOST=4 # 每机应写出分片数(按你的并行布局)
MAX_SHARD_SIZE="5GB"
# ★★★ 新增:参考模型目录(你用来做 LoRA 的 Qwen3-32B 或其 Instruct 变体) ★★★
REF_MODEL_DIR="/home/test/Qwen3-32B"
STRICT_PRECHECK=true # true: 预检不通过就退出false: 仅告警
SSH_OPTS="-o BatchMode=yes -o StrictHostKeyChecking=accept-new -o ConnectTimeout=8"
RSYNC_OPTS="-a --info=progress2 --human-readable --partial --inplace"
# ====================
# ===== 派生参数(一般不用改) =====
EXPECTED_TOTAL_SHARDS=$(( EXPECTED_SHARDS_PER_HOST * ${#HOSTS[@]} ))
STAGING_BASE="${CKPT_ROOT}/_staging"
STAGING_TAG_DIR="${STAGING_BASE}/${TAG}"
OUT_DIR="${CKPT_ROOT}/merged-${TAG}"
TMP_PT_DIR="${CKPT_ROOT}/_tmp-fp32-pt-${TAG}" # 临时 FP32pytorch_model.bin目录
export OUT_DIR TMP_PT_DIR MAX_SHARD_SIZE REF_MODEL_DIR
# =================================
echo "== 预检查 SSH =="
for h in "${HOSTS[@]}"; do
ssh ${SSH_OPTS} "$h" "true" >/dev/null || { echo "!! 无法免密 SSH 到 $h"; exit 1; }
done
echo "== 0/7 逐节点分片预检(统计各机 ${CKPT_ROOT}/${TAG} 下的 *model_states.pt=="
remote_total=0
agg_cnt=0
for h in "${HOSTS[@]}"; do
c=$(ssh ${SSH_OPTS} "$h" "find '${CKPT_ROOT}/${TAG}' -maxdepth 1 -type f -name '*model_states.pt' 2>/dev/null | wc -l" || echo 0)
c=$(echo "$c" | tr -d ' ')
printf " - %-8s: %s 分片\n" "$h" "$c"
if [[ "$h" == "$AGGREGATOR_HOST" ]]; then
agg_cnt=$c
else
remote_total=$(( remote_total + c ))
if (( c < EXPECTED_SHARDS_PER_HOST )); then
echo "!! 预警:$h 分片仅 $c 个(期望 ${EXPECTED_SHARDS_PER_HOST}" >&2
fi
fi
done
expected_remote_total=$(( EXPECTED_TOTAL_SHARDS - EXPECTED_SHARDS_PER_HOST ))
echo " - 远端合计(不含 ${AGGREGATOR_HOST}$remote_total(期望 ${expected_remote_total}"
echo " - ${AGGREGATOR_HOST} 自身:$agg_cnt(期望 ${EXPECTED_SHARDS_PER_HOST}"
precheck_ok=true
if (( remote_total != expected_remote_total )); then
echo "!! 远端总分片不等:实际 ${remote_total} / 期望 ${expected_remote_total}" >&2
precheck_ok=false
fi
if (( agg_cnt < EXPECTED_SHARDS_PER_HOST )); then
echo "!! ${AGGREGATOR_HOST} 本机分片不足:实际 ${agg_cnt} / 期望 ${EXPECTED_SHARDS_PER_HOST}" >&2
precheck_ok=false
fi
if [[ "${STRICT_PRECHECK}" == "true" && "${precheck_ok}" == "false" ]]; then
echo "!! STRICT_PRECHECK 开启:预检不通过,停止执行" >&2
exit 2
fi
[[ "${precheck_ok}" == "true" ]] && echo "OK: 预检通过(远端=${remote_total}、本机=${agg_cnt},总计期望=${EXPECTED_TOTAL_SHARDS}" || echo "WARN: 预检未通过(分片数量与期望不符),已启用宽松模式,继续执行..."
echo "== 1/7 准备 staging 目录(干净环境)=="
rm -rf "${STAGING_TAG_DIR}"
mkdir -p "${STAGING_TAG_DIR}"
echo "== 2/7 收集分片到 staging =="
for h in "${HOSTS[@]}"; do
if ssh ${SSH_OPTS} "$h" "test -d '${CKPT_ROOT}/${TAG}'"; then
echo " - 收集 ${h}:${CKPT_ROOT}/${TAG}/ -> ${STAGING_TAG_DIR}/"
rsync ${RSYNC_OPTS} -e "ssh ${SSH_OPTS}" \
"${h}:${CKPT_ROOT}/${TAG}/" "${STAGING_TAG_DIR}/" || true
else
echo " - ${h}${CKPT_ROOT}/${TAG},跳过"
fi
done
echo "== 3/7 在 staging 校验总分片数(应为 ${EXPECTED_TOTAL_SHARDS}=="
mapfile -t SHARDS < <(find "${STAGING_TAG_DIR}" -maxdepth 1 -type f -name "*model_states.pt" | sort -u)
CNT=${#SHARDS[@]}
echo " - staging 中发现分片数:${CNT}"
if (( CNT != EXPECTED_TOTAL_SHARDS )); then
echo "!! 分片总数不等staging 实际 ${CNT} / 期望 ${EXPECTED_TOTAL_SHARDS}。请检查是否缺片或命名不一致。" >&2
exit 3
fi
echo "== 4/7 合并分片 -> 临时 FP32PyTorch .bin避免共享权重导致 safetensors 报错 =="
rm -rf "${TMP_PT_DIR}"
mkdir -p "${TMP_PT_DIR}"
python - <<PY
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
convert_zero_checkpoint_to_fp32_state_dict(
checkpoint_dir=r"${STAGING_BASE}",
output_dir=r"${TMP_PT_DIR}",
tag=r"${TAG}",
safe_serialization=False, # 先落成 .binFP32
)
print("合并完成FP32 .bin", r"${TMP_PT_DIR}")
PY
# ★★★ 改动点:从参考模型复制 config/tokenizer且强制覆盖不要从 CKPT_ROOT 拷LoRA 目录通常没 config.json★★★
echo "== 4.1/7 从参考模型复制 config/tokenizer 到临时 FP32 目录(装载需要)=="
for f in config.json generation_config.json tokenizer_config.json tokenizer.json merges.txt vocab.json special_tokens_map.json added_tokens.json; do
[[ -f "${REF_MODEL_DIR}/${f}" ]] && cp -f "${REF_MODEL_DIR}/${f}" "${TMP_PT_DIR}/" || true
done
echo "== 5/7 装载 REF 模型结构 + 灌入 FP32 权重;如检测到 LoRA则导出 adapter否则保存 BF16 分片 safetensors =="
python - <<'PY'
import os, re, json, sys, torch, shutil
from transformers import AutoConfig, AutoModelForCausalLM
from safetensors.torch import save_file
TMP_PT_DIR = os.environ["TMP_PT_DIR"]
REF_DIR = os.environ["REF_MODEL_DIR"]
OUT_DIR = os.environ["OUT_DIR"]
MAX_SHARD_SIZE = os.environ.get("MAX_SHARD_SIZE","5GB")
print("[load] ref model from:", REF_DIR)
cfg = AutoConfig.from_pretrained(REF_DIR, trust_remote_code=True) # 确保 model_type=qwen3
model = AutoModelForCausalLM.from_pretrained(
REF_DIR, config=cfg, trust_remote_code=True,
torch_dtype=torch.float32, low_cpu_mem_usage=False, device_map={"": "cpu"}
)
sd_path = os.path.join(TMP_PT_DIR, "pytorch_model.bin")
if not os.path.exists(sd_path):
print("ERR: 未找到", sd_path, file=sys.stderr); sys.exit(9)
state = torch.load(sd_path, map_location="cpu")
# 去掉可能的 'module.' 前缀
state = { (k.split("module.",1)[-1]): v for k, v in state.items() }
# 识别 LoRA
has_lora = any((".lora_A" in k) or (".lora_B" in k) or (".lora_alpha" in k) for k in state)
print("[check] contains LoRA keys:", has_lora)
if has_lora:
# ——导出 LoRA 适配器(不合到基座)——
lora_state = {k: v for k, v in state.items()
if (".lora_A" in k) or (".lora_B" in k) or (".lora_alpha" in k)}
if not lora_state:
print("ERR: 识别到 LoRA 但未筛出权重;中止。", file=sys.stderr); sys.exit(10)
lora_A_keys = [k for k in lora_state if k.endswith(".lora_A.weight")]
r = state[lora_A_keys[0]].shape[0] if lora_A_keys else 16
# alpha优先读到的第一个没有就用 r
alpha = r
for k in lora_A_keys:
a = state.get(k.replace(".lora_A.weight", ".lora_alpha"))
if a is not None:
alpha = int(a.item()); break
# 统计 target_modules叶子模块名
targets = sorted(set(re.sub(r"\.lora_(A|B)\.weight$", "", k) for k in lora_A_keys))
target_modules = sorted(set(t.split(".")[-1] for t in targets)) or ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]
adapters_dir = os.path.join(OUT_DIR, "adapters")
os.makedirs(adapters_dir, exist_ok=True)
save_file(lora_state, os.path.join(adapters_dir, "adapter_model.safetensors"))
adapter_cfg = {
"peft_type": "LORA",
"base_model_name_or_path": REF_DIR,
"r": int(r),
"lora_alpha": int(alpha),
"lora_dropout": 0.0,
"bias": "none",
"task_type": "CAUSAL_LM",
"target_modules": target_modules
}
with open(os.path.join(adapters_dir, "adapter_config.json"), "w", encoding="utf-8") as f:
json.dump(adapter_cfg, f, ensure_ascii=False, indent=2)
# 复制 tokenizer/generation/config便于推理端直接使用
for f in ("tokenizer_config.json","tokenizer.json","merges.txt","vocab.json","special_tokens_map.json","added_tokens.json","generation_config.json","config.json"):
src = os.path.join(REF_DIR, f)
if os.path.exists(src):
dst = os.path.join(adapters_dir, f)
if not os.path.exists(dst):
try: shutil.copy(src, dst)
except Exception: pass
print("[save] 导出了 LoRA 适配器 →", adapters_dir)
print("INFO: 可用 Transformers/vLLM/SGLang 以『REF_MODEL + adapters/adapter_model.safetensors』方式推理。")
else:
# ——无 LoRA按密集权重流程保存 BF16 分片 safetensors——
missing, unexpected = model.load_state_dict(state, strict=False)
print(f"[load] missing={len(missing)} unexpected={len(unexpected)}")
# untie如需要
try:
emb = model.get_input_embeddings().weight if hasattr(model, "get_input_embeddings") else None
head = model.lm_head.weight if hasattr(model, "lm_head") else None
if emb is not None and head is not None and emb.data_ptr() == head.data_ptr():
with torch.no_grad():
model.lm_head.weight = torch.nn.Parameter(head.detach().clone())
print("[fix] untied lm_head from embed_tokens")
except Exception as e:
print("[fix] skip untie check:", e)
model.to(dtype=torch.bfloat16)
os.makedirs(OUT_DIR, exist_ok=True)
model.save_pretrained(OUT_DIR, safe_serialization=True, max_shard_size=MAX_SHARD_SIZE)
print("[save] BF16 safetensors →", OUT_DIR)
PY
echo "== 5.1/7 拷贝(/补齐)最终目录的 tokenizer/config 工件(如存在)=="
for f in config.json generation_config.json tokenizer_config.json tokenizer.json merges.txt vocab.json special_tokens_map.json added_tokens.json; do
[[ -f "${REF_MODEL_DIR}/${f}" ]] && cp -n "${REF_MODEL_DIR}/${f}" "${OUT_DIR}/" || true
done
echo "== 6/7 自检(索引与 config=="
python - <<'PY'
import os, json, sys
out_dir = os.environ.get("OUT_DIR")
idx = os.path.join(out_dir, "model.safetensors.index.json")
if os.path.exists(idx):
with open(idx) as f: j = json.load(f)
print(f"OK: 找到 safetensors 索引:{idx}(参数条目 {len(j.get('weight_map', {}))}")
else:
sfts = [x for x in os.listdir(out_dir) if x.endswith(".safetensors")]
if len(sfts) == 1:
print(f"NOTE: 单分片 safetensors{sfts[0]}")
elif len(sfts) == 0:
# LoRA 分支下可能没有模型分片(只导出 adapters
print("NOTE: 未发现模型分片(若已导出 adapters/ 则属正常)")
else:
print("WARN: 未找到 model.safetensors.index.json且分片数 != 1", file=sys.stderr)
try:
from transformers import AutoConfig
cfg = AutoConfig.from_pretrained(out_dir, trust_remote_code=True)
print("OK: 读取到 config", cfg.model_type, "hidden:", getattr(cfg,'hidden_size',None), "layers:", getattr(cfg,'num_hidden_layers',None))
except Exception as e:
print("NOTE: 最终目录无 config若为纯 adapters 导出则正常):", e)
PY
echo "== 7/7 清理提示 =="
echo "临时 FP32 目录:${TMP_PT_DIR}"
echo "输出目录(无 LoRA=密集权重;有 LoRA=adapters/${OUT_DIR}"
echo "完成。"