#!/usr/bin/env bash set -euo pipefail # ===== 可调参数 ===== CKPT_ROOT="/home/test/checkpoints/q3-32b-lora" # 若实际是 .../checkpoint-62/global_step62,请把 CKPT_ROOT 改成 .../checkpoint-62 TAG="global_step200" HOSTS=(tn01 tn02 tn03 tn04 tn05 tn06) AGGREGATOR_HOST="tn06" # 本脚本运行/汇总所在机器 EXPECTED_SHARDS_PER_HOST=4 # 每机应写出分片数(按你的并行布局) MAX_SHARD_SIZE="5GB" # ★★★ 参考模型目录(用来做 LoRA 的 Qwen3-32B 或其 Instruct 变体) ★★★ REF_MODEL_DIR="/home/test/Qwen3-32B" STRICT_PRECHECK=true # true: 预检不通过就退出;false: 仅告警 SSH_OPTS="-o BatchMode=yes -o StrictHostKeyChecking=accept-new -o ConnectTimeout=8" RSYNC_OPTS="-a --info=progress2 --human-readable --partial --inplace" # ==================== # ===== 派生参数(一般不用改) ===== EXPECTED_TOTAL_SHARDS=$(( EXPECTED_SHARDS_PER_HOST * ${#HOSTS[@]} )) STAGING_BASE="${CKPT_ROOT}/_staging" STAGING_TAG_DIR="${STAGING_BASE}/${TAG}" OUT_DIR="${CKPT_ROOT}/merged-${TAG}" TMP_PT_DIR="${CKPT_ROOT}/_tmp-fp32-pt-${TAG}" # 临时 FP32 输出目录 export OUT_DIR TMP_PT_DIR MAX_SHARD_SIZE REF_MODEL_DIR # ================================= echo "== 预检查 SSH ==" for h in "${HOSTS[@]}"; do ssh ${SSH_OPTS} "$h" "true" >/dev/null || { echo "!! 无法免密 SSH 到 $h"; exit 1; } done echo "== 0/7 逐节点分片预检(统计各机 ${CKPT_ROOT}/${TAG} 下的 *model_states.pt)==" remote_total=0 agg_cnt=0 for h in "${HOSTS[@]}"; do c=$(ssh ${SSH_OPTS} "$h" "find '${CKPT_ROOT}/${TAG}' -maxdepth 1 -type f -name '*model_states.pt' 2>/dev/null | wc -l" || echo 0) c=$(echo "$c" | tr -d ' ') printf " - %-8s: %s 分片\n" "$h" "$c" if [[ "$h" == "$AGGREGATOR_HOST" ]]; then agg_cnt=$c else remote_total=$(( remote_total + c )) if (( c < EXPECTED_SHARDS_PER_HOST )); then echo "!! 预警:$h 分片仅 $c 个(期望 ${EXPECTED_SHARDS_PER_HOST})" >&2 fi fi done expected_remote_total=$(( EXPECTED_TOTAL_SHARDS - EXPECTED_SHARDS_PER_HOST )) echo " - 远端合计(不含 ${AGGREGATOR_HOST}):$remote_total(期望 ${expected_remote_total})" echo " - ${AGGREGATOR_HOST} 自身:$agg_cnt(期望 ${EXPECTED_SHARDS_PER_HOST})" precheck_ok=true if (( remote_total != expected_remote_total )); then echo "!! 远端总分片不等:实际 ${remote_total} / 期望 ${expected_remote_total}" >&2 precheck_ok=false fi if (( agg_cnt < EXPECTED_SHARDS_PER_HOST )); then echo "!! ${AGGREGATOR_HOST} 本机分片不足:实际 ${agg_cnt} / 期望 ${EXPECTED_SHARDS_PER_HOST}" >&2 precheck_ok=false fi if [[ "${STRICT_PRECHECK}" == "true" && "${precheck_ok}" == "false" ]]; then echo "!! STRICT_PRECHECK 开启:预检不通过,停止执行" >&2 exit 2 fi [[ "${precheck_ok}" == "true" ]] && echo "OK: 预检通过(远端=${remote_total}、本机=${agg_cnt},总计期望=${EXPECTED_TOTAL_SHARDS})" || echo "WARN: 预检未通过(分片数量与期望不符),已启用宽松模式,继续执行..." echo "== 1/7 准备 staging 目录(干净环境)==" rm -rf "${STAGING_TAG_DIR}" mkdir -p "${STAGING_TAG_DIR}" echo "== 2/7 收集分片到 staging ==" for h in "${HOSTS[@]}"; do if ssh ${SSH_OPTS} "$h" "test -d '${CKPT_ROOT}/${TAG}'"; then echo " - 收集 ${h}:${CKPT_ROOT}/${TAG}/ -> ${STAGING_TAG_DIR}/" rsync ${RSYNC_OPTS} -e "ssh ${SSH_OPTS}" \ "${h}:${CKPT_ROOT}/${TAG}/" "${STAGING_TAG_DIR}/" || true else echo " - ${h} 无 ${CKPT_ROOT}/${TAG},跳过" fi done echo "== 3/7 在 staging 校验总分片数(应为 ${EXPECTED_TOTAL_SHARDS})==" mapfile -t SHARDS < <(find "${STAGING_TAG_DIR}" -maxdepth 1 -type f -name "*model_states.pt" | sort -u) CNT=${#SHARDS[@]} echo " - staging 中发现分片数:${CNT}" if (( CNT != EXPECTED_TOTAL_SHARDS )); then echo "!! 分片总数不等:staging 实际 ${CNT} / 期望 ${EXPECTED_TOTAL_SHARDS}。请检查是否缺片或命名不一致。" >&2 exit 3 fi echo "== 4/7 合并分片 -> 临时 FP32(优先单文件 pytorch_model.bin;不支持则分片)==" rm -rf "${TMP_PT_DIR}" mkdir -p "${TMP_PT_DIR}" python - <bool: return k.endswith(".lora_A.weight") or k.endswith(".lora_A.default.weight") def alpha_key_for(kA:str)->str: if kA.endswith(".lora_A.weight"): return kA.replace(".lora_A.weight", ".lora_alpha") if kA.endswith(".lora_A.default.weight"): return kA.replace(".lora_A.default.weight", ".lora_alpha.default") return "" log("[load] ref model from:", REF_DIR) cfg = AutoConfig.from_pretrained(REF_DIR, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( REF_DIR, config=cfg, trust_remote_code=True, torch_dtype=torch.float32, low_cpu_mem_usage=False, device_map={"": "cpu"} ) state = None has_lora = False if os.path.exists(sd_single): log("[fp32] detected single file:", sd_single) state = torch.load(sd_single, map_location="cpu") # 去 'module.' 前缀 state = { (k.split("module.",1)[-1]): v for k, v in state.items() } has_lora = any((".lora_A" in k) or (".lora_B" in k) or (".lora_alpha" in k) for k in state) elif os.path.exists(idx_json): log("[fp32] detected sharded weights with index:", idx_json) weight_map, shard_to_keys = parse_index(idx_json) has_lora = detect_lora_from_index(weight_map) if has_lora: log("[check] contains LoRA keys: True (stream collecting)") lora_state, lora_keys = stream_collect_lora_from_shards(shard_to_keys) if not lora_state: print("ERR: 识别到 LoRA 但未筛出权重;中止。", file=sys.stderr); sys.exit(10) # 估 r / alpha(兼容 .default) lora_A_keys = [k for k in lora_keys if is_lora_A(k)] if lora_A_keys: k0 = lora_A_keys[0] shard = next(s for s, ks in shard_to_keys.items() if k0 in ks) part = torch.load(os.path.join(TMP_PT_DIR, shard), map_location="cpu") r = part[k0].shape[0] ak = alpha_key_for(k0) a = part.get(ak) alpha = int(a.item()) if (a is not None and hasattr(a, "item")) else int(r) else: r, alpha = 16, 16 adapters_dir = os.path.join(OUT_DIR, "adapters") os.makedirs(adapters_dir, exist_ok=True) save_file(lora_state, os.path.join(adapters_dir, "adapter_model.safetensors")) targets = sorted(set(re.sub(r"\.lora_(A|B)(?:\.default)?\.weight$", "", k) for k in lora_A_keys)) target_modules = sorted(set(t.split(".")[-1] for t in targets)) or ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"] # ☆ 优先复用训练时的 adapter_config.json(若存在),只覆盖 base_model_name_or_path copied_cfg = False try: CKPT_ROOT = os.path.abspath(os.path.join(TMP_PT_DIR, os.pardir)) src_cfg = os.path.join(CKPT_ROOT, "adapter_config.json") if os.path.exists(src_cfg): with open(src_cfg, "r", encoding="utf-8") as f: adapter_cfg = json.load(f) adapter_cfg["base_model_name_or_path"] = REF_DIR with open(os.path.join(adapters_dir, "adapter_config.json"), "w", encoding="utf-8") as f: json.dump(adapter_cfg, f, ensure_ascii=False, indent=2) copied_cfg = True except Exception: copied_cfg = False if not copied_cfg: adapter_cfg = { "peft_type": "LORA", "base_model_name_or_path": REF_DIR, "r": int(r), "lora_alpha": int(alpha), "lora_dropout": 0.05, "bias": "none", "task_type": "CAUSAL_LM", "target_modules": target_modules } with open(os.path.join(adapters_dir, "adapter_config.json"), "w", encoding="utf-8") as f: json.dump(adapter_cfg, f, ensure_ascii=False, indent=2) # 复制 tokenizer/generation/config for f in ("tokenizer_config.json","tokenizer.json","merges.txt","vocab.json","special_tokens_map.json","added_tokens.json","generation_config.json","config.json"): src = os.path.join(REF_DIR, f) if os.path.exists(src): dst = os.path.join(adapters_dir, f) if not os.path.exists(dst): try: shutil.copy(src, dst) except Exception: pass log("[save] 导出了 LoRA 适配器 →", adapters_dir) log("INFO: 可用 Transformers/vLLM/SGLang 以『REF模型 + adapters/adapter_model.safetensors』方式推理。") sys.exit(0) else: log("[check] contains LoRA keys: False (stream loading into model)") stream_load_full_into_model(model, shard_to_keys) else: # 无单文件也无 index.json,尝试兜底按分片名加载 shard_glob = sorted(glob.glob(os.path.join(TMP_PT_DIR, "pytorch_model-*.bin"))) if not shard_glob: print("ERR: 未找到单文件或分片 FP32(pytorch_model.bin / .index.json / 分片)", file=sys.stderr); sys.exit(9) log(f"[fp32] detected {len(shard_glob)} shards (no index.json), brute-load") # 粗暴合并(可能占内存,但作为兜底) state = {} for sf in shard_glob: part = torch.load(sf, map_location="cpu") state.update(part) state = { (k.split("module.",1)[-1]): v for k, v in state.items() } has_lora = any((".lora_A" in k) or (".lora_B" in k) or (".lora_alpha" in k) for k in state) if state is not None and has_lora: log("[check] contains LoRA keys: True") lora_state = {k: v for k, v in state.items() if (".lora_A" in k) or (".lora_B" in k) or (".lora_alpha" in k)} if not lora_state: print("ERR: 识别到 LoRA 但未筛出权重;中止。", file=sys.stderr); sys.exit(10) lora_A_keys = [k for k in lora_state if k.endswith(".lora_A.weight") or k.endswith(".lora_A.default.weight")] if lora_A_keys: k0 = lora_A_keys[0] r = state[k0].shape[0] ak = k0.replace(".lora_A.weight", ".lora_alpha").replace(".lora_A.default.weight", ".lora_alpha.default") a = state.get(ak) alpha = int(a.item()) if (a is not None and hasattr(a, "item")) else int(r) else: r, alpha = 16, 16 targets = sorted(set(re.sub(r"\.lora_(A|B)(?:\.default)?\.weight$", "", k) for k in lora_A_keys)) target_modules = sorted(set(t.split(".")[-1] for t in targets)) or ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"] adapters_dir = os.path.join(OUT_DIR, "adapters") os.makedirs(adapters_dir, exist_ok=True) save_file(lora_state, os.path.join(adapters_dir, "adapter_model.safetensors")) # ☆ 优先复用训练时的 adapter_config.json(若存在),只覆盖 base_model_name_or_path copied_cfg = False try: CKPT_ROOT = os.path.abspath(os.path.join(TMP_PT_DIR, os.pardir)) src_cfg = os.path.join(CKPT_ROOT, "adapter_config.json") if os.path.exists(src_cfg): with open(src_cfg, "r", encoding="utf-8") as f: adapter_cfg = json.load(f) adapter_cfg["base_model_name_or_path"] = REF_DIR with open(os.path.join(adapters_dir, "adapter_config.json"), "w", encoding="utf-8") as f: json.dump(adapter_cfg, f, ensure_ascii=False, indent=2) copied_cfg = True except Exception: copied_cfg = False if not copied_cfg: adapter_cfg = { "peft_type": "LORA", "base_model_name_or_path": REF_DIR, "r": int(r), "lora_alpha": int(alpha), "lora_dropout": 0.05, "bias": "none", "task_type": "CAUSAL_LM", "target_modules": target_modules } with open(os.path.join(adapters_dir, "adapter_config.json"), "w", encoding="utf-8") as f: json.dump(adapter_cfg, f, ensure_ascii=False, indent=2) for f in ("tokenizer_config.json","tokenizer.json","merges.txt","vocab.json","special_tokens_map.json","added_tokens.json","generation_config.json","config.json"): src = os.path.join(REF_DIR, f) if os.path.exists(src): dst = os.path.join(adapters_dir, f) if not os.path.exists(dst): try: shutil.copy(src, dst) except Exception: pass log("[save] 导出了 LoRA 适配器 →", adapters_dir) log("INFO: 可用 Transformers/vLLM/SGLang 以『REF模型 + adapters/adapter_model.safetensors』方式推理。") sys.exit(0) # ——走到这里表示“无 LoRA”,把 FP32 权重注入模型并保存 BF16 safetensors—— if state is not None: missing, unexpected = model.load_state_dict(state, strict=False) log(f"[load] missing={len(missing)} unexpected={len(unexpected)}") # untie(如需要) try: emb = model.get_input_embeddings().weight if hasattr(model, "get_input_embeddings") else None head = model.lm_head.weight if hasattr(model, "lm_head") else None if emb is not None and head is not None and emb.data_ptr() == head.data_ptr(): with torch.no_grad(): model.lm_head.weight = torch.nn.Parameter(head.detach().clone()) log("[fix] untied lm_head from embed_tokens") except Exception as e: log("[fix] skip untie check:", e) model.to(dtype=torch.bfloat16) os.makedirs(OUT_DIR, exist_ok=True) model.save_pretrained(OUT_DIR, safe_serialization=True, max_shard_size=MAX_SHARD_SIZE) log("[save] BF16 safetensors →", OUT_DIR) PY echo "== 5.1/7 拷贝(/补齐)最终目录的 tokenizer/config 工件(如存在)==" for f in config.json generation_config.json tokenizer_config.json tokenizer.json merges.txt vocab.json special_tokens_map.json added_tokens.json; do [[ -f "${REF_MODEL_DIR}/${f}" ]] && cp -n "${REF_MODEL_DIR}/${f}" "${OUT_DIR}/" || true done echo "== 6/7 自检(索引与 config)==" python - <<'PY' import os, json out_dir = os.environ.get("OUT_DIR") idx = os.path.join(out_dir, "model.safetensors.index.json") if os.path.exists(idx): with open(idx) as f: j = json.load(f) print(f"OK: 找到 safetensors 索引:{idx}(参数条目 {len(j.get('weight_map', {}))})") else: sfts = [x for x in os.listdir(out_dir) if x.endswith(".safetensors")] if len(sfts) == 1: print(f"NOTE: 单分片 safetensors:{sfts[0]}") elif len(sfts) == 0: # LoRA 分支下可能没有模型分片(只导出 adapters) print("NOTE: 未发现模型分片(若已导出 adapters/ 则属正常)") else: print("WARN: 未找到 model.safetensors.index.json,且分片数 != 1") try: from transformers import AutoConfig cfg = AutoConfig.from_pretrained(out_dir, trust_remote_code=True) print("OK: 读取到 config:", cfg.model_type, "hidden:", getattr(cfg,'hidden_size',None), "layers:", getattr(cfg,'num_hidden_layers',None)) except Exception as e: print("NOTE: 最终目录无 config(若为纯 adapters 导出则正常):", e) PY echo "== 7/7 清理提示 ==" echo "临时 FP32 目录:${TMP_PT_DIR}" echo "输出目录(无 LoRA=密集权重;有 LoRA=adapters/):${OUT_DIR}" echo "完成。"