This commit is contained in:
hailin 2025-09-14 17:01:53 +08:00
parent 815827e031
commit c51a3bbedb
1 changed files with 201 additions and 58 deletions

View File

@ -9,7 +9,7 @@ AGGREGATOR_HOST="tn06" # 本脚本运行/汇总所在机器
EXPECTED_SHARDS_PER_HOST=4 # 每机应写出分片数(按你的并行布局) EXPECTED_SHARDS_PER_HOST=4 # 每机应写出分片数(按你的并行布局)
MAX_SHARD_SIZE="5GB" MAX_SHARD_SIZE="5GB"
# ★★★ 新增:参考模型目录(用来做 LoRA 的 Qwen3-32B 或其 Instruct 变体) ★★★ # ★★★ 参考模型目录(用来做 LoRA 的 Qwen3-32B 或其 Instruct 变体) ★★★
REF_MODEL_DIR="/home/test/Qwen3-32B" REF_MODEL_DIR="/home/test/Qwen3-32B"
STRICT_PRECHECK=true # true: 预检不通过就退出false: 仅告警 STRICT_PRECHECK=true # true: 预检不通过就退出false: 仅告警
@ -22,7 +22,7 @@ EXPECTED_TOTAL_SHARDS=$(( EXPECTED_SHARDS_PER_HOST * ${#HOSTS[@]} ))
STAGING_BASE="${CKPT_ROOT}/_staging" STAGING_BASE="${CKPT_ROOT}/_staging"
STAGING_TAG_DIR="${STAGING_BASE}/${TAG}" STAGING_TAG_DIR="${STAGING_BASE}/${TAG}"
OUT_DIR="${CKPT_ROOT}/merged-${TAG}" OUT_DIR="${CKPT_ROOT}/merged-${TAG}"
TMP_PT_DIR="${CKPT_ROOT}/_tmp-fp32-pt-${TAG}" # 临时 FP32pytorch_model.bin目录 TMP_PT_DIR="${CKPT_ROOT}/_tmp-fp32-pt-${TAG}" # 临时 FP32 输出目录
export OUT_DIR TMP_PT_DIR MAX_SHARD_SIZE REF_MODEL_DIR export OUT_DIR TMP_PT_DIR MAX_SHARD_SIZE REF_MODEL_DIR
# ================================= # =================================
@ -90,30 +90,62 @@ if (( CNT != EXPECTED_TOTAL_SHARDS )); then
exit 3 exit 3
fi fi
echo "== 4/7 合并分片 -> 临时 FP32PyTorch .bin避免共享权重导致 safetensors 报错 ==" echo "== 4/7 合并分片 -> 临时 FP32优先单文件 pytorch_model.bin不支持则分片=="
rm -rf "${TMP_PT_DIR}" rm -rf "${TMP_PT_DIR}"
mkdir -p "${TMP_PT_DIR}" mkdir -p "${TMP_PT_DIR}"
python - <<PY python - <<PY
import os, json, glob, torch
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
TMP_PT_DIR = r"${TMP_PT_DIR}"
STAGING_BASE = r"${STAGING_BASE}"
TAG = r"${TAG}"
sd_single = os.path.join(TMP_PT_DIR, "pytorch_model.bin")
idx_json = os.path.join(TMP_PT_DIR, "pytorch_model.bin.index.json")
# 优先:新式接口,直接写单文件
ok = False
try:
convert_zero_checkpoint_to_fp32_state_dict( convert_zero_checkpoint_to_fp32_state_dict(
checkpoint_dir=r"${STAGING_BASE}", checkpoint_dir=STAGING_BASE,
output_dir=r"${TMP_PT_DIR}", tag=TAG,
tag=r"${TAG}", output_file=sd_single,
safe_serialization=False, # 先落成 .binFP32 safe_serialization=False,
) )
print("合并完成FP32 .bin", r"${TMP_PT_DIR}") ok = True
except TypeError:
# 回退:写目录(多分片)
convert_zero_checkpoint_to_fp32_state_dict(
checkpoint_dir=STAGING_BASE,
tag=TAG,
output_dir=TMP_PT_DIR,
safe_serialization=False,
)
# 若写成分片且存在索引,记录一下
if os.path.exists(sd_single):
print("合并完成FP32 单文件):", sd_single)
else:
shards = sorted(glob.glob(os.path.join(TMP_PT_DIR, "pytorch_model-*.bin")))
if os.path.exists(idx_json):
with open(idx_json) as f: j = json.load(f)
n = len(set(j.get("weight_map", {}).values()))
print(f"合并完成FP32 多分片):{n} 片,索引 {idx_json}")
else:
print(f"合并完成FP32 多分片):{len(shards)} 片(无 index.json")
PY PY
# ★★★ 改动点:从参考模型复制 config/tokenizer且强制覆盖不要从 CKPT_ROOT 拷LoRA 目录通常没 config.json★★★ # ★★★ 从参考模型复制 config/tokenizer 到临时 FP32 目录(装载需要)★★★
echo "== 4.1/7 从参考模型复制 config/tokenizer 到临时 FP32 目录(装载需要)==" echo "== 4.1/7 从参考模型复制 config/tokenizer 到临时 FP32 目录(装载需要)=="
for f in config.json generation_config.json tokenizer_config.json tokenizer.json merges.txt vocab.json special_tokens_map.json added_tokens.json; do for f in config.json generation_config.json tokenizer_config.json tokenizer.json merges.txt vocab.json special_tokens_map.json added_tokens.json; do
[[ -f "${REF_MODEL_DIR}/${f}" ]] && cp -f "${REF_MODEL_DIR}/${f}" "${TMP_PT_DIR}/" || true [[ -f "${REF_MODEL_DIR}/${f}" ]] && cp -f "${REF_MODEL_DIR}/${f}" "${TMP_PT_DIR}/" || true
done done
echo "== 5/7 装载 REF 模型结构 + 灌入 FP32 权重;如检测到 LoRA则导出 adapter否则保存 BF16 分片 safetensors ==" echo "== 5/7 装载 REF 结构 + 灌入 FP32若检测到 LoRA → 导出 adapter否则保存 BF16 分片 safetensors =="
python - <<'PY' python - <<'PY'
import os, re, json, sys, torch, shutil import os, re, json, sys, torch, shutil, glob
from transformers import AutoConfig, AutoModelForCausalLM from transformers import AutoConfig, AutoModelForCausalLM
from safetensors.torch import save_file from safetensors.torch import save_file
@ -122,42 +154,152 @@ REF_DIR = os.environ["REF_MODEL_DIR"]
OUT_DIR = os.environ["OUT_DIR"] OUT_DIR = os.environ["OUT_DIR"]
MAX_SHARD_SIZE = os.environ.get("MAX_SHARD_SIZE","5GB") MAX_SHARD_SIZE = os.environ.get("MAX_SHARD_SIZE","5GB")
print("[load] ref model from:", REF_DIR) sd_single = os.path.join(TMP_PT_DIR, "pytorch_model.bin")
cfg = AutoConfig.from_pretrained(REF_DIR, trust_remote_code=True) # 确保 model_type=qwen3 idx_json = os.path.join(TMP_PT_DIR, "pytorch_model.bin.index.json")
def log(*a, **k): print(*a, **k, flush=True)
def parse_index(idx_path):
with open(idx_path) as f:
j = json.load(f)
weight_map = j.get("weight_map", {})
shard_to_keys = {}
for k, shard in weight_map.items():
shard_to_keys.setdefault(shard, []).append(k)
return weight_map, shard_to_keys
def detect_lora_from_index(weight_map):
# 直接从权重名判断,不用先加载大权重
has = any((".lora_A" in k) or (".lora_B" in k) or (".lora_alpha" in k) for k in weight_map)
return has
def stream_collect_lora_from_shards(shard_to_keys):
lora_state = {}
lora_keys = []
for shard, keys in sorted(shard_to_keys.items()):
pick = [k for k in keys if (".lora_A" in k) or (".lora_B" in k) or (".lora_alpha" in k)]
if not pick: continue
part = torch.load(os.path.join(TMP_PT_DIR, shard), map_location="cpu")
for k in pick:
if k in part: lora_state[k] = part[k]
else:
# 极少数 DS 版本权重名不完全一致,容错跳过
pass
lora_keys.extend(pick)
return lora_state, lora_keys
def stream_load_full_into_model(model, shard_to_keys):
missing_total = 0
unexpected_total = 0
for shard, keys in sorted(shard_to_keys.items()):
part = torch.load(os.path.join(TMP_PT_DIR, shard), map_location="cpu")
m, u = model.load_state_dict(part, strict=False)
missing_total += len(m)
unexpected_total += len(u)
log(f"[load] missing_total={missing_total} unexpected_total={unexpected_total}")
log("[load] ref model from:", REF_DIR)
cfg = AutoConfig.from_pretrained(REF_DIR, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
REF_DIR, config=cfg, trust_remote_code=True, REF_DIR, config=cfg, trust_remote_code=True,
torch_dtype=torch.float32, low_cpu_mem_usage=False, device_map={"": "cpu"} torch_dtype=torch.float32, low_cpu_mem_usage=False, device_map={"": "cpu"}
) )
sd_path = os.path.join(TMP_PT_DIR, "pytorch_model.bin") state = None
if not os.path.exists(sd_path): has_lora = False
print("ERR: 未找到", sd_path, file=sys.stderr); sys.exit(9)
state = torch.load(sd_path, map_location="cpu")
# 去掉可能的 'module.' 前缀 if os.path.exists(sd_single):
log("[fp32] detected single file:", sd_single)
state = torch.load(sd_single, map_location="cpu")
# 去 'module.' 前缀
state = { (k.split("module.",1)[-1]): v for k, v in state.items() } state = { (k.split("module.",1)[-1]): v for k, v in state.items() }
# 识别 LoRA
has_lora = any((".lora_A" in k) or (".lora_B" in k) or (".lora_alpha" in k) for k in state) has_lora = any((".lora_A" in k) or (".lora_B" in k) or (".lora_alpha" in k) for k in state)
print("[check] contains LoRA keys:", has_lora) elif os.path.exists(idx_json):
log("[fp32] detected sharded weights with index:", idx_json)
weight_map, shard_to_keys = parse_index(idx_json)
has_lora = detect_lora_from_index(weight_map)
if has_lora: if has_lora:
# ——导出 LoRA 适配器(不合到基座)—— log("[check] contains LoRA keys: True (stream collecting)")
lora_state, lora_keys = stream_collect_lora_from_shards(shard_to_keys)
if not lora_state:
print("ERR: 识别到 LoRA 但未筛出权重;中止。", file=sys.stderr); sys.exit(10)
# 估 r / alpha
lora_A_keys = [k for k in lora_keys if k.endswith(".lora_A.weight")]
if lora_A_keys:
# 找到第一个 lora_A 所在分片并读取 shape
k0 = lora_A_keys[0]
shard = next(s for s, ks in shard_to_keys.items() if k0 in ks)
part = torch.load(os.path.join(TMP_PT_DIR, shard), map_location="cpu")
r = part[k0].shape[0]
a = part.get(k0.replace(".lora_A.weight", ".lora_alpha"))
alpha = int(a.item()) if (a is not None and hasattr(a, "item")) else int(r)
else:
r, alpha = 16, 16
adapters_dir = os.path.join(OUT_DIR, "adapters")
os.makedirs(adapters_dir, exist_ok=True)
save_file(lora_state, os.path.join(adapters_dir, "adapter_model.safetensors"))
targets = sorted(set(re.sub(r"\.lora_(A|B)\.weight$", "", k) for k in lora_A_keys))
target_modules = sorted(set(t.split(".")[-1] for t in targets)) or ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]
adapter_cfg = {
"peft_type": "LORA",
"base_model_name_or_path": REF_DIR,
"r": int(r),
"lora_alpha": int(alpha),
"lora_dropout": 0.0,
"bias": "none",
"task_type": "CAUSAL_LM",
"target_modules": target_modules
}
with open(os.path.join(adapters_dir, "adapter_config.json"), "w", encoding="utf-8") as f:
json.dump(adapter_cfg, f, ensure_ascii=False, indent=2)
# 复制 tokenizer/generation/config
for f in ("tokenizer_config.json","tokenizer.json","merges.txt","vocab.json","special_tokens_map.json","added_tokens.json","generation_config.json","config.json"):
src = os.path.join(REF_DIR, f)
if os.path.exists(src):
dst = os.path.join(adapters_dir, f)
if not os.path.exists(dst):
try: shutil.copy(src, dst)
except Exception: pass
log("[save] 导出了 LoRA 适配器 →", adapters_dir)
log("INFO: 可用 Transformers/vLLM/SGLang 以『REF_MODEL + adapters/adapter_model.safetensors』方式推理。")
sys.exit(0)
else:
log("[check] contains LoRA keys: False (stream loading into model)")
stream_load_full_into_model(model, shard_to_keys)
else:
# 无单文件也无 index.json尝试兜底按分片名加载
shard_glob = sorted(glob.glob(os.path.join(TMP_PT_DIR, "pytorch_model-*.bin")))
if not shard_glob:
print("ERR: 未找到单文件或分片 FP32pytorch_model.bin / .index.json / 分片)", file=sys.stderr); sys.exit(9)
log(f"[fp32] detected {len(shard_glob)} shards (no index.json), brute-load")
# 粗暴合并(可能占内存,但作为兜底)
state = {}
for sf in shard_glob:
part = torch.load(sf, map_location="cpu")
state.update(part)
state = { (k.split("module.",1)[-1]): v for k, v in state.items() }
has_lora = any((".lora_A" in k) or (".lora_B" in k) or (".lora_alpha" in k) for k in state)
if state is not None and has_lora:
log("[check] contains LoRA keys: True")
lora_state = {k: v for k, v in state.items() lora_state = {k: v for k, v in state.items()
if (".lora_A" in k) or (".lora_B" in k) or (".lora_alpha" in k)} if (".lora_A" in k) or (".lora_B" in k) or (".lora_alpha" in k)}
if not lora_state: if not lora_state:
print("ERR: 识别到 LoRA 但未筛出权重;中止。", file=sys.stderr); sys.exit(10) print("ERR: 识别到 LoRA 但未筛出权重;中止。", file=sys.stderr); sys.exit(10)
lora_A_keys = [k for k in lora_state if k.endswith(".lora_A.weight")] lora_A_keys = [k for k in lora_state if k.endswith(".lora_A.weight")]
r = state[lora_A_keys[0]].shape[0] if lora_A_keys else 16 if lora_A_keys:
# alpha优先读到的第一个没有就用 r r = state[lora_A_keys[0]].shape[0]
alpha = r a = state.get(lora_A_keys[0].replace(".lora_A.weight", ".lora_alpha"))
for k in lora_A_keys: alpha = int(a.item()) if (a is not None and hasattr(a, "item")) else int(r)
a = state.get(k.replace(".lora_A.weight", ".lora_alpha")) else:
if a is not None: r, alpha = 16, 16
alpha = int(a.item()); break
# 统计 target_modules叶子模块名
targets = sorted(set(re.sub(r"\.lora_(A|B)\.weight$", "", k) for k in lora_A_keys)) targets = sorted(set(re.sub(r"\.lora_(A|B)\.weight$", "", k) for k in lora_A_keys))
target_modules = sorted(set(t.split(".")[-1] for t in targets)) or ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"] target_modules = sorted(set(t.split(".")[-1] for t in targets)) or ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]
@ -178,7 +320,6 @@ if has_lora:
with open(os.path.join(adapters_dir, "adapter_config.json"), "w", encoding="utf-8") as f: with open(os.path.join(adapters_dir, "adapter_config.json"), "w", encoding="utf-8") as f:
json.dump(adapter_cfg, f, ensure_ascii=False, indent=2) json.dump(adapter_cfg, f, ensure_ascii=False, indent=2)
# 复制 tokenizer/generation/config便于推理端直接使用
for f in ("tokenizer_config.json","tokenizer.json","merges.txt","vocab.json","special_tokens_map.json","added_tokens.json","generation_config.json","config.json"): for f in ("tokenizer_config.json","tokenizer.json","merges.txt","vocab.json","special_tokens_map.json","added_tokens.json","generation_config.json","config.json"):
src = os.path.join(REF_DIR, f) src = os.path.join(REF_DIR, f)
if os.path.exists(src): if os.path.exists(src):
@ -187,12 +328,15 @@ if has_lora:
try: shutil.copy(src, dst) try: shutil.copy(src, dst)
except Exception: pass except Exception: pass
print("[save] 导出了 LoRA 适配器 →", adapters_dir) log("[save] 导出了 LoRA 适配器 →", adapters_dir)
print("INFO: 可用 Transformers/vLLM/SGLang 以『REF_MODEL + adapters/adapter_model.safetensors』方式推理。") log("INFO: 可用 Transformers/vLLM/SGLang 以『REF_MODEL + adapters/adapter_model.safetensors』方式推理。")
else: sys.exit(0)
# ——无 LoRA按密集权重流程保存 BF16 分片 safetensors——
# ——走到这里表示“无 LoRA”把 FP32 权重注入模型并保存 BF16 safetensors——
if state is not None:
missing, unexpected = model.load_state_dict(state, strict=False) missing, unexpected = model.load_state_dict(state, strict=False)
print(f"[load] missing={len(missing)} unexpected={len(unexpected)}") log(f"[load] missing={len(missing)} unexpected={len(unexpected)}")
# untie如需要 # untie如需要
try: try:
emb = model.get_input_embeddings().weight if hasattr(model, "get_input_embeddings") else None emb = model.get_input_embeddings().weight if hasattr(model, "get_input_embeddings") else None
@ -200,14 +344,14 @@ else:
if emb is not None and head is not None and emb.data_ptr() == head.data_ptr(): if emb is not None and head is not None and emb.data_ptr() == head.data_ptr():
with torch.no_grad(): with torch.no_grad():
model.lm_head.weight = torch.nn.Parameter(head.detach().clone()) model.lm_head.weight = torch.nn.Parameter(head.detach().clone())
print("[fix] untied lm_head from embed_tokens") log("[fix] untied lm_head from embed_tokens")
except Exception as e: except Exception as e:
print("[fix] skip untie check:", e) log("[fix] skip untie check:", e)
model.to(dtype=torch.bfloat16) model.to(dtype=torch.bfloat16)
os.makedirs(OUT_DIR, exist_ok=True) os.makedirs(OUT_DIR, exist_ok=True)
model.save_pretrained(OUT_DIR, safe_serialization=True, max_shard_size=MAX_SHARD_SIZE) model.save_pretrained(OUT_DIR, safe_serialization=True, max_shard_size=MAX_SHARD_SIZE)
print("[save] BF16 safetensors →", OUT_DIR) log("[save] BF16 safetensors →", OUT_DIR)
PY PY
echo "== 5.1/7 拷贝(/补齐)最终目录的 tokenizer/config 工件(如存在)==" echo "== 5.1/7 拷贝(/补齐)最终目录的 tokenizer/config 工件(如存在)=="
@ -217,7 +361,7 @@ done
echo "== 6/7 自检(索引与 config==" echo "== 6/7 自检(索引与 config=="
python - <<'PY' python - <<'PY'
import os, json, sys import os, json
out_dir = os.environ.get("OUT_DIR") out_dir = os.environ.get("OUT_DIR")
idx = os.path.join(out_dir, "model.safetensors.index.json") idx = os.path.join(out_dir, "model.safetensors.index.json")
if os.path.exists(idx): if os.path.exists(idx):
@ -231,8 +375,7 @@ else:
# LoRA 分支下可能没有模型分片(只导出 adapters # LoRA 分支下可能没有模型分片(只导出 adapters
print("NOTE: 未发现模型分片(若已导出 adapters/ 则属正常)") print("NOTE: 未发现模型分片(若已导出 adapters/ 则属正常)")
else: else:
print("WARN: 未找到 model.safetensors.index.json且分片数 != 1", file=sys.stderr) print("WARN: 未找到 model.safetensors.index.json且分片数 != 1")
try: try:
from transformers import AutoConfig from transformers import AutoConfig
cfg = AutoConfig.from_pretrained(out_dir, trust_remote_code=True) cfg = AutoConfig.from_pretrained(out_dir, trust_remote_code=True)