from transformers.trainer import * from transformers.deepspeed import is_deepspeed_zero3_enabled from peft import get_peft_model_state_dict class BiTrainer(Trainer): use_lora: bool def _save(self, output_dir: Optional[str] = None, state_dict=None): if not self.use_lora: super()._save(output_dir, state_dict) return output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) logger.info("Saving model checkpoint to %s", output_dir) # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` if not hasattr(self.model, 'save'): raise NotImplementedError( f'MODEL {self.model.__class__.__name__} ' f'does not support save interface') else: self.model.save(output_dir) # if self.tokenizer is not None and self.is_world_process_zero(): # self.tokenizer.save_pretrained(output_dir) torch.save(self.args, os.path.join(output_dir, "training_args.bin")) if is_deepspeed_zero3_enabled(): if state_dict is None: state_dict = self.model.state_dict() prefix = 'model.' assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys()) state_dict = {k[len(prefix):]: v for k, v in state_dict.items()} lora_state_dict = get_peft_model_state_dict(self.model.model, state_dict) if self.args.process_index <= 0: torch.save(lora_state_dict, os.path.join(output_dir, "adapter_model.bin")) print(f"Save adapter model at {output_dir}") def compute_loss(self, model, inputs, return_outputs=False): """ How the loss is computed by Trainer. By default, all models return the loss in the first element. Subclass and override for custom behavior. """ outputs = model(**inputs) loss = outputs.loss return (loss, outputs) if return_outputs else loss