from sentence_transformers import SentenceTransformer, models from transformers.trainer import * def save_ckpt_for_sentence_transformers(ckpt_dir, pooling_mode: str = 'cls', normlized: bool=True): word_embedding_model = models.Transformer(ckpt_dir) pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode=pooling_mode) if normlized: normlize_layer = models.Normalize() model = SentenceTransformer(modules=[word_embedding_model, pooling_model, normlize_layer], device='cpu') else: model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device='cpu') model.save(ckpt_dir) class BiTrainer(Trainer): def _save(self, output_dir: Optional[str] = None, state_dict=None): output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) logger.info("Saving model checkpoint to %s", output_dir) # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` if not hasattr(self.model, 'save'): raise NotImplementedError( f'MODEL {self.model.__class__.__name__} ' f'does not support save interface') else: self.model.save(output_dir) if self.tokenizer is not None and self.is_world_process_zero(): self.tokenizer.save_pretrained(output_dir) torch.save(self.args, os.path.join(output_dir, "training_args.bin")) # save the checkpoint for sentence-transformers library if self.is_world_process_zero(): save_ckpt_for_sentence_transformers(output_dir, pooling_mode=self.args.sentence_pooling_method, normlized=self.args.normlized) def compute_loss(self, model, inputs, return_outputs=False): """ How the loss is computed by Trainer. By default, all models return the loss in the first element. Subclass and override for custom behavior. """ outputs = model(**inputs) loss = outputs.loss return (loss, outputs) if return_outputs else loss