import logging import os from pathlib import Path import torch.distributed as dist from transformers import AutoConfig, AutoTokenizer from transformers import ( HfArgumentParser, set_seed, ) from transformers import ( TrainerCallback, TrainingArguments, TrainerState, TrainerControl ) from .arguments import ModelArguments, DataArguments, \ RetrieverTrainingArguments as TrainingArguments from .data import SameDatasetTrainDataset, EmbedCollator from .modeling import BGEM3Model from .trainer import BiTrainer logger = logging.getLogger(__name__) class TrainerCallbackForDataRefresh(TrainerCallback): def __init__(self, train_dataset): self.train_dataset = train_dataset def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): """ Event called at the end of an epoch. """ self.train_dataset.refresh_epoch() def main(): parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args: ModelArguments data_args: DataArguments training_args: TrainingArguments if ( os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir ): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." ) # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, ) logger.warning( "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", training_args.local_rank, training_args.device, training_args.n_gpu, bool(training_args.local_rank != -1), training_args.fp16, ) logger.info("Training/evaluation parameters %s", training_args) logger.info("Model parameters %s", model_args) logger.info("Data parameters %s", data_args) # Set seed set_seed(training_args.seed) num_labels = 1 tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=False, ) config = AutoConfig.from_pretrained( model_args.config_name if model_args.config_name else model_args.model_name_or_path, num_labels=num_labels, cache_dir=model_args.cache_dir, ) logger.info('Config: %s', config) model = BGEM3Model(model_name=model_args.model_name_or_path, normlized=training_args.normlized, sentence_pooling_method=training_args.sentence_pooling_method, negatives_cross_device=training_args.negatives_cross_device, temperature=training_args.temperature, enable_sub_batch=training_args.enable_sub_batch, unified_finetuning=training_args.unified_finetuning, use_self_distill=training_args.use_self_distill, colbert_dim=training_args.colbert_dim, self_distill_start_step=training_args.self_distill_start_step) if training_args.fix_position_embedding: for k, v in model.named_parameters(): if "position_embeddings" in k: logging.info(f"Freeze the parameters for {k}") v.requires_grad = False if training_args.fix_encoder: for k, v in model.named_parameters(): if "colbert_linear" in k or 'sparse_linear' in k: logging.info(f"train the parameters for {k}") else: v.requires_grad = False # print(f"===========================Rank {dist.get_rank()}: start loading data===========================") if data_args.same_task_within_batch: train_dataset = SameDatasetTrainDataset(args=data_args, batch_size=training_args.per_device_train_batch_size, seed=training_args.seed, num_processes=training_args.world_size, process_index=training_args.process_index) training_args.per_device_train_batch_size = 1 training_args.dataloader_num_workers = 0 # avoid multi-processes else: raise NotImplementedError("Not support `same_task_within_batch=False`") data_collator = EmbedCollator( tokenizer, query_max_len=data_args.query_max_len, passage_max_len=data_args.passage_max_len ) trainer = BiTrainer( model=model, args=training_args, train_dataset=train_dataset, data_collator=data_collator, tokenizer=tokenizer ) if data_args.same_task_within_batch: trainer.add_callback(TrainerCallbackForDataRefresh(train_dataset)) Path(training_args.output_dir).mkdir(parents=True, exist_ok=True) # Training # print(f"===========================Rank {dist.get_rank()}: start training===========================") trainer.train() trainer.save_model() # For convenience, we also re-save the tokenizer to the same directory, # so that you can share your model easily on huggingface.co/models =) if trainer.is_world_process_zero(): tokenizer.save_pretrained(training_args.output_dir) if __name__ == "__main__": main()