embed-bge-m3/FlagEmbedding/research/baai_general_embedding/retromae_pretrain/trainer.py

47 lines
1.8 KiB
Python

import logging
import os
from typing import Dict, Optional
import torch
from transformers import Trainer
logger = logging.getLogger(__name__)
class PreTrainer(Trainer):
def log(self, logs: Dict[str, float]) -> None:
"""
Log `logs` on the various objects watching training.
Subclass and override this method to inject custom behavior.
Args:
logs (`Dict[str, float]`):
The values to log.
"""
logs["step"] = self.state.global_step
if self.state.epoch is not None:
logs["epoch"] = round(self.state.epoch, 2)
output = {**logs, **{"step": self.state.global_step}}
self.state.log_history.append(output)
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
def _save(self, output_dir: Optional[str] = None, state_dict=None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not hasattr(self.model, 'save_pretrained'):
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict()
torch.save(state_dict, os.path.join(output_dir, "pytorch_model.bin"))
else:
self.model.save_pretrained(output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(os.path.join(output_dir, "encoder_model"))
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))