166 lines
6.4 KiB
Python
166 lines
6.4 KiB
Python
import logging
|
||
import sys
|
||
from dataclasses import dataclass
|
||
from typing import Dict, Optional, List, Union
|
||
|
||
import torch
|
||
import torch.distributed as dist
|
||
from torch import nn, Tensor
|
||
from tqdm import trange, tqdm
|
||
from transformers import AutoModel, AutoTokenizer
|
||
from transformers.file_utils import ModelOutput
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class EncoderOutput(ModelOutput):
|
||
q_reps: Optional[Tensor] = None
|
||
p_reps: Optional[Tensor] = None
|
||
loss: Optional[Tensor] = None
|
||
scores: Optional[Tensor] = None
|
||
|
||
|
||
class BiEncoderModel(nn.Module):
|
||
TRANSFORMER_CLS = AutoModel
|
||
|
||
def __init__(self,
|
||
model: AutoModel = None,
|
||
tokenizer: AutoTokenizer = None,
|
||
normlized: bool = False,
|
||
negatives_cross_device: bool = False,
|
||
temperature: float = 1.0,
|
||
sub_batch_size: int = -1
|
||
):
|
||
super().__init__()
|
||
self.model = model
|
||
self.config = model.config
|
||
self.tokenizer = tokenizer
|
||
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
|
||
|
||
self.normlized = normlized
|
||
self.temperature = temperature
|
||
if not normlized:
|
||
self.temperature = 1.0
|
||
logger.info("reset temperature = 1.0 due to using inner product to compute similarity")
|
||
|
||
self.negatives_cross_device = negatives_cross_device
|
||
if self.negatives_cross_device:
|
||
if not dist.is_initialized():
|
||
raise ValueError('Distributed training has not been initialized for representation all gather.')
|
||
# logger.info("Run in a single GPU, set negatives_cross_device=False")
|
||
# self.negatives_cross_device = False
|
||
# else:
|
||
self.process_rank = dist.get_rank()
|
||
self.world_size = dist.get_world_size()
|
||
|
||
self.sub_batch_size = sub_batch_size
|
||
|
||
def gradient_checkpointing_enable(self, **kwargs):
|
||
self.model.gradient_checkpointing_enable(**kwargs)
|
||
|
||
def enable_input_require_grads(self, **kwargs):
|
||
self.model.enable_input_require_grads(**kwargs)
|
||
|
||
def encode(self, features):
|
||
# input('continue?')
|
||
if features is None:
|
||
return None
|
||
if not isinstance(features, list):
|
||
if self.sub_batch_size is not None and self.sub_batch_size > 0:
|
||
all_p_reps = []
|
||
for i in range(0, len(features['attention_mask']), self.sub_batch_size):
|
||
end_inx = min(i + self.sub_batch_size, len(features['attention_mask']))
|
||
sub_features = {}
|
||
for k, v in features.items():
|
||
sub_features[k] = v[i:end_inx]
|
||
psg_out = self.model(**sub_features, return_dict=True, output_hidden_states=True)
|
||
### modify
|
||
p_reps = psg_out.hidden_states[-1][:, -8:, :]
|
||
p_reps = torch.mean(p_reps, dim=1)
|
||
all_p_reps.append(p_reps)
|
||
all_p_reps = torch.cat(all_p_reps, 0).contiguous()
|
||
if self.normlized:
|
||
all_p_reps = torch.nn.functional.normalize(all_p_reps, dim=-1)
|
||
return all_p_reps.contiguous()
|
||
else:
|
||
psg_out = self.model(**features, return_dict=True, output_hidden_states=True)
|
||
p_reps = psg_out.hidden_states[-1][:, -8:, :]
|
||
p_reps = torch.mean(p_reps, dim=1)
|
||
if self.normlized:
|
||
p_reps = torch.nn.functional.normalize(p_reps, dim=-1)
|
||
return p_reps.contiguous()
|
||
else:
|
||
all_p_reps = []
|
||
for sub_features in features:
|
||
psg_out = self.model(**sub_features, return_dict=True, output_hidden_states=True)
|
||
### modify
|
||
p_reps = psg_out.hidden_states[-1][:, -8:, :]
|
||
p_reps = torch.mean(p_reps, dim=1)
|
||
all_p_reps.append(p_reps)
|
||
all_p_reps = torch.cat(all_p_reps, 0).contiguous()
|
||
if self.normlized:
|
||
all_p_reps = torch.nn.functional.normalize(all_p_reps, dim=-1)
|
||
return all_p_reps.contiguous()
|
||
|
||
|
||
def compute_similarity(self, q_reps, p_reps):
|
||
if len(p_reps.size()) == 2:
|
||
return torch.matmul(q_reps, p_reps.transpose(0, 1))
|
||
return torch.matmul(q_reps, p_reps.transpose(-2, -1))
|
||
|
||
def forward(self, query: Union[Dict[str, Tensor], List[Dict[str, Tensor]]]= None, passage: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = None):
|
||
# torch.cuda.empty_cache()
|
||
# if self.process_rank == 1:
|
||
# print(query)
|
||
q_reps = self.encode(query) # (batch_size, dim)
|
||
p_reps = self.encode(passage) # (batch_size * num, dim)
|
||
|
||
if self.training:
|
||
if self.negatives_cross_device:
|
||
q_reps = self._dist_gather_tensor(q_reps)
|
||
p_reps = self._dist_gather_tensor(p_reps)
|
||
|
||
scores = self.compute_similarity(q_reps, p_reps)
|
||
scores = scores / self.temperature
|
||
scores = scores.view(q_reps.size(0), -1)
|
||
|
||
target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
|
||
target = target * (p_reps.size(0) // q_reps.size(0))
|
||
loss = self.compute_loss(scores, target) # 同批内除了正样本以外的均为负样本
|
||
|
||
else:
|
||
scores = self.compute_similarity(q_reps, p_reps)
|
||
loss = None
|
||
|
||
# print(loss)
|
||
return EncoderOutput(
|
||
loss=loss,
|
||
scores=scores,
|
||
q_reps=q_reps,
|
||
p_reps=p_reps,
|
||
)
|
||
|
||
def compute_loss(self, scores, target):
|
||
return self.cross_entropy(scores, target)
|
||
|
||
def _dist_gather_tensor(self, t: Optional[torch.Tensor]):
|
||
if t is None:
|
||
return None
|
||
t = t.contiguous()
|
||
|
||
all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
|
||
dist.all_gather(all_tensors, t)
|
||
all_tensors[self.process_rank] = t # 给当前进程的q和doc加上梯度,当前的q对其他的d,更新;当前的d对其他的q,更新
|
||
all_tensors = torch.cat(all_tensors, dim=0)
|
||
|
||
return all_tensors
|
||
|
||
def save(self, output_dir: str):
|
||
state_dict = self.model.state_dict()
|
||
state_dict = type(state_dict)(
|
||
{k: v.clone().cpu()
|
||
for k,
|
||
v in state_dict.items()})
|
||
self.model.save_pretrained(output_dir, state_dict=state_dict)
|