jd_train/lora_merge_zero3_safetensor...

391 lines
17 KiB
Bash
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env bash
set -euo pipefail
# ===== 可调参数 =====
CKPT_ROOT="/home/test/checkpoints/q3-32b-lora" # 若实际是 .../checkpoint-62/global_step62请把 CKPT_ROOT 改成 .../checkpoint-62
TAG="global_step30"
HOSTS=(tn01 tn02 tn03 tn04 tn05 tn06)
AGGREGATOR_HOST="tn06" # 本脚本运行/汇总所在机器
EXPECTED_SHARDS_PER_HOST=4 # 每机应写出分片数(按你的并行布局)
MAX_SHARD_SIZE="5GB"
# ★★★ 参考模型目录(用来做 LoRA 的 Qwen3-32B 或其 Instruct 变体) ★★★
REF_MODEL_DIR="/home/test/Qwen3-32B"
STRICT_PRECHECK=true # true: 预检不通过就退出false: 仅告警
SSH_OPTS="-o BatchMode=yes -o StrictHostKeyChecking=accept-new -o ConnectTimeout=8"
RSYNC_OPTS="-a --info=progress2 --human-readable --partial --inplace"
# ====================
# ===== 派生参数(一般不用改) =====
EXPECTED_TOTAL_SHARDS=$(( EXPECTED_SHARDS_PER_HOST * ${#HOSTS[@]} ))
STAGING_BASE="${CKPT_ROOT}/_staging"
STAGING_TAG_DIR="${STAGING_BASE}/${TAG}"
OUT_DIR="${CKPT_ROOT}/merged-${TAG}"
TMP_PT_DIR="${CKPT_ROOT}/_tmp-fp32-pt-${TAG}" # 临时 FP32 输出目录
export OUT_DIR TMP_PT_DIR MAX_SHARD_SIZE REF_MODEL_DIR
# =================================
echo "== 预检查 SSH =="
for h in "${HOSTS[@]}"; do
ssh ${SSH_OPTS} "$h" "true" >/dev/null || { echo "!! 无法免密 SSH 到 $h"; exit 1; }
done
echo "== 0/7 逐节点分片预检(统计各机 ${CKPT_ROOT}/${TAG} 下的 *model_states.pt=="
remote_total=0
agg_cnt=0
for h in "${HOSTS[@]}"; do
c=$(ssh ${SSH_OPTS} "$h" "find '${CKPT_ROOT}/${TAG}' -maxdepth 1 -type f -name '*model_states.pt' 2>/dev/null | wc -l" || echo 0)
c=$(echo "$c" | tr -d ' ')
printf " - %-8s: %s 分片\n" "$h" "$c"
if [[ "$h" == "$AGGREGATOR_HOST" ]]; then
agg_cnt=$c
else
remote_total=$(( remote_total + c ))
if (( c < EXPECTED_SHARDS_PER_HOST )); then
echo "!! 预警:$h 分片仅 $c 个(期望 ${EXPECTED_SHARDS_PER_HOST}" >&2
fi
fi
done
expected_remote_total=$(( EXPECTED_TOTAL_SHARDS - EXPECTED_SHARDS_PER_HOST ))
echo " - 远端合计(不含 ${AGGREGATOR_HOST}$remote_total(期望 ${expected_remote_total}"
echo " - ${AGGREGATOR_HOST} 自身:$agg_cnt(期望 ${EXPECTED_SHARDS_PER_HOST}"
precheck_ok=true
if (( remote_total != expected_remote_total )); then
echo "!! 远端总分片不等:实际 ${remote_total} / 期望 ${expected_remote_total}" >&2
precheck_ok=false
fi
if (( agg_cnt < EXPECTED_SHARDS_PER_HOST )); then
echo "!! ${AGGREGATOR_HOST} 本机分片不足:实际 ${agg_cnt} / 期望 ${EXPECTED_SHARDS_PER_HOST}" >&2
precheck_ok=false
fi
if [[ "${STRICT_PRECHECK}" == "true" && "${precheck_ok}" == "false" ]]; then
echo "!! STRICT_PRECHECK 开启:预检不通过,停止执行" >&2
exit 2
fi
[[ "${precheck_ok}" == "true" ]] && echo "OK: 预检通过(远端=${remote_total}、本机=${agg_cnt},总计期望=${EXPECTED_TOTAL_SHARDS}" || echo "WARN: 预检未通过(分片数量与期望不符),已启用宽松模式,继续执行..."
echo "== 1/7 准备 staging 目录(干净环境)=="
rm -rf "${STAGING_TAG_DIR}"
mkdir -p "${STAGING_TAG_DIR}"
echo "== 2/7 收集分片到 staging =="
for h in "${HOSTS[@]}"; do
if ssh ${SSH_OPTS} "$h" "test -d '${CKPT_ROOT}/${TAG}'"; then
echo " - 收集 ${h}:${CKPT_ROOT}/${TAG}/ -> ${STAGING_TAG_DIR}/"
rsync ${RSYNC_OPTS} -e "ssh ${SSH_OPTS}" \
"${h}:${CKPT_ROOT}/${TAG}/" "${STAGING_TAG_DIR}/" || true
else
echo " - ${h}${CKPT_ROOT}/${TAG},跳过"
fi
done
echo "== 3/7 在 staging 校验总分片数(应为 ${EXPECTED_TOTAL_SHARDS}=="
mapfile -t SHARDS < <(find "${STAGING_TAG_DIR}" -maxdepth 1 -type f -name "*model_states.pt" | sort -u)
CNT=${#SHARDS[@]}
echo " - staging 中发现分片数:${CNT}"
if (( CNT != EXPECTED_TOTAL_SHARDS )); then
echo "!! 分片总数不等staging 实际 ${CNT} / 期望 ${EXPECTED_TOTAL_SHARDS}。请检查是否缺片或命名不一致。" >&2
exit 3
fi
echo "== 4/7 合并分片 -> 临时 FP32优先单文件 pytorch_model.bin不支持则分片=="
rm -rf "${TMP_PT_DIR}"
mkdir -p "${TMP_PT_DIR}"
python - <<PY
import os, json, glob, torch
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
TMP_PT_DIR = r"${TMP_PT_DIR}"
STAGING_BASE = r"${STAGING_BASE}"
TAG = r"${TAG}"
sd_single = os.path.join(TMP_PT_DIR, "pytorch_model.bin")
idx_json = os.path.join(TMP_PT_DIR, "pytorch_model.bin.index.json")
# 优先:新式接口,直接写单文件
ok = False
try:
convert_zero_checkpoint_to_fp32_state_dict(
checkpoint_dir=STAGING_BASE,
tag=TAG,
output_file=sd_single,
safe_serialization=False,
)
ok = True
except TypeError:
# 回退:写目录(多分片)
convert_zero_checkpoint_to_fp32_state_dict(
checkpoint_dir=STAGING_BASE,
tag=TAG,
output_dir=TMP_PT_DIR,
safe_serialization=False,
)
# 若写成分片且存在索引,记录一下
if os.path.exists(sd_single):
print("合并完成FP32 单文件):", sd_single)
else:
shards = sorted(glob.glob(os.path.join(TMP_PT_DIR, "pytorch_model-*.bin")))
if os.path.exists(idx_json):
with open(idx_json) as f: j = json.load(f)
n = len(set(j.get("weight_map", {}).values()))
print(f"合并完成FP32 多分片):{n} 片,索引 {idx_json}")
else:
print(f"合并完成FP32 多分片):{len(shards)} 片(无 index.json")
PY
# ★★★ 从参考模型复制 config/tokenizer 到临时 FP32 目录(装载需要)★★★
echo "== 4.1/7 从参考模型复制 config/tokenizer 到临时 FP32 目录(装载需要)=="
for f in config.json generation_config.json tokenizer_config.json tokenizer.json merges.txt vocab.json special_tokens_map.json added_tokens.json; do
[[ -f "${REF_MODEL_DIR}/${f}" ]] && cp -f "${REF_MODEL_DIR}/${f}" "${TMP_PT_DIR}/" || true
done
echo "== 5/7 装载 REF 结构 + 灌入 FP32若检测到 LoRA → 导出 adapter否则保存 BF16 分片 safetensors =="
python - <<'PY'
import os, re, json, sys, torch, shutil, glob
from transformers import AutoConfig, AutoModelForCausalLM
from safetensors.torch import save_file
TMP_PT_DIR = os.environ["TMP_PT_DIR"]
REF_DIR = os.environ["REF_MODEL_DIR"]
OUT_DIR = os.environ["OUT_DIR"]
MAX_SHARD_SIZE = os.environ.get("MAX_SHARD_SIZE","5GB")
sd_single = os.path.join(TMP_PT_DIR, "pytorch_model.bin")
idx_json = os.path.join(TMP_PT_DIR, "pytorch_model.bin.index.json")
def log(*a, **k): print(*a, **k, flush=True)
def parse_index(idx_path):
with open(idx_path) as f:
j = json.load(f)
weight_map = j.get("weight_map", {})
shard_to_keys = {}
for k, shard in weight_map.items():
shard_to_keys.setdefault(shard, []).append(k)
return weight_map, shard_to_keys
def detect_lora_from_index(weight_map):
# 直接从权重名判断,不用先加载大权重
has = any((".lora_A" in k) or (".lora_B" in k) or (".lora_alpha" in k) for k in weight_map)
return has
def stream_collect_lora_from_shards(shard_to_keys):
lora_state = {}
lora_keys = []
for shard, keys in sorted(shard_to_keys.items()):
pick = [k for k in keys if (".lora_A" in k) or (".lora_B" in k) or (".lora_alpha" in k)]
if not pick: continue
part = torch.load(os.path.join(TMP_PT_DIR, shard), map_location="cpu")
for k in pick:
if k in part: lora_state[k] = part[k]
else:
# 极少数 DS 版本权重名不完全一致,容错跳过
pass
lora_keys.extend(pick)
return lora_state, lora_keys
def stream_load_full_into_model(model, shard_to_keys):
missing_total = 0
unexpected_total = 0
for shard, keys in sorted(shard_to_keys.items()):
part = torch.load(os.path.join(TMP_PT_DIR, shard), map_location="cpu")
m, u = model.load_state_dict(part, strict=False)
missing_total += len(m)
unexpected_total += len(u)
log(f"[load] missing_total={missing_total} unexpected_total={unexpected_total}")
log("[load] ref model from:", REF_DIR)
cfg = AutoConfig.from_pretrained(REF_DIR, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
REF_DIR, config=cfg, trust_remote_code=True,
torch_dtype=torch.float32, low_cpu_mem_usage=False, device_map={"": "cpu"}
)
state = None
has_lora = False
if os.path.exists(sd_single):
log("[fp32] detected single file:", sd_single)
state = torch.load(sd_single, map_location="cpu")
# 去 'module.' 前缀
state = { (k.split("module.",1)[-1]): v for k, v in state.items() }
has_lora = any((".lora_A" in k) or (".lora_B" in k) or (".lora_alpha" in k) for k in state)
elif os.path.exists(idx_json):
log("[fp32] detected sharded weights with index:", idx_json)
weight_map, shard_to_keys = parse_index(idx_json)
has_lora = detect_lora_from_index(weight_map)
if has_lora:
log("[check] contains LoRA keys: True (stream collecting)")
lora_state, lora_keys = stream_collect_lora_from_shards(shard_to_keys)
if not lora_state:
print("ERR: 识别到 LoRA 但未筛出权重;中止。", file=sys.stderr); sys.exit(10)
# 估 r / alpha
lora_A_keys = [k for k in lora_keys if k.endswith(".lora_A.weight")]
if lora_A_keys:
# 找到第一个 lora_A 所在分片并读取 shape
k0 = lora_A_keys[0]
shard = next(s for s, ks in shard_to_keys.items() if k0 in ks)
part = torch.load(os.path.join(TMP_PT_DIR, shard), map_location="cpu")
r = part[k0].shape[0]
a = part.get(k0.replace(".lora_A.weight", ".lora_alpha"))
alpha = int(a.item()) if (a is not None and hasattr(a, "item")) else int(r)
else:
r, alpha = 16, 16
adapters_dir = os.path.join(OUT_DIR, "adapters")
os.makedirs(adapters_dir, exist_ok=True)
save_file(lora_state, os.path.join(adapters_dir, "adapter_model.safetensors"))
targets = sorted(set(re.sub(r"\.lora_(A|B)\.weight$", "", k) for k in lora_A_keys))
target_modules = sorted(set(t.split(".")[-1] for t in targets)) or ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]
adapter_cfg = {
"peft_type": "LORA",
"base_model_name_or_path": REF_DIR,
"r": int(r),
"lora_alpha": int(alpha),
"lora_dropout": 0.0,
"bias": "none",
"task_type": "CAUSAL_LM",
"target_modules": target_modules
}
with open(os.path.join(adapters_dir, "adapter_config.json"), "w", encoding="utf-8") as f:
json.dump(adapter_cfg, f, ensure_ascii=False, indent=2)
# 复制 tokenizer/generation/config
for f in ("tokenizer_config.json","tokenizer.json","merges.txt","vocab.json","special_tokens_map.json","added_tokens.json","generation_config.json","config.json"):
src = os.path.join(REF_DIR, f)
if os.path.exists(src):
dst = os.path.join(adapters_dir, f)
if not os.path.exists(dst):
try: shutil.copy(src, dst)
except Exception: pass
log("[save] 导出了 LoRA 适配器 →", adapters_dir)
log("INFO: 可用 Transformers/vLLM/SGLang 以『REF_MODEL + adapters/adapter_model.safetensors』方式推理。")
sys.exit(0)
else:
log("[check] contains LoRA keys: False (stream loading into model)")
stream_load_full_into_model(model, shard_to_keys)
else:
# 无单文件也无 index.json尝试兜底按分片名加载
shard_glob = sorted(glob.glob(os.path.join(TMP_PT_DIR, "pytorch_model-*.bin")))
if not shard_glob:
print("ERR: 未找到单文件或分片 FP32pytorch_model.bin / .index.json / 分片)", file=sys.stderr); sys.exit(9)
log(f"[fp32] detected {len(shard_glob)} shards (no index.json), brute-load")
# 粗暴合并(可能占内存,但作为兜底)
state = {}
for sf in shard_glob:
part = torch.load(sf, map_location="cpu")
state.update(part)
state = { (k.split("module.",1)[-1]): v for k, v in state.items() }
has_lora = any((".lora_A" in k) or (".lora_B" in k) or (".lora_alpha" in k) for k in state)
if state is not None and has_lora:
log("[check] contains LoRA keys: True")
lora_state = {k: v for k, v in state.items()
if (".lora_A" in k) or (".lora_B" in k) or (".lora_alpha" in k)}
if not lora_state:
print("ERR: 识别到 LoRA 但未筛出权重;中止。", file=sys.stderr); sys.exit(10)
lora_A_keys = [k for k in lora_state if k.endswith(".lora_A.weight")]
if lora_A_keys:
r = state[lora_A_keys[0]].shape[0]
a = state.get(lora_A_keys[0].replace(".lora_A.weight", ".lora_alpha"))
alpha = int(a.item()) if (a is not None and hasattr(a, "item")) else int(r)
else:
r, alpha = 16, 16
targets = sorted(set(re.sub(r"\.lora_(A|B)\.weight$", "", k) for k in lora_A_keys))
target_modules = sorted(set(t.split(".")[-1] for t in targets)) or ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]
adapters_dir = os.path.join(OUT_DIR, "adapters")
os.makedirs(adapters_dir, exist_ok=True)
save_file(lora_state, os.path.join(adapters_dir, "adapter_model.safetensors"))
adapter_cfg = {
"peft_type": "LORA",
"base_model_name_or_path": REF_DIR,
"r": int(r),
"lora_alpha": int(alpha),
"lora_dropout": 0.0,
"bias": "none",
"task_type": "CAUSAL_LM",
"target_modules": target_modules
}
with open(os.path.join(adapters_dir, "adapter_config.json"), "w", encoding="utf-8") as f:
json.dump(adapter_cfg, f, ensure_ascii=False, indent=2)
for f in ("tokenizer_config.json","tokenizer.json","merges.txt","vocab.json","special_tokens_map.json","added_tokens.json","generation_config.json","config.json"):
src = os.path.join(REF_DIR, f)
if os.path.exists(src):
dst = os.path.join(adapters_dir, f)
if not os.path.exists(dst):
try: shutil.copy(src, dst)
except Exception: pass
log("[save] 导出了 LoRA 适配器 →", adapters_dir)
log("INFO: 可用 Transformers/vLLM/SGLang 以『REF_MODEL + adapters/adapter_model.safetensors』方式推理。")
sys.exit(0)
# ——走到这里表示“无 LoRA”把 FP32 权重注入模型并保存 BF16 safetensors——
if state is not None:
missing, unexpected = model.load_state_dict(state, strict=False)
log(f"[load] missing={len(missing)} unexpected={len(unexpected)}")
# untie如需要
try:
emb = model.get_input_embeddings().weight if hasattr(model, "get_input_embeddings") else None
head = model.lm_head.weight if hasattr(model, "lm_head") else None
if emb is not None and head is not None and emb.data_ptr() == head.data_ptr():
with torch.no_grad():
model.lm_head.weight = torch.nn.Parameter(head.detach().clone())
log("[fix] untied lm_head from embed_tokens")
except Exception as e:
log("[fix] skip untie check:", e)
model.to(dtype=torch.bfloat16)
os.makedirs(OUT_DIR, exist_ok=True)
model.save_pretrained(OUT_DIR, safe_serialization=True, max_shard_size=MAX_SHARD_SIZE)
log("[save] BF16 safetensors →", OUT_DIR)
PY
echo "== 5.1/7 拷贝(/补齐)最终目录的 tokenizer/config 工件(如存在)=="
for f in config.json generation_config.json tokenizer_config.json tokenizer.json merges.txt vocab.json special_tokens_map.json added_tokens.json; do
[[ -f "${REF_MODEL_DIR}/${f}" ]] && cp -n "${REF_MODEL_DIR}/${f}" "${OUT_DIR}/" || true
done
echo "== 6/7 自检(索引与 config=="
python - <<'PY'
import os, json
out_dir = os.environ.get("OUT_DIR")
idx = os.path.join(out_dir, "model.safetensors.index.json")
if os.path.exists(idx):
with open(idx) as f: j = json.load(f)
print(f"OK: 找到 safetensors 索引:{idx}(参数条目 {len(j.get('weight_map', {}))}")
else:
sfts = [x for x in os.listdir(out_dir) if x.endswith(".safetensors")]
if len(sfts) == 1:
print(f"NOTE: 单分片 safetensors{sfts[0]}")
elif len(sfts) == 0:
# LoRA 分支下可能没有模型分片(只导出 adapters
print("NOTE: 未发现模型分片(若已导出 adapters/ 则属正常)")
else:
print("WARN: 未找到 model.safetensors.index.json且分片数 != 1")
try:
from transformers import AutoConfig
cfg = AutoConfig.from_pretrained(out_dir, trust_remote_code=True)
print("OK: 读取到 config", cfg.model_type, "hidden:", getattr(cfg,'hidden_size',None), "layers:", getattr(cfg,'num_hidden_layers',None))
except Exception as e:
print("NOTE: 最终目录无 config若为纯 adapters 导出则正常):", e)
PY
echo "== 7/7 清理提示 =="
echo "临时 FP32 目录:${TMP_PT_DIR}"
echo "输出目录(无 LoRA=密集权重;有 LoRA=adapters/${OUT_DIR}"
echo "完成。"