import logging import os from typing import Dict, Optional import torch from transformers import Trainer logger = logging.getLogger(__name__) class PreTrainer(Trainer): def log(self, logs: Dict[str, float]) -> None: """ Log `logs` on the various objects watching training. Subclass and override this method to inject custom behavior. Args: logs (`Dict[str, float]`): The values to log. """ logs["step"] = self.state.global_step if self.state.epoch is not None: logs["epoch"] = round(self.state.epoch, 2) output = {**logs, **{"step": self.state.global_step}} self.state.log_history.append(output) self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) def _save(self, output_dir: Optional[str] = None, state_dict=None): output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) logger.info(f"Saving model checkpoint to {output_dir}") # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` if not hasattr(self.model, 'save_pretrained'): logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") state_dict = self.model.state_dict() torch.save(state_dict, os.path.join(output_dir, "pytorch_model.bin")) else: self.model.save_pretrained(output_dir) if self.tokenizer is not None: self.tokenizer.save_pretrained(os.path.join(output_dir, "encoder_model")) # Good practice: save your training arguments together with the trained model torch.save(self.args, os.path.join(output_dir, "training_args.bin"))