embed-bge-m3/FlagEmbedding/research/LLARA/pretrain/load_model.py

50 lines
2.0 KiB
Python

from transformers import AutoConfig, AutoModelForCausalLM
from peft import LoraConfig, TaskType, get_peft_model
from modeling import PreLlamaModel
def get_model(model_args, use_gradient_checkpointing: bool = False):
if model_args.config_name:
config = AutoConfig.from_pretrained(model_args.config_name,
token=model_args.token,
cache_dir=model_args.cache_dir,
)
elif model_args.model_name_or_path:
config = AutoConfig.from_pretrained(model_args.model_name_or_path,
token=model_args.token,
cache_dir=model_args.cache_dir,
)
else:
raise ValueError(
"You are instantiating a new config instance from scratch. This is not supported by this script."
)
if use_gradient_checkpointing:
config.use_cache = False
if model_args.model_name_or_path:
model = PreLlamaModel.from_pretrained(
model_args.model_name_or_path,
use_flash_attention_2=True if model_args.use_flash_attn else False,
attn_implementation='sdpa',
token=model_args.token,
cache_dir=model_args.cache_dir,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
)
else:
print("Training new model from scratch")
model = model_args.from_config(config)
if model_args.use_lora:
peft_config = LoraConfig(
task_type="CAUSAL_LM",
inference_mode=False,
r=model_args.lora_rank,
target_modules=model_args.target_modules,
lora_alpha=model_args.lora_alpha,
lora_dropout=model_args.lora_dropout
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
return model