embed-bge-m3/FlagEmbedding/research/Matroyshka_reranker/finetune/self_distillation/load_model.py

73 lines
3.0 KiB
Python

import torch
from torch import nn
from mistral_model import CostWiseMistralForCausalLM, CostWiseHead
from mistral_config import CostWiseMistralConfig
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
def get_model(model_args, training_args, output_token_id):
config = CostWiseMistralConfig.from_pretrained(model_args.model_name_or_path,
token=model_args.token,
cache_dir=model_args.cache_dir,
trust_remote_code=True)
if model_args.use_flash_attn:
model = CostWiseMistralForCausalLM.from_pretrained(
model_args.model_name_or_path,
torch_dtype=torch.float16 if training_args.fp16 else torch.bfloat16,
use_flash_attention_2=True,
token=model_args.token,
cache_dir=model_args.cache_dir,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
trust_remote_code=True,
config=config
)
else:
model = CostWiseMistralForCausalLM.from_pretrained(
model_args.model_name_or_path,
use_flash_attention_2=False,
token=model_args.token,
cache_dir=model_args.cache_dir,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
trust_remote_code=True,
config=config
)
model.config.use_cache = False
if model_args.layer_wise:
lm_head = nn.ModuleList([CostWiseHead(
model.config.hidden_size, 1) for _ in range(
model_args.start_layer,
model.config.num_hidden_layers + 1,
model_args.layer_sep)])
state_dict_back = model.lm_head.state_dict()
state_dict_back['weight'] = state_dict_back['weight'][output_token_id: output_token_id + 1, :]
for i in range(len(lm_head)):
lm_head[i].linear_head.load_state_dict(state_dict_back)
model.set_output_embeddings(lm_head)
model.config.start_layer = model_args.start_layer
model.config.layer_sep = model_args.layer_sep
model.config.layer_wise = model_args.layer_wise
if model_args.raw_peft is not None:
model = PeftModel.from_pretrained(model, model_args.raw_peft)
model = model.merge_and_unload()
if model_args.from_peft is not None:
model = PeftModel.from_pretrained(model, model_args.from_peft, is_trainable=True)
model.print_trainable_parameters()
else:
if model_args.use_lora:
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=model_args.lora_rank,
target_modules=model_args.target_modules,
lora_alpha=model_args.lora_alpha,
lora_dropout=model_args.lora_dropout,
modules_to_save=model_args.lora_extra_parameters
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
print(model)
return model