embed-bge-m3/FlagEmbedding/research/llm_dense_retriever/finetune/load_model.py

142 lines
5.9 KiB
Python

import os
import re
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
def find_largest_checkpoint(checkpoint_dir):
checkpoint_pattern = re.compile(r'checkpoint-(\d+)')
max_number = -1
max_checkpoint_file = None
for file in os.listdir(checkpoint_dir):
match = checkpoint_pattern.search(file)
if match:
number = int(match.group(1))
if number > max_number:
max_number = number
max_checkpoint_file = file
if max_checkpoint_file:
return os.path.join(checkpoint_dir, max_checkpoint_file)
else:
return None
def get_model(model_args, output_dir, resize, resize_tokens):
if model_args.config_name:
config = AutoConfig.from_pretrained(model_args.config_name,
token=model_args.token,
cache_dir=model_args.cache_dir,
)
elif model_args.model_name_or_path:
config = AutoConfig.from_pretrained(model_args.model_name_or_path,
token=model_args.token,
cache_dir=model_args.cache_dir,
)
else:
raise ValueError(
"You are instantiating a new config instance from scratch. This is not supported by this script."
)
config.use_cache = False
if model_args.model_name_or_path:
model = AutoModel.from_pretrained(
model_args.model_name_or_path,
# torch_dtype=torch.bfloat16,
use_flash_attention_2=True if model_args.use_flash_attn else False,
token=model_args.token,
cache_dir=model_args.cache_dir,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
)
else:
print("Training new model from scratch")
model = model_args.from_config(config)
if model_args.raw_peft is not None:
model.set_input_embeddings(torch.load(os.path.join(model_args.raw_peft, 'embedding', 'emb.pth')))
model = PeftModel.from_pretrained(model, model_args.raw_peft)
model = model.merge_and_unload()
if resize:
model.resize_token_embeddings(resize_tokens)
os.makedirs(os.path.join(output_dir, 'embedding'), exist_ok=True)
torch.save(model.embed_tokens, os.path.join(output_dir, 'embedding', 'emb.pth'))
target_modules = model_args.target_modules
else:
target_modules = model_args.target_modules
if 'embed_tokens' in target_modules:
target_modules.remove('embed_tokens')
if model_args.from_peft is not None:
if os.path.exists(os.path.join(model_args.from_peft, 'embedding')):
model.set_input_embeddings(torch.load(os.path.join(model_args.from_peft, 'embedding', 'emb.pth')))
torch.save(model.embed_tokens, os.path.join(output_dir, 'embedding', 'emb.pth'))
model = PeftModel.from_pretrained(model, model_args.from_peft, is_trainable=True)
model.print_trainable_parameters()
else:
if model_args.use_lora:
peft_config = LoraConfig(
task_type=TaskType.FEATURE_EXTRACTION,
inference_mode=False,
r=model_args.lora_rank,
target_modules=target_modules,
modules_to_save=model_args.modules_to_save,
lora_alpha=model_args.lora_alpha,
lora_dropout=model_args.lora_dropout
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
return model
def save_merged_model(model_args, output_dir):
if model_args.config_name:
config = AutoConfig.from_pretrained(model_args.config_name,
token=model_args.token,
cache_dir=model_args.cache_dir,
)
elif model_args.model_name_or_path:
config = AutoConfig.from_pretrained(model_args.model_name_or_path,
token=model_args.token,
cache_dir=model_args.cache_dir,
)
else:
raise ValueError(
"You are instantiating a new config instance from scratch. This is not supported by this script."
)
config.use_cache = False
if model_args.model_name_or_path:
model = AutoModel.from_pretrained(
model_args.model_name_or_path,
# torch_dtype=torch.bfloat16,
use_flash_attention_2=True if model_args.use_flash_attn else False,
token=model_args.token,
cache_dir=model_args.cache_dir,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
)
else:
print("Training new model from scratch")
model = model_args.from_config(config)
if model_args.raw_peft is not None:
model.set_input_embeddings(torch.load(os.path.join(model_args.raw_peft, 'embedding', 'emb.pth')))
model = PeftModel.from_pretrained(model, model_args.raw_peft)
model = model.merge_and_unload()
if os.path.exists(os.path.join(output_dir, 'embedding', 'emb.pth')):
model.set_input_embeddings(torch.load(os.path.join(output_dir, 'embedding', 'emb.pth')))
try:
model = PeftModel.from_pretrained(model, output_dir)
model = model.merge_and_unload()
except:
model = PeftModel.from_pretrained(model, find_largest_checkpoint(output_dir))
model = model.merge_and_unload()
model.save_pretrained(os.path.join(output_dir, 'full_model'))
tokenizer = AutoTokenizer.from_pretrained(output_dir)
tokenizer.save_pretrained(os.path.join(output_dir, 'full_model'))