44 lines
1.7 KiB
Python
44 lines
1.7 KiB
Python
import os
|
|
import torch
|
|
import logging
|
|
from typing import Optional
|
|
|
|
from FlagEmbedding.abc.finetune.embedder import AbsEmbedderTrainer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class IREmbedderTrainer(AbsEmbedderTrainer):
|
|
"""
|
|
Trainer class for base encoder models.
|
|
"""
|
|
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
|
"""Save the model to directory.
|
|
|
|
Args:
|
|
output_dir (Optional[str], optional): Output directory to save the model. Defaults to ``None``.
|
|
|
|
Raises:
|
|
NotImplementedError
|
|
"""
|
|
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
logger.info("Saving model checkpoint to %s", output_dir)
|
|
# Save a trained model and configuration using `save_pretrained()`.
|
|
# They can then be reloaded using `from_pretrained()`
|
|
if not hasattr(self.model, 'save'):
|
|
raise NotImplementedError(
|
|
f'MODEL {self.model.__class__.__name__} '
|
|
f'does not support save interface')
|
|
else:
|
|
self.model.save(output_dir)
|
|
if self.tokenizer is not None and self.is_world_process_zero():
|
|
self.tokenizer.save_pretrained(output_dir)
|
|
|
|
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
|
|
|
# save the checkpoint for sentence-transformers library
|
|
# if self.is_world_process_zero():
|
|
# save_ckpt_for_sentence_transformers(output_dir,
|
|
# pooling_mode=self.args.sentence_pooling_method,
|
|
# normlized=self.args.normlized) |