embed-bge-m3/FlagEmbedding/research/llm_reranker/finetune_for_instruction/modeling.py

91 lines
3.1 KiB
Python

import logging
from dataclasses import dataclass
from typing import Dict, Optional, List, Union
import torch
from torch import nn, Tensor
from transformers import AutoTokenizer
from transformers.file_utils import ModelOutput
logger = logging.getLogger(__name__)
@dataclass
class RerankerOutput(ModelOutput):
loss: Optional[Tensor] = None
scores: Optional[Tensor] = None
class BiEncoderModel(nn.Module):
def __init__(self,
model: None,
tokenizer: AutoTokenizer = None,
train_batch_size: int = 4,
):
super().__init__()
self.model = model
self.tokenizer = tokenizer
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
if self.model.config.pad_token_id is None:
self.model.config.pad_token_id = self.tokenizer.pad_token_id
self.config = self.model.config
self.train_batch_size = train_batch_size
self.yes_loc = self.tokenizer('Yes', add_special_tokens=False)['input_ids'][-1]
def gradient_checkpointing_enable(self, **kwargs):
self.model.gradient_checkpointing_enable(**kwargs)
def enable_input_require_grads(self, **kwargs):
self.model.enable_input_require_grads(**kwargs)
def encode(self, features):
# input('continue?')
if features is None:
return None
outputs = self.model(input_ids=features['input_ids'],
attention_mask=features['attention_mask'],
position_ids=features['position_ids'] if 'position_ids' in features.keys() else None,
output_hidden_states=True)
_, max_indices = torch.max(features['labels'], dim=1)
predict_indices = max_indices - 1
logits = [outputs.logits[i, predict_indices[i], :] for i in range(outputs.logits.shape[0])]
logits = torch.stack(logits, dim=0)
scores = logits[:, self.yes_loc]
return scores.contiguous()
def forward(self, pair: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = None):
ranker_logits = self.encode(pair) # (batch_size * num, dim)
if self.training:
grouped_logits = ranker_logits.view(self.train_batch_size, -1)
target = torch.zeros(self.train_batch_size, device=grouped_logits.device, dtype=torch.long)
loss = self.compute_loss(grouped_logits, target)
else:
loss = None
# print(loss)
return RerankerOutput(
loss=loss,
scores=ranker_logits,
)
def compute_loss(self, scores, target):
return self.cross_entropy(scores, target)
def save(self, output_dir: str):
# self.model.save_pretrained(output_dir)
state_dict = self.model.state_dict()
state_dict = type(state_dict)(
{k: v.clone().cpu()
for k,
v in state_dict.items()})
self.model.save_pretrained(output_dir, state_dict=state_dict)
def save_pretrained(self, **kwargs):
self.tokenizer.save_pretrained(**kwargs)
return self.model.save_pretrained(**kwargs)