embed-bge-m3/FlagEmbedding/research/reranker/modeling.py

67 lines
2.2 KiB
Python

import logging
import torch
from torch import nn
from transformers import AutoModelForSequenceClassification, PreTrainedModel, TrainingArguments
from transformers.modeling_outputs import SequenceClassifierOutput
from .arguments import ModelArguments, DataArguments
logger = logging.getLogger(__name__)
class CrossEncoder(nn.Module):
def __init__(self, hf_model: PreTrainedModel, model_args: ModelArguments, data_args: DataArguments,
train_args: TrainingArguments):
super().__init__()
self.hf_model = hf_model
self.model_args = model_args
self.train_args = train_args
self.data_args = data_args
self.config = self.hf_model.config
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
self.register_buffer(
'target_label',
torch.zeros(self.train_args.per_device_train_batch_size, dtype=torch.long)
)
def gradient_checkpointing_enable(self, **kwargs):
self.hf_model.gradient_checkpointing_enable(**kwargs)
def forward(self, batch):
ranker_out: SequenceClassifierOutput = self.hf_model(**batch, return_dict=True)
logits = ranker_out.logits
if self.training:
scores = logits.view(
self.train_args.per_device_train_batch_size,
self.data_args.train_group_size
)
loss = self.cross_entropy(scores, self.target_label)
return SequenceClassifierOutput(
loss=loss,
**ranker_out,
)
else:
return ranker_out
@classmethod
def from_pretrained(
cls, model_args: ModelArguments, data_args: DataArguments, train_args: TrainingArguments,
*args, **kwargs
):
hf_model = AutoModelForSequenceClassification.from_pretrained(*args, **kwargs)
reranker = cls(hf_model, model_args, data_args, train_args)
return reranker
def save_pretrained(self, output_dir: str):
state_dict = self.hf_model.state_dict()
state_dict = type(state_dict)(
{k: v.clone().cpu()
for k,
v in state_dict.items()})
self.hf_model.save_pretrained(output_dir, state_dict=state_dict)