This commit is contained in:
hailin 2025-09-23 21:09:08 +08:00
parent d2bdb8af9e
commit 8cc7a41d48
1 changed files with 79 additions and 35 deletions

View File

@ -198,6 +198,16 @@ def stream_load_full_into_model(model, shard_to_keys):
unexpected_total += len(u)
log(f"[load] missing_total={missing_total} unexpected_total={unexpected_total}")
def is_lora_A(k:str)->bool:
return k.endswith(".lora_A.weight") or k.endswith(".lora_A.default.weight")
def alpha_key_for(kA:str)->str:
if kA.endswith(".lora_A.weight"):
return kA.replace(".lora_A.weight", ".lora_alpha")
if kA.endswith(".lora_A.default.weight"):
return kA.replace(".lora_A.default.weight", ".lora_alpha.default")
return ""
log("[load] ref model from:", REF_DIR)
cfg = AutoConfig.from_pretrained(REF_DIR, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
@ -223,15 +233,15 @@ elif os.path.exists(idx_json):
lora_state, lora_keys = stream_collect_lora_from_shards(shard_to_keys)
if not lora_state:
print("ERR: 识别到 LoRA 但未筛出权重;中止。", file=sys.stderr); sys.exit(10)
# 估 r / alpha
lora_A_keys = [k for k in lora_keys if k.endswith(".lora_A.weight")]
# 估 r / alpha(兼容 .default
lora_A_keys = [k for k in lora_keys if is_lora_A(k)]
if lora_A_keys:
# 找到第一个 lora_A 所在分片并读取 shape
k0 = lora_A_keys[0]
shard = next(s for s, ks in shard_to_keys.items() if k0 in ks)
part = torch.load(os.path.join(TMP_PT_DIR, shard), map_location="cpu")
r = part[k0].shape[0]
a = part.get(k0.replace(".lora_A.weight", ".lora_alpha"))
ak = alpha_key_for(k0)
a = part.get(ak)
alpha = int(a.item()) if (a is not None and hasattr(a, "item")) else int(r)
else:
r, alpha = 16, 16
@ -240,15 +250,31 @@ elif os.path.exists(idx_json):
os.makedirs(adapters_dir, exist_ok=True)
save_file(lora_state, os.path.join(adapters_dir, "adapter_model.safetensors"))
targets = sorted(set(re.sub(r"\.lora_(A|B)\.weight$", "", k) for k in lora_A_keys))
targets = sorted(set(re.sub(r"\.lora_(A|B)(?:\.default)?\.weight$", "", k) for k in lora_A_keys))
target_modules = sorted(set(t.split(".")[-1] for t in targets)) or ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]
# ☆ 优先复用训练时的 adapter_config.json若存在只覆盖 base_model_name_or_path
copied_cfg = False
try:
CKPT_ROOT = os.path.abspath(os.path.join(TMP_PT_DIR, os.pardir))
src_cfg = os.path.join(CKPT_ROOT, "adapter_config.json")
if os.path.exists(src_cfg):
with open(src_cfg, "r", encoding="utf-8") as f:
adapter_cfg = json.load(f)
adapter_cfg["base_model_name_or_path"] = REF_DIR
with open(os.path.join(adapters_dir, "adapter_config.json"), "w", encoding="utf-8") as f:
json.dump(adapter_cfg, f, ensure_ascii=False, indent=2)
copied_cfg = True
except Exception:
copied_cfg = False
if not copied_cfg:
adapter_cfg = {
"peft_type": "LORA",
"base_model_name_or_path": REF_DIR,
"r": int(r),
"lora_alpha": int(alpha),
"lora_dropout": 0.0,
"lora_dropout": 0.05,
"bias": "none",
"task_type": "CAUSAL_LM",
"target_modules": target_modules
@ -266,7 +292,7 @@ elif os.path.exists(idx_json):
except Exception: pass
log("[save] 导出了 LoRA 适配器 →", adapters_dir)
log("INFO: 可用 Transformers/vLLM/SGLang 以『REF_MODEL + adapters/adapter_model.safetensors』方式推理。")
log("INFO: 可用 Transformers/vLLM/SGLang 以『REF模型 + adapters/adapter_model.safetensors』方式推理。")
sys.exit(0)
else:
log("[check] contains LoRA keys: False (stream loading into model)")
@ -292,27 +318,45 @@ if state is not None and has_lora:
if not lora_state:
print("ERR: 识别到 LoRA 但未筛出权重;中止。", file=sys.stderr); sys.exit(10)
lora_A_keys = [k for k in lora_state if k.endswith(".lora_A.weight")]
lora_A_keys = [k for k in lora_state if k.endswith(".lora_A.weight") or k.endswith(".lora_A.default.weight")]
if lora_A_keys:
r = state[lora_A_keys[0]].shape[0]
a = state.get(lora_A_keys[0].replace(".lora_A.weight", ".lora_alpha"))
k0 = lora_A_keys[0]
r = state[k0].shape[0]
ak = k0.replace(".lora_A.weight", ".lora_alpha").replace(".lora_A.default.weight", ".lora_alpha.default")
a = state.get(ak)
alpha = int(a.item()) if (a is not None and hasattr(a, "item")) else int(r)
else:
r, alpha = 16, 16
targets = sorted(set(re.sub(r"\.lora_(A|B)\.weight$", "", k) for k in lora_A_keys))
targets = sorted(set(re.sub(r"\.lora_(A|B)(?:\.default)?\.weight$", "", k) for k in lora_A_keys))
target_modules = sorted(set(t.split(".")[-1] for t in targets)) or ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]
adapters_dir = os.path.join(OUT_DIR, "adapters")
os.makedirs(adapters_dir, exist_ok=True)
save_file(lora_state, os.path.join(adapters_dir, "adapter_model.safetensors"))
# ☆ 优先复用训练时的 adapter_config.json若存在只覆盖 base_model_name_or_path
copied_cfg = False
try:
CKPT_ROOT = os.path.abspath(os.path.join(TMP_PT_DIR, os.pardir))
src_cfg = os.path.join(CKPT_ROOT, "adapter_config.json")
if os.path.exists(src_cfg):
with open(src_cfg, "r", encoding="utf-8") as f:
adapter_cfg = json.load(f)
adapter_cfg["base_model_name_or_path"] = REF_DIR
with open(os.path.join(adapters_dir, "adapter_config.json"), "w", encoding="utf-8") as f:
json.dump(adapter_cfg, f, ensure_ascii=False, indent=2)
copied_cfg = True
except Exception:
copied_cfg = False
if not copied_cfg:
adapter_cfg = {
"peft_type": "LORA",
"base_model_name_or_path": REF_DIR,
"r": int(r),
"lora_alpha": int(alpha),
"lora_dropout": 0.0,
"lora_dropout": 0.05,
"bias": "none",
"task_type": "CAUSAL_LM",
"target_modules": target_modules
@ -329,7 +373,7 @@ if state is not None and has_lora:
except Exception: pass
log("[save] 导出了 LoRA 适配器 →", adapters_dir)
log("INFO: 可用 Transformers/vLLM/SGLang 以『REF_MODEL + adapters/adapter_model.safetensors』方式推理。")
log("INFO: 可用 Transformers/vLLM/SGLang 以『REF模型 + adapters/adapter_model.safetensors』方式推理。")
sys.exit(0)
# ——走到这里表示“无 LoRA”把 FP32 权重注入模型并保存 BF16 safetensors——