103 lines
3.5 KiB
Python
103 lines
3.5 KiB
Python
import logging
|
|
import os
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers import BertForMaskedLM, AutoModelForMaskedLM
|
|
from transformers.modeling_outputs import MaskedLMOutput
|
|
|
|
from .arguments import ModelArguments
|
|
from .enhancedDecoder import BertLayerForDecoder
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RetroMAEForPretraining(nn.Module):
|
|
def __init__(
|
|
self,
|
|
bert: BertForMaskedLM,
|
|
model_args: ModelArguments,
|
|
):
|
|
super(RetroMAEForPretraining, self).__init__()
|
|
self.lm = bert
|
|
|
|
if hasattr(self.lm, 'bert'):
|
|
self.decoder_embeddings = self.lm.bert.embeddings
|
|
elif hasattr(self.lm, 'roberta'):
|
|
self.decoder_embeddings = self.lm.roberta.embeddings
|
|
else:
|
|
self.decoder_embeddings = self.lm.bert.embeddings
|
|
|
|
self.c_head = BertLayerForDecoder(bert.config)
|
|
self.c_head.apply(self.lm._init_weights)
|
|
|
|
self.cross_entropy = nn.CrossEntropyLoss()
|
|
|
|
self.model_args = model_args
|
|
|
|
def gradient_checkpointing_enable(self, **kwargs):
|
|
self.lm.gradient_checkpointing_enable(**kwargs)
|
|
|
|
def forward(self,
|
|
encoder_input_ids, encoder_attention_mask, encoder_labels,
|
|
decoder_input_ids, decoder_attention_mask, decoder_labels):
|
|
|
|
lm_out: MaskedLMOutput = self.lm(
|
|
encoder_input_ids, encoder_attention_mask,
|
|
labels=encoder_labels,
|
|
output_hidden_states=True,
|
|
return_dict=True
|
|
)
|
|
cls_hiddens = lm_out.hidden_states[-1][:, :1] # B 1 D
|
|
|
|
decoder_embedding_output = self.decoder_embeddings(input_ids=decoder_input_ids)
|
|
hiddens = torch.cat([cls_hiddens, decoder_embedding_output[:, 1:]], dim=1)
|
|
|
|
# decoder_position_ids = self.lm.bert.embeddings.position_ids[:, :decoder_input_ids.size(1)]
|
|
# decoder_position_embeddings = self.lm.bert.embeddings.position_embeddings(decoder_position_ids) # B L D
|
|
# query = decoder_position_embeddings + cls_hiddens
|
|
|
|
cls_hiddens = cls_hiddens.expand(hiddens.size(0), hiddens.size(1), hiddens.size(2))
|
|
query = self.decoder_embeddings(inputs_embeds=cls_hiddens)
|
|
|
|
matrix_attention_mask = self.lm.get_extended_attention_mask(
|
|
decoder_attention_mask,
|
|
decoder_attention_mask.shape,
|
|
decoder_attention_mask.device
|
|
)
|
|
|
|
hiddens = self.c_head(query=query,
|
|
key=hiddens,
|
|
value=hiddens,
|
|
attention_mask=matrix_attention_mask)[0]
|
|
pred_scores, loss = self.mlm_loss(hiddens, decoder_labels)
|
|
|
|
return (loss + lm_out.loss,)
|
|
|
|
def mlm_loss(self, hiddens, labels):
|
|
if hasattr(self.lm, 'cls'):
|
|
pred_scores = self.lm.cls(hiddens)
|
|
elif hasattr(self.lm, 'lm_head'):
|
|
pred_scores = self.lm.lm_head(hiddens)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
masked_lm_loss = self.cross_entropy(
|
|
pred_scores.view(-1, self.lm.config.vocab_size),
|
|
labels.view(-1)
|
|
)
|
|
return pred_scores, masked_lm_loss
|
|
|
|
def save_pretrained(self, output_dir: str):
|
|
self.lm.save_pretrained(os.path.join(output_dir, "encoder_model"))
|
|
torch.save(self.state_dict(), os.path.join(output_dir, 'pytorch_model.bin'))
|
|
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls, model_args: ModelArguments,
|
|
*args, **kwargs
|
|
):
|
|
hf_model = AutoModelForMaskedLM.from_pretrained(*args, **kwargs)
|
|
model = cls(hf_model, model_args)
|
|
return model
|