95 lines
3.9 KiB
Python
95 lines
3.9 KiB
Python
import copy
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from mistral_model import CostWiseMistralForCausalLM, CostWiseHead
|
|
from mistral_config import CostWiseMistralConfig
|
|
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
|
|
|
|
|
|
def get_model(model_args, training_args, output_token_id):
|
|
config = CostWiseMistralConfig.from_pretrained(model_args.model_name_or_path,
|
|
token=model_args.token,
|
|
cache_dir=model_args.cache_dir,
|
|
trust_remote_code=True)
|
|
if model_args.use_flash_attn:
|
|
model = CostWiseMistralForCausalLM.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
torch_dtype=torch.float16 if training_args.fp16 else torch.bfloat16,
|
|
use_flash_attention_2=True,
|
|
token=model_args.token,
|
|
cache_dir=model_args.cache_dir,
|
|
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
|
trust_remote_code=True,
|
|
config=config
|
|
)
|
|
else:
|
|
model = CostWiseMistralForCausalLM.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
use_flash_attention_2=False,
|
|
token=model_args.token,
|
|
cache_dir=model_args.cache_dir,
|
|
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
|
trust_remote_code=True,
|
|
config=config
|
|
)
|
|
model.config.use_cache = False
|
|
if model_args.layer_wise:
|
|
lm_head = nn.ModuleList([CostWiseHead(
|
|
model.config.hidden_size, 1) for _ in range(
|
|
model_args.start_layer,
|
|
model.config.num_hidden_layers + 1,
|
|
model_args.layer_sep)])
|
|
state_dict_back = model.lm_head.state_dict()
|
|
state_dict_back['weight'] = state_dict_back['weight'][output_token_id: output_token_id + 1, :]
|
|
for i in range(len(lm_head)):
|
|
lm_head[i].linear_head.load_state_dict(state_dict_back)
|
|
model.set_output_embeddings(lm_head)
|
|
model.config.start_layer = model_args.start_layer
|
|
model.config.layer_sep = model_args.layer_sep
|
|
model.config.layer_wise = model_args.layer_wise
|
|
|
|
if model_args.raw_peft is not None:
|
|
for raw_peft in model_args.raw_peft:
|
|
model = PeftModel.from_pretrained(model, raw_peft)
|
|
model = model.merge_and_unload()
|
|
|
|
tmp_model = None
|
|
if model_args.use_lora:
|
|
# if model_args.finetune_type == 'layer':
|
|
# peft_config = LoraConfig(
|
|
# task_type=TaskType.CAUSAL_LM,
|
|
# inference_mode=False,
|
|
# r=model_args.lora_rank,
|
|
# target_modules=model_args.target_modules,
|
|
# lora_alpha=model_args.lora_alpha,
|
|
# lora_dropout=model_args.lora_dropout,
|
|
# modules_to_save=model_args.lora_extra_parameters
|
|
# )
|
|
|
|
# else:
|
|
# peft_config = LoraConfig(
|
|
# task_type=TaskType.CAUSAL_LM,
|
|
# inference_mode=False,
|
|
# r=model_args.lora_rank,
|
|
# target_modules=model_args.target_modules,
|
|
# lora_alpha=model_args.lora_alpha,
|
|
# lora_dropout=model_args.lora_dropout,
|
|
# modules_to_save=model_args.lora_extra_parameters,
|
|
# layers_to_transform=model_args.compress_layers
|
|
# )
|
|
peft_config = LoraConfig(
|
|
task_type=TaskType.CAUSAL_LM,
|
|
inference_mode=False,
|
|
r=model_args.lora_rank,
|
|
target_modules=model_args.target_modules,
|
|
lora_alpha=model_args.lora_alpha,
|
|
lora_dropout=model_args.lora_dropout,
|
|
modules_to_save=model_args.lora_extra_parameters
|
|
)
|
|
model = get_peft_model(model, peft_config)
|
|
model.print_trainable_parameters()
|
|
print(model)
|
|
|
|
return model, tmp_model |