import logging import torch from torch import nn from transformers import AutoModelForSequenceClassification, PreTrainedModel, TrainingArguments from transformers.modeling_outputs import SequenceClassifierOutput from .arguments import ModelArguments, DataArguments logger = logging.getLogger(__name__) class CrossEncoder(nn.Module): def __init__(self, hf_model: PreTrainedModel, model_args: ModelArguments, data_args: DataArguments, train_args: TrainingArguments): super().__init__() self.hf_model = hf_model self.model_args = model_args self.train_args = train_args self.data_args = data_args self.config = self.hf_model.config self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') self.register_buffer( 'target_label', torch.zeros(self.train_args.per_device_train_batch_size, dtype=torch.long) ) def gradient_checkpointing_enable(self, **kwargs): self.hf_model.gradient_checkpointing_enable(**kwargs) def forward(self, batch): ranker_out: SequenceClassifierOutput = self.hf_model(**batch, return_dict=True) logits = ranker_out.logits if self.training: scores = logits.view( self.train_args.per_device_train_batch_size, self.data_args.train_group_size ) loss = self.cross_entropy(scores, self.target_label) return SequenceClassifierOutput( loss=loss, **ranker_out, ) else: return ranker_out @classmethod def from_pretrained( cls, model_args: ModelArguments, data_args: DataArguments, train_args: TrainingArguments, *args, **kwargs ): hf_model = AutoModelForSequenceClassification.from_pretrained(*args, **kwargs) reranker = cls(hf_model, model_args, data_args, train_args) return reranker def save_pretrained(self, output_dir: str): state_dict = self.hf_model.state_dict() state_dict = type(state_dict)( {k: v.clone().cpu() for k, v in state_dict.items()}) self.hf_model.save_pretrained(output_dir, state_dict=state_dict)