import logging import os import torch from torch import nn from transformers import BertForMaskedLM, AutoModelForMaskedLM from transformers.modeling_outputs import MaskedLMOutput from .arguments import ModelArguments from .enhancedDecoder import BertLayerForDecoder logger = logging.getLogger(__name__) class RetroMAEForPretraining(nn.Module): def __init__( self, bert: BertForMaskedLM, model_args: ModelArguments, ): super(RetroMAEForPretraining, self).__init__() self.lm = bert if hasattr(self.lm, 'bert'): self.decoder_embeddings = self.lm.bert.embeddings elif hasattr(self.lm, 'roberta'): self.decoder_embeddings = self.lm.roberta.embeddings else: self.decoder_embeddings = self.lm.bert.embeddings self.c_head = BertLayerForDecoder(bert.config) self.c_head.apply(self.lm._init_weights) self.cross_entropy = nn.CrossEntropyLoss() self.model_args = model_args def gradient_checkpointing_enable(self, **kwargs): self.lm.gradient_checkpointing_enable(**kwargs) def forward(self, encoder_input_ids, encoder_attention_mask, encoder_labels, decoder_input_ids, decoder_attention_mask, decoder_labels): lm_out: MaskedLMOutput = self.lm( encoder_input_ids, encoder_attention_mask, labels=encoder_labels, output_hidden_states=True, return_dict=True ) cls_hiddens = lm_out.hidden_states[-1][:, :1] # B 1 D decoder_embedding_output = self.decoder_embeddings(input_ids=decoder_input_ids) hiddens = torch.cat([cls_hiddens, decoder_embedding_output[:, 1:]], dim=1) # decoder_position_ids = self.lm.bert.embeddings.position_ids[:, :decoder_input_ids.size(1)] # decoder_position_embeddings = self.lm.bert.embeddings.position_embeddings(decoder_position_ids) # B L D # query = decoder_position_embeddings + cls_hiddens cls_hiddens = cls_hiddens.expand(hiddens.size(0), hiddens.size(1), hiddens.size(2)) query = self.decoder_embeddings(inputs_embeds=cls_hiddens) matrix_attention_mask = self.lm.get_extended_attention_mask( decoder_attention_mask, decoder_attention_mask.shape, decoder_attention_mask.device ) hiddens = self.c_head(query=query, key=hiddens, value=hiddens, attention_mask=matrix_attention_mask)[0] pred_scores, loss = self.mlm_loss(hiddens, decoder_labels) return (loss + lm_out.loss,) def mlm_loss(self, hiddens, labels): if hasattr(self.lm, 'cls'): pred_scores = self.lm.cls(hiddens) elif hasattr(self.lm, 'lm_head'): pred_scores = self.lm.lm_head(hiddens) else: raise NotImplementedError masked_lm_loss = self.cross_entropy( pred_scores.view(-1, self.lm.config.vocab_size), labels.view(-1) ) return pred_scores, masked_lm_loss def save_pretrained(self, output_dir: str): self.lm.save_pretrained(os.path.join(output_dir, "encoder_model")) torch.save(self.state_dict(), os.path.join(output_dir, 'pytorch_model.bin')) @classmethod def from_pretrained( cls, model_args: ModelArguments, *args, **kwargs ): hf_model = AutoModelForMaskedLM.from_pretrained(*args, **kwargs) model = cls(hf_model, model_args) return model