import torch from torch import nn from mistral_model import CostWiseMistralForCausalLM, CostWiseHead from mistral_config import CostWiseMistralConfig from peft import LoraConfig, TaskType, get_peft_model, PeftModel def get_model(model_args, training_args, output_token_id): config = CostWiseMistralConfig.from_pretrained(model_args.model_name_or_path, token=model_args.token, cache_dir=model_args.cache_dir, trust_remote_code=True) if model_args.use_flash_attn: model = CostWiseMistralForCausalLM.from_pretrained( model_args.model_name_or_path, torch_dtype=torch.float16 if training_args.fp16 else torch.bfloat16, use_flash_attention_2=True, token=model_args.token, cache_dir=model_args.cache_dir, from_tf=bool(".ckpt" in model_args.model_name_or_path), trust_remote_code=True, config=config ) else: model = CostWiseMistralForCausalLM.from_pretrained( model_args.model_name_or_path, use_flash_attention_2=False, token=model_args.token, cache_dir=model_args.cache_dir, from_tf=bool(".ckpt" in model_args.model_name_or_path), trust_remote_code=True, config=config ) model.config.use_cache = False if model_args.layer_wise: lm_head = nn.ModuleList([CostWiseHead( model.config.hidden_size, 1) for _ in range( model_args.start_layer, model.config.num_hidden_layers + 1, model_args.layer_sep)]) state_dict_back = model.lm_head.state_dict() state_dict_back['weight'] = state_dict_back['weight'][output_token_id: output_token_id + 1, :] for i in range(len(lm_head)): lm_head[i].linear_head.load_state_dict(state_dict_back) model.set_output_embeddings(lm_head) model.config.start_layer = model_args.start_layer model.config.layer_sep = model_args.layer_sep model.config.layer_wise = model_args.layer_wise if model_args.raw_peft is not None: model = PeftModel.from_pretrained(model, model_args.raw_peft) model = model.merge_and_unload() if model_args.from_peft is not None: model = PeftModel.from_pretrained(model, model_args.from_peft, is_trainable=True) model.print_trainable_parameters() else: if model_args.use_lora: peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, r=model_args.lora_rank, target_modules=model_args.target_modules, lora_alpha=model_args.lora_alpha, lora_dropout=model_args.lora_dropout, modules_to_save=model_args.lora_extra_parameters ) model = get_peft_model(model, peft_config) model.print_trainable_parameters() print(model) return model