import logging from typing import Tuple from transformers import ( AutoModel, AutoConfig, AutoTokenizer, PreTrainedTokenizer ) from FlagEmbedding.abc.finetune.embedder import AbsEmbedderRunner, AbsEmbedderModel, EmbedderTrainerCallbackForDataRefresh, AbsEmbedderModelArguments, AbsEmbedderTrainingArguments from modeling import BiIREmbedderModel from trainer import IREmbedderTrainer from dataset import ( IREmbedderTrainDataset, IREmbedderCollator, IREmbedderSameDatasetTrainDataset, IREmbedderSameDatasetCollator ) logger = logging.getLogger(__name__) class IREmbedderRunner(AbsEmbedderRunner): """ Finetune Runner for base embedding models. """ def load_train_dataset(self): if self.data_args.same_dataset_within_batch: train_dataset = IREmbedderSameDatasetTrainDataset( args=self.data_args, default_batch_size=self.training_args.per_device_train_batch_size, seed=self.training_args.seed, tokenizer=self.tokenizer, process_index=self.training_args.process_index, num_processes=self.training_args.world_size ) self.training_args.per_device_train_batch_size = 1 self.training_args.dataloader_num_workers = 0 # avoid multi-processing else: train_dataset = IREmbedderTrainDataset( args=self.data_args, tokenizer=self.tokenizer ) return train_dataset def load_data_collator(self): if self.data_args.same_dataset_within_batch: EmbedCollator = IREmbedderSameDatasetCollator else: EmbedCollator = IREmbedderTrainDataset data_collator = EmbedCollator( tokenizer=self.tokenizer, query_max_len=self.data_args.query_max_len, passage_max_len=self.data_args.passage_max_len, sub_batch_size=self.training_args.sub_batch_size, pad_to_multiple_of=self.data_args.pad_to_multiple_of, padding=True, return_tensors="pt" ) return data_collator def load_tokenizer_and_model(self) -> Tuple[PreTrainedTokenizer, AbsEmbedderModel, AbsEmbedderModel]: """Load tokenizer and model. Returns: Tuple[PreTrainedTokenizer, AbsEmbedderModel]: Tokenizer and model instances. """ tokenizer = AutoTokenizer.from_pretrained( self.model_args.model_name_or_path, cache_dir=self.model_args.cache_dir, token=self.model_args.token, trust_remote_code=self.model_args.trust_remote_code ) base_model = AutoModel.from_pretrained( self.model_args.model_name_or_path, cache_dir=self.model_args.cache_dir, token=self.model_args.token, trust_remote_code=self.model_args.trust_remote_code ) num_labels = 1 config = AutoConfig.from_pretrained( self.model_args.config_name if self.model_args.config_name else self.model_args.model_name_or_path, num_labels=num_labels, cache_dir=self.model_args.cache_dir, token=self.model_args.token, trust_remote_code=self.model_args.trust_remote_code, ) logger.info('Config: %s', config) model = BiIREmbedderModel( base_model, tokenizer=tokenizer, negatives_cross_device=self.training_args.negatives_cross_device, temperature=self.training_args.temperature, answer_temperature=self.training_args.answer_temperature, sub_batch_size=self.training_args.sub_batch_size, kd_loss_type=self.training_args.kd_loss_type, sentence_pooling_method=self.training_args.sentence_pooling_method, normalize_embeddings=self.training_args.normalize_embeddings, normalize_answer=self.training_args.normalize_answer, training_type=self.training_args.training_type ) if self.training_args.gradient_checkpointing: model.enable_input_require_grads() if self.training_args.fix_position_embedding: for k, v in model.named_parameters(): if "position_embeddings" in k: logging.info(f"Freeze the parameters for {k}") v.requires_grad = False return tokenizer, model def load_trainer(self) -> IREmbedderTrainer: """Load the trainer. Returns: IREmbedderTrainer: Loaded trainer instance. """ trainer = IREmbedderTrainer( model=self.model, args=self.training_args, train_dataset=self.train_dataset, data_collator=self.data_collator, tokenizer=self.tokenizer ) if self.data_args.same_dataset_within_batch: trainer.add_callback(EmbedderTrainerCallbackForDataRefresh(self.train_dataset)) return trainer