diff --git a/lora_merge_zero3_safetensors.sh b/lora_merge_zero3_safetensors.sh index f57a68f..3c0bd4d 100644 --- a/lora_merge_zero3_safetensors.sh +++ b/lora_merge_zero3_safetensors.sh @@ -9,7 +9,7 @@ AGGREGATOR_HOST="tn06" # 本脚本运行/汇总所在机器 EXPECTED_SHARDS_PER_HOST=4 # 每机应写出分片数(按你的并行布局) MAX_SHARD_SIZE="5GB" -# ★★★ 新增:参考模型目录(你用来做 LoRA 的 Qwen3-32B 或其 Instruct 变体) ★★★ +# ★★★ 参考模型目录(用来做 LoRA 的 Qwen3-32B 或其 Instruct 变体) ★★★ REF_MODEL_DIR="/home/test/Qwen3-32B" STRICT_PRECHECK=true # true: 预检不通过就退出;false: 仅告警 @@ -22,7 +22,7 @@ EXPECTED_TOTAL_SHARDS=$(( EXPECTED_SHARDS_PER_HOST * ${#HOSTS[@]} )) STAGING_BASE="${CKPT_ROOT}/_staging" STAGING_TAG_DIR="${STAGING_BASE}/${TAG}" OUT_DIR="${CKPT_ROOT}/merged-${TAG}" -TMP_PT_DIR="${CKPT_ROOT}/_tmp-fp32-pt-${TAG}" # 临时 FP32(pytorch_model.bin)目录 +TMP_PT_DIR="${CKPT_ROOT}/_tmp-fp32-pt-${TAG}" # 临时 FP32 输出目录 export OUT_DIR TMP_PT_DIR MAX_SHARD_SIZE REF_MODEL_DIR # ================================= @@ -90,30 +90,62 @@ if (( CNT != EXPECTED_TOTAL_SHARDS )); then exit 3 fi -echo "== 4/7 合并分片 -> 临时 FP32(PyTorch .bin),避免共享权重导致 safetensors 报错 ==" +echo "== 4/7 合并分片 -> 临时 FP32(优先单文件 pytorch_model.bin;不支持则分片)==" rm -rf "${TMP_PT_DIR}" mkdir -p "${TMP_PT_DIR}" python - <