embed-bge-m3/FlagEmbedding/research/llm_dense_retriever/finetune/modeling.py

203 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import sys
from dataclasses import dataclass
from itertools import product
from typing import Dict, Optional, List, Union
import torch
import torch.distributed as dist
from torch import nn, Tensor
from torch.nn import BCEWithLogitsLoss
from tqdm import trange, tqdm
from transformers import AutoModel, AutoTokenizer
from transformers.file_utils import ModelOutput
logger = logging.getLogger(__name__)
@dataclass
class EncoderOutput(ModelOutput):
q_reps: Optional[Tensor] = None
p_reps: Optional[Tensor] = None
loss: Optional[Tensor] = None
scores: Optional[Tensor] = None
class BiEncoderModel(nn.Module):
TRANSFORMER_CLS = AutoModel
def __init__(self,
model: AutoModel = None,
tokenizer: AutoTokenizer = None,
normlized: bool = False,
negatives_cross_device: bool = False,
temperature: float = 1.0,
sub_batch_size: int = -1):
super().__init__()
self.model = model
self.config = model.config
self.tokenizer = tokenizer
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
self.normlized = normlized
self.temperature = temperature
if not normlized:
self.temperature = 1.0
logger.info("reset temperature = 1.0 due to using inner product to compute similarity")
self.negatives_cross_device = negatives_cross_device
if self.negatives_cross_device:
if not dist.is_initialized():
raise ValueError('Distributed training has not been initialized for representation all gather.')
self.process_rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.sub_batch_size = sub_batch_size
def gradient_checkpointing_enable(self, **kwargs):
self.model.gradient_checkpointing_enable(**kwargs)
def enable_input_require_grads(self, **kwargs):
self.model.enable_input_require_grads(**kwargs)
def encode(self, features):
# input('continue?')
if features is None:
return None
if not isinstance(features, list):
if self.sub_batch_size is not None and self.sub_batch_size > 0:
all_p_reps = []
for i in range(0, len(features['attention_mask']), self.sub_batch_size):
end_inx = min(i + self.sub_batch_size, len(features['attention_mask']))
sub_features = {}
for k, v in features.items():
sub_features[k] = v[i:end_inx]
psg_out = self.model(**sub_features, return_dict=True, output_hidden_states=False)
p_reps = psg_out.last_hidden_state[:, -1, :]
all_p_reps.append(p_reps)
all_p_reps = torch.cat(all_p_reps, 0).contiguous()
if self.normlized:
all_p_reps = torch.nn.functional.normalize(all_p_reps, dim=-1)
return all_p_reps.contiguous()
else:
psg_out = self.model(**features, return_dict=True, output_hidden_states=False)
p_reps = psg_out.last_hidden_state[:, -1, :]
if self.normlized:
p_reps = torch.nn.functional.normalize(p_reps, dim=-1)
return p_reps.contiguous()
else:
all_p_reps = []
for sub_features in features:
psg_out = self.model(**sub_features, return_dict=True, output_hidden_states=False)
p_reps = psg_out.last_hidden_state[:, -1, :]
all_p_reps.append(p_reps)
all_p_reps = torch.cat(all_p_reps, 0).contiguous()
if self.normlized:
all_p_reps = torch.nn.functional.normalize(all_p_reps, dim=-1)
return all_p_reps.contiguous()
def compute_similarity(self, q_reps, p_reps):
if len(p_reps.size()) == 2:
return torch.matmul(q_reps, p_reps.transpose(0, 1))
return torch.matmul(q_reps, p_reps.transpose(-2, -1))
def get_local_similarity(self, q_reps, p_reps, all_scores):
indices = torch.arange(0, q_reps.shape[0], device=q_reps.device) * (p_reps.size(0) // q_reps.size(0))
specific_scores = []
for i in range(p_reps.size(0) // q_reps.size(0)):
specific_scores.append(
all_scores[torch.arange(q_reps.size(0), device=q_reps.device), indices + i]
)
return torch.stack(specific_scores, dim=1)
def compute_local_similarity(self, q_reps, p_reps):
all_scores = self.compute_similarity(q_reps, p_reps)
result = self.get_local_similarity(q_reps, p_reps, all_scores)
return result
def forward(self,
query: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = None,
passage: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = None,
messages: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = None,
teacher_scores: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = None):
q_reps = self.encode(query) # (batch_size, dim)
p_reps = self.encode(passage) # (batch_size * num, dim)
if self.training:
if messages[0] == 'normal':
if self.negatives_cross_device:
q_reps = self._dist_gather_tensor(q_reps)
p_reps = self._dist_gather_tensor(p_reps)
scores = self.compute_similarity(q_reps, p_reps)
scores = scores / self.temperature
scores = scores.view(q_reps.size(0), -1)
target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
target = target * (p_reps.size(0) // q_reps.size(0))
loss = self.compute_cross_entropy_loss(scores, target) # 同批内除了正样本以外的均为负样本
if teacher_scores is not None:
q_len = q_reps.shape[0]
if self.negatives_cross_device:
q_len = q_len // self.world_size
teacher_scores = torch.tensor(teacher_scores, device=q_reps.device)
teacher_scores = teacher_scores.view(q_len, -1)
teacher_targets = torch.softmax(teacher_scores.detach(), dim=-1)
student_scores = self.get_local_similarity(q_reps, p_reps, scores)
student_scores = student_scores
student_scores = student_scores.view(q_reps.size(0), -1)
if self.negatives_cross_device:
student_scores = student_scores[q_len * self.process_rank: q_len * (self.process_rank + 1)]
distill_loss = - torch.mean(
torch.sum(torch.log_softmax(student_scores, dim=-1) * teacher_targets, dim=-1))
loss += distill_loss
else:
scores = self.compute_local_similarity(q_reps, p_reps)
scores = scores / self.temperature
scores = scores.view(q_reps.size(0), -1)
# print(scores)
target = torch.zeros(scores.size(0), device=scores.device, dtype=torch.long)
loss = self.compute_cross_entropy_loss(scores, target) # 同批内除了正样本以外的均为负样本
else:
scores = self.compute_similarity(q_reps, p_reps)
loss = None
# print(loss)
return EncoderOutput(
loss=loss,
scores=scores,
q_reps=q_reps,
p_reps=p_reps,
)
def compute_cross_entropy_loss(self, scores, target):
return self.cross_entropy(scores, target)
def _dist_gather_tensor(self, t: Optional[torch.Tensor]):
if t is None:
return None
t = t.contiguous()
all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
dist.all_gather(all_tensors, t)
all_tensors[self.process_rank] = t # 给当前进程的q和doc加上梯度,当前的q对其他的d更新当前的d对其他的q更新
all_tensors = torch.cat(all_tensors, dim=0)
return all_tensors
def save(self, output_dir: str):
state_dict = self.model.state_dict()
state_dict = type(state_dict)(
{k: v.clone().cpu()
for k,
v in state_dict.items()})
self.model.save_pretrained(output_dir, state_dict=state_dict)