From 8cc7a41d481053c57c7c78ee1f693a223755aefc Mon Sep 17 00:00:00 2001 From: hailin Date: Tue, 23 Sep 2025 21:09:08 +0800 Subject: [PATCH] . --- lora_merge_zero3_safetensors.sh | 114 ++++++++++++++++++++++---------- 1 file changed, 79 insertions(+), 35 deletions(-) diff --git a/lora_merge_zero3_safetensors.sh b/lora_merge_zero3_safetensors.sh index f93ba77..098d1f4 100755 --- a/lora_merge_zero3_safetensors.sh +++ b/lora_merge_zero3_safetensors.sh @@ -198,6 +198,16 @@ def stream_load_full_into_model(model, shard_to_keys): unexpected_total += len(u) log(f"[load] missing_total={missing_total} unexpected_total={unexpected_total}") +def is_lora_A(k:str)->bool: + return k.endswith(".lora_A.weight") or k.endswith(".lora_A.default.weight") + +def alpha_key_for(kA:str)->str: + if kA.endswith(".lora_A.weight"): + return kA.replace(".lora_A.weight", ".lora_alpha") + if kA.endswith(".lora_A.default.weight"): + return kA.replace(".lora_A.default.weight", ".lora_alpha.default") + return "" + log("[load] ref model from:", REF_DIR) cfg = AutoConfig.from_pretrained(REF_DIR, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( @@ -223,15 +233,15 @@ elif os.path.exists(idx_json): lora_state, lora_keys = stream_collect_lora_from_shards(shard_to_keys) if not lora_state: print("ERR: 识别到 LoRA 但未筛出权重;中止。", file=sys.stderr); sys.exit(10) - # 估 r / alpha - lora_A_keys = [k for k in lora_keys if k.endswith(".lora_A.weight")] + # 估 r / alpha(兼容 .default) + lora_A_keys = [k for k in lora_keys if is_lora_A(k)] if lora_A_keys: - # 找到第一个 lora_A 所在分片并读取 shape k0 = lora_A_keys[0] shard = next(s for s, ks in shard_to_keys.items() if k0 in ks) part = torch.load(os.path.join(TMP_PT_DIR, shard), map_location="cpu") r = part[k0].shape[0] - a = part.get(k0.replace(".lora_A.weight", ".lora_alpha")) + ak = alpha_key_for(k0) + a = part.get(ak) alpha = int(a.item()) if (a is not None and hasattr(a, "item")) else int(r) else: r, alpha = 16, 16 @@ -240,21 +250,37 @@ elif os.path.exists(idx_json): os.makedirs(adapters_dir, exist_ok=True) save_file(lora_state, os.path.join(adapters_dir, "adapter_model.safetensors")) - targets = sorted(set(re.sub(r"\.lora_(A|B)\.weight$", "", k) for k in lora_A_keys)) + targets = sorted(set(re.sub(r"\.lora_(A|B)(?:\.default)?\.weight$", "", k) for k in lora_A_keys)) target_modules = sorted(set(t.split(".")[-1] for t in targets)) or ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"] - adapter_cfg = { - "peft_type": "LORA", - "base_model_name_or_path": REF_DIR, - "r": int(r), - "lora_alpha": int(alpha), - "lora_dropout": 0.0, - "bias": "none", - "task_type": "CAUSAL_LM", - "target_modules": target_modules - } - with open(os.path.join(adapters_dir, "adapter_config.json"), "w", encoding="utf-8") as f: - json.dump(adapter_cfg, f, ensure_ascii=False, indent=2) + # ☆ 优先复用训练时的 adapter_config.json(若存在),只覆盖 base_model_name_or_path + copied_cfg = False + try: + CKPT_ROOT = os.path.abspath(os.path.join(TMP_PT_DIR, os.pardir)) + src_cfg = os.path.join(CKPT_ROOT, "adapter_config.json") + if os.path.exists(src_cfg): + with open(src_cfg, "r", encoding="utf-8") as f: + adapter_cfg = json.load(f) + adapter_cfg["base_model_name_or_path"] = REF_DIR + with open(os.path.join(adapters_dir, "adapter_config.json"), "w", encoding="utf-8") as f: + json.dump(adapter_cfg, f, ensure_ascii=False, indent=2) + copied_cfg = True + except Exception: + copied_cfg = False + + if not copied_cfg: + adapter_cfg = { + "peft_type": "LORA", + "base_model_name_or_path": REF_DIR, + "r": int(r), + "lora_alpha": int(alpha), + "lora_dropout": 0.05, + "bias": "none", + "task_type": "CAUSAL_LM", + "target_modules": target_modules + } + with open(os.path.join(adapters_dir, "adapter_config.json"), "w", encoding="utf-8") as f: + json.dump(adapter_cfg, f, ensure_ascii=False, indent=2) # 复制 tokenizer/generation/config for f in ("tokenizer_config.json","tokenizer.json","merges.txt","vocab.json","special_tokens_map.json","added_tokens.json","generation_config.json","config.json"): @@ -266,7 +292,7 @@ elif os.path.exists(idx_json): except Exception: pass log("[save] 导出了 LoRA 适配器 →", adapters_dir) - log("INFO: 可用 Transformers/vLLM/SGLang 以『REF_MODEL + adapters/adapter_model.safetensors』方式推理。") + log("INFO: 可用 Transformers/vLLM/SGLang 以『REF模型 + adapters/adapter_model.safetensors』方式推理。") sys.exit(0) else: log("[check] contains LoRA keys: False (stream loading into model)") @@ -292,33 +318,51 @@ if state is not None and has_lora: if not lora_state: print("ERR: 识别到 LoRA 但未筛出权重;中止。", file=sys.stderr); sys.exit(10) - lora_A_keys = [k for k in lora_state if k.endswith(".lora_A.weight")] + lora_A_keys = [k for k in lora_state if k.endswith(".lora_A.weight") or k.endswith(".lora_A.default.weight")] if lora_A_keys: - r = state[lora_A_keys[0]].shape[0] - a = state.get(lora_A_keys[0].replace(".lora_A.weight", ".lora_alpha")) + k0 = lora_A_keys[0] + r = state[k0].shape[0] + ak = k0.replace(".lora_A.weight", ".lora_alpha").replace(".lora_A.default.weight", ".lora_alpha.default") + a = state.get(ak) alpha = int(a.item()) if (a is not None and hasattr(a, "item")) else int(r) else: r, alpha = 16, 16 - targets = sorted(set(re.sub(r"\.lora_(A|B)\.weight$", "", k) for k in lora_A_keys)) + targets = sorted(set(re.sub(r"\.lora_(A|B)(?:\.default)?\.weight$", "", k) for k in lora_A_keys)) target_modules = sorted(set(t.split(".")[-1] for t in targets)) or ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"] adapters_dir = os.path.join(OUT_DIR, "adapters") os.makedirs(adapters_dir, exist_ok=True) save_file(lora_state, os.path.join(adapters_dir, "adapter_model.safetensors")) - adapter_cfg = { - "peft_type": "LORA", - "base_model_name_or_path": REF_DIR, - "r": int(r), - "lora_alpha": int(alpha), - "lora_dropout": 0.0, - "bias": "none", - "task_type": "CAUSAL_LM", - "target_modules": target_modules - } - with open(os.path.join(adapters_dir, "adapter_config.json"), "w", encoding="utf-8") as f: - json.dump(adapter_cfg, f, ensure_ascii=False, indent=2) + # ☆ 优先复用训练时的 adapter_config.json(若存在),只覆盖 base_model_name_or_path + copied_cfg = False + try: + CKPT_ROOT = os.path.abspath(os.path.join(TMP_PT_DIR, os.pardir)) + src_cfg = os.path.join(CKPT_ROOT, "adapter_config.json") + if os.path.exists(src_cfg): + with open(src_cfg, "r", encoding="utf-8") as f: + adapter_cfg = json.load(f) + adapter_cfg["base_model_name_or_path"] = REF_DIR + with open(os.path.join(adapters_dir, "adapter_config.json"), "w", encoding="utf-8") as f: + json.dump(adapter_cfg, f, ensure_ascii=False, indent=2) + copied_cfg = True + except Exception: + copied_cfg = False + + if not copied_cfg: + adapter_cfg = { + "peft_type": "LORA", + "base_model_name_or_path": REF_DIR, + "r": int(r), + "lora_alpha": int(alpha), + "lora_dropout": 0.05, + "bias": "none", + "task_type": "CAUSAL_LM", + "target_modules": target_modules + } + with open(os.path.join(adapters_dir, "adapter_config.json"), "w", encoding="utf-8") as f: + json.dump(adapter_cfg, f, ensure_ascii=False, indent=2) for f in ("tokenizer_config.json","tokenizer.json","merges.txt","vocab.json","special_tokens_map.json","added_tokens.json","generation_config.json","config.json"): src = os.path.join(REF_DIR, f) @@ -329,7 +373,7 @@ if state is not None and has_lora: except Exception: pass log("[save] 导出了 LoRA 适配器 →", adapters_dir) - log("INFO: 可用 Transformers/vLLM/SGLang 以『REF_MODEL + adapters/adapter_model.safetensors』方式推理。") + log("INFO: 可用 Transformers/vLLM/SGLang 以『REF模型 + adapters/adapter_model.safetensors』方式推理。") sys.exit(0) # ——走到这里表示“无 LoRA”,把 FP32 权重注入模型并保存 BF16 safetensors——