This commit is contained in:
parent
d2bdb8af9e
commit
8cc7a41d48
|
|
@ -198,6 +198,16 @@ def stream_load_full_into_model(model, shard_to_keys):
|
|||
unexpected_total += len(u)
|
||||
log(f"[load] missing_total={missing_total} unexpected_total={unexpected_total}")
|
||||
|
||||
def is_lora_A(k:str)->bool:
|
||||
return k.endswith(".lora_A.weight") or k.endswith(".lora_A.default.weight")
|
||||
|
||||
def alpha_key_for(kA:str)->str:
|
||||
if kA.endswith(".lora_A.weight"):
|
||||
return kA.replace(".lora_A.weight", ".lora_alpha")
|
||||
if kA.endswith(".lora_A.default.weight"):
|
||||
return kA.replace(".lora_A.default.weight", ".lora_alpha.default")
|
||||
return ""
|
||||
|
||||
log("[load] ref model from:", REF_DIR)
|
||||
cfg = AutoConfig.from_pretrained(REF_DIR, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
|
|
@ -223,15 +233,15 @@ elif os.path.exists(idx_json):
|
|||
lora_state, lora_keys = stream_collect_lora_from_shards(shard_to_keys)
|
||||
if not lora_state:
|
||||
print("ERR: 识别到 LoRA 但未筛出权重;中止。", file=sys.stderr); sys.exit(10)
|
||||
# 估 r / alpha
|
||||
lora_A_keys = [k for k in lora_keys if k.endswith(".lora_A.weight")]
|
||||
# 估 r / alpha(兼容 .default)
|
||||
lora_A_keys = [k for k in lora_keys if is_lora_A(k)]
|
||||
if lora_A_keys:
|
||||
# 找到第一个 lora_A 所在分片并读取 shape
|
||||
k0 = lora_A_keys[0]
|
||||
shard = next(s for s, ks in shard_to_keys.items() if k0 in ks)
|
||||
part = torch.load(os.path.join(TMP_PT_DIR, shard), map_location="cpu")
|
||||
r = part[k0].shape[0]
|
||||
a = part.get(k0.replace(".lora_A.weight", ".lora_alpha"))
|
||||
ak = alpha_key_for(k0)
|
||||
a = part.get(ak)
|
||||
alpha = int(a.item()) if (a is not None and hasattr(a, "item")) else int(r)
|
||||
else:
|
||||
r, alpha = 16, 16
|
||||
|
|
@ -240,15 +250,31 @@ elif os.path.exists(idx_json):
|
|||
os.makedirs(adapters_dir, exist_ok=True)
|
||||
save_file(lora_state, os.path.join(adapters_dir, "adapter_model.safetensors"))
|
||||
|
||||
targets = sorted(set(re.sub(r"\.lora_(A|B)\.weight$", "", k) for k in lora_A_keys))
|
||||
targets = sorted(set(re.sub(r"\.lora_(A|B)(?:\.default)?\.weight$", "", k) for k in lora_A_keys))
|
||||
target_modules = sorted(set(t.split(".")[-1] for t in targets)) or ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]
|
||||
|
||||
# ☆ 优先复用训练时的 adapter_config.json(若存在),只覆盖 base_model_name_or_path
|
||||
copied_cfg = False
|
||||
try:
|
||||
CKPT_ROOT = os.path.abspath(os.path.join(TMP_PT_DIR, os.pardir))
|
||||
src_cfg = os.path.join(CKPT_ROOT, "adapter_config.json")
|
||||
if os.path.exists(src_cfg):
|
||||
with open(src_cfg, "r", encoding="utf-8") as f:
|
||||
adapter_cfg = json.load(f)
|
||||
adapter_cfg["base_model_name_or_path"] = REF_DIR
|
||||
with open(os.path.join(adapters_dir, "adapter_config.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(adapter_cfg, f, ensure_ascii=False, indent=2)
|
||||
copied_cfg = True
|
||||
except Exception:
|
||||
copied_cfg = False
|
||||
|
||||
if not copied_cfg:
|
||||
adapter_cfg = {
|
||||
"peft_type": "LORA",
|
||||
"base_model_name_or_path": REF_DIR,
|
||||
"r": int(r),
|
||||
"lora_alpha": int(alpha),
|
||||
"lora_dropout": 0.0,
|
||||
"lora_dropout": 0.05,
|
||||
"bias": "none",
|
||||
"task_type": "CAUSAL_LM",
|
||||
"target_modules": target_modules
|
||||
|
|
@ -266,7 +292,7 @@ elif os.path.exists(idx_json):
|
|||
except Exception: pass
|
||||
|
||||
log("[save] 导出了 LoRA 适配器 →", adapters_dir)
|
||||
log("INFO: 可用 Transformers/vLLM/SGLang 以『REF_MODEL + adapters/adapter_model.safetensors』方式推理。")
|
||||
log("INFO: 可用 Transformers/vLLM/SGLang 以『REF模型 + adapters/adapter_model.safetensors』方式推理。")
|
||||
sys.exit(0)
|
||||
else:
|
||||
log("[check] contains LoRA keys: False (stream loading into model)")
|
||||
|
|
@ -292,27 +318,45 @@ if state is not None and has_lora:
|
|||
if not lora_state:
|
||||
print("ERR: 识别到 LoRA 但未筛出权重;中止。", file=sys.stderr); sys.exit(10)
|
||||
|
||||
lora_A_keys = [k for k in lora_state if k.endswith(".lora_A.weight")]
|
||||
lora_A_keys = [k for k in lora_state if k.endswith(".lora_A.weight") or k.endswith(".lora_A.default.weight")]
|
||||
if lora_A_keys:
|
||||
r = state[lora_A_keys[0]].shape[0]
|
||||
a = state.get(lora_A_keys[0].replace(".lora_A.weight", ".lora_alpha"))
|
||||
k0 = lora_A_keys[0]
|
||||
r = state[k0].shape[0]
|
||||
ak = k0.replace(".lora_A.weight", ".lora_alpha").replace(".lora_A.default.weight", ".lora_alpha.default")
|
||||
a = state.get(ak)
|
||||
alpha = int(a.item()) if (a is not None and hasattr(a, "item")) else int(r)
|
||||
else:
|
||||
r, alpha = 16, 16
|
||||
|
||||
targets = sorted(set(re.sub(r"\.lora_(A|B)\.weight$", "", k) for k in lora_A_keys))
|
||||
targets = sorted(set(re.sub(r"\.lora_(A|B)(?:\.default)?\.weight$", "", k) for k in lora_A_keys))
|
||||
target_modules = sorted(set(t.split(".")[-1] for t in targets)) or ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]
|
||||
|
||||
adapters_dir = os.path.join(OUT_DIR, "adapters")
|
||||
os.makedirs(adapters_dir, exist_ok=True)
|
||||
save_file(lora_state, os.path.join(adapters_dir, "adapter_model.safetensors"))
|
||||
|
||||
# ☆ 优先复用训练时的 adapter_config.json(若存在),只覆盖 base_model_name_or_path
|
||||
copied_cfg = False
|
||||
try:
|
||||
CKPT_ROOT = os.path.abspath(os.path.join(TMP_PT_DIR, os.pardir))
|
||||
src_cfg = os.path.join(CKPT_ROOT, "adapter_config.json")
|
||||
if os.path.exists(src_cfg):
|
||||
with open(src_cfg, "r", encoding="utf-8") as f:
|
||||
adapter_cfg = json.load(f)
|
||||
adapter_cfg["base_model_name_or_path"] = REF_DIR
|
||||
with open(os.path.join(adapters_dir, "adapter_config.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(adapter_cfg, f, ensure_ascii=False, indent=2)
|
||||
copied_cfg = True
|
||||
except Exception:
|
||||
copied_cfg = False
|
||||
|
||||
if not copied_cfg:
|
||||
adapter_cfg = {
|
||||
"peft_type": "LORA",
|
||||
"base_model_name_or_path": REF_DIR,
|
||||
"r": int(r),
|
||||
"lora_alpha": int(alpha),
|
||||
"lora_dropout": 0.0,
|
||||
"lora_dropout": 0.05,
|
||||
"bias": "none",
|
||||
"task_type": "CAUSAL_LM",
|
||||
"target_modules": target_modules
|
||||
|
|
@ -329,7 +373,7 @@ if state is not None and has_lora:
|
|||
except Exception: pass
|
||||
|
||||
log("[save] 导出了 LoRA 适配器 →", adapters_dir)
|
||||
log("INFO: 可用 Transformers/vLLM/SGLang 以『REF_MODEL + adapters/adapter_model.safetensors』方式推理。")
|
||||
log("INFO: 可用 Transformers/vLLM/SGLang 以『REF模型 + adapters/adapter_model.safetensors』方式推理。")
|
||||
sys.exit(0)
|
||||
|
||||
# ——走到这里表示“无 LoRA”,把 FP32 权重注入模型并保存 BF16 safetensors——
|
||||
|
|
|
|||
Loading…
Reference in New Issue