embed-bge-m3/FlagEmbedding/research/Reinforced_IR/finetune/retriever/modeling.py

215 lines
9.7 KiB
Python

import logging
import torch
from torch import nn, Tensor
import torch.nn.functional as F
import torch.distributed as dist
from transformers import AutoTokenizer
from transformers.file_utils import ModelOutput
import logging
from dataclasses import dataclass
from abc import ABC, abstractmethod
from typing import Dict, Optional, List, Union
import torch
from transformers import AutoModel, AutoTokenizer
from FlagEmbedding.abc.finetune.embedder.AbsModeling import AbsEmbedderModel, EmbedderOutput
from FlagEmbedding.finetune.embedder.encoder_only.base.modeling import BiEncoderOnlyEmbedderModel
logger = logging.getLogger(__name__)
class BiIREmbedderModel(BiEncoderOnlyEmbedderModel):
"""Embedder class for encoder only model.
Args:
base_model (AutoModel): The base model to train on.
tokenizer (AutoTokenizer, optional): The tokenizer to use. Defaults to ``None``.
negatives_cross_device (bool, optional): If True, will compute cross devices negative loss. Defaults to ``False``.
temperature (float, optional): Temperature to control the scale of scores. Defaults to ``1.0``.
sub_batch_size (int, optional): Sub-batch size during encoding. If negative, will not split to sub-batch.
Defaults to ``-1``.
kd_loss_type (str, optional): Type of knowledge distillation loss. Defaults to ``"kl_div"``.
sentence_pooling_method (str, optional): Pooling method to get sentence embedding. Defaults to ``'cls'``.
normalize_embeddings (bool, optional): If True, normalize the embedding vector. Defaults to ``False``.
"""
TRANSFORMER_CLS = AutoModel
def __init__(
self,
base_model: AutoModel,
tokenizer: AutoTokenizer = None,
negatives_cross_device: bool = False,
temperature: float = 1.0,
answer_temperature: float = None,
sub_batch_size: int = -1,
kd_loss_type: str = 'kl_div',
sentence_pooling_method: str = 'cls',
normalize_embeddings: bool = False,
normalize_answer: bool = True,
training_type: str = 'retrieval_answer'
):
super().__init__(
base_model,
tokenizer=tokenizer,
negatives_cross_device=negatives_cross_device,
temperature=temperature,
sub_batch_size=sub_batch_size,
kd_loss_type=kd_loss_type,
sentence_pooling_method=sentence_pooling_method,
normalize_embeddings=normalize_embeddings
)
self.sentence_pooling_method = sentence_pooling_method
self.normalize_embeddings = normalize_embeddings
self.cross_entropy = torch.nn.CrossEntropyLoss(reduction='mean')
self.training_type = training_type
if answer_temperature is not None:
self.answer_temperature = answer_temperature
else:
self.answer_temperature = 0.05
self.normalize_answer = normalize_answer
def forward(
self,
queries: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = None,
answers: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = None,
passages: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = None,
teacher_scores: Union[None, List[float]] = None,
teacher_scores_answers: Union[None, List[float]] = None,
no_in_batch_neg_flag: bool = False,
):
"""The computation performed at every call.
Args:
queries (Union[Dict[str, Tensor], List[Dict[str, Tensor]]], optional): Input queries. Defaults to ``None``.
passages (Union[Dict[str, Tensor], List[Dict[str, Tensor]]], optional): Input passages. Defaults to ``None``.
teacher_scores (Union[None, List[float]], optional): Teacher scores for distillation. Defaults to ``None``.
no_in_batch_neg_flag (bool, optional): If True, use no in-batch negatives and no cross-device negatives. Defaults to ``False``.
Returns:
EmbedderOutput: Output of the forward call of model.
"""
q_reps = self.encode(queries) # (batch_size, dim)
p_reps = self.encode(passages) # (batch_size * group_size, dim)
if 'answer' in self.training_type or 'passage' in self.training_type:
a_reps = self.encode(answers)
group_size = p_reps.size(0) // q_reps.size(0)
if self.training:
if teacher_scores is not None:
teacher_scores = torch.tensor(teacher_scores, device=q_reps.device)
# teacher_scores = (teacher_scores + 3) * 2
# teacher_scores = - teacher_scores.reciprocal() / 0.2 # / self.temperature
# print(teacher_scores)
teacher_scores = teacher_scores.view(q_reps.size(0), -1).detach() # (batch_size, group_size)
# print(teacher_scores)
teacher_targets = F.softmax(teacher_scores, dim=-1) # (batch_size, group_size)
# print(teacher_targets)
else:
teacher_targets = None
if no_in_batch_neg_flag:
compute_loss_func = self._compute_no_in_batch_neg_loss
else:
if self.negatives_cross_device:
compute_loss_func = self._compute_cross_device_neg_loss
else:
compute_loss_func = self._compute_in_batch_neg_loss
loss = 0
if self.normalize_answer:
current_norm = torch.norm(a_reps, p=2, dim=1)
mse_loss = F.mse_loss(current_norm, torch.full_like(current_norm, 1.0))
loss += mse_loss
if 'retrieval' in self.training_type:
scores, q_loss = compute_loss_func(q_reps, p_reps, teacher_targets=teacher_targets)
loss += q_loss
if 'answer' in self.training_type:
tmp_temperature = self.temperature
self.temperature = self.answer_temperature
_, a_loss = compute_loss_func(q_reps, a_reps)
self.temperature = tmp_temperature
loss += a_loss
# if self.process_rank == 0:
# print('The norm of queries:', torch.norm(q_reps, dim=1).tolist()[:10])
# print('The norm of answer:', torch.norm(a_reps, dim=1).tolist()[:10])
# print('query passage scores:', torch.matmul(q_reps, p_reps.t())[:10])
# print('answer passage scores:', torch.matmul(a_reps, p_reps.t())[:10])
if 'passage' in self.training_type:
_, p_loss = compute_loss_func(a_reps, p_reps)
loss += 0.25 * p_loss
# if self.process_rank == 0:
# print('The norm of queries:', torch.norm(q_reps, dim=1).tolist()[:10])
# print('The norm of answer:', torch.norm(a_reps, dim=1).tolist()[:10])
# print('query passage scores:', torch.matmul(q_reps, p_reps.t())[:10])
# print('answer passage scores:', torch.matmul(a_reps, p_reps.t())[:10])
else:
loss = None
return EmbedderOutput(
loss=loss,
)
@staticmethod
def distill_loss(kd_loss_type, teacher_targets, student_scores, group_size=None):
"""Compute the distillation loss.
Args:
kd_loss_type (str): Type of knowledge distillation loss, supports "kl_div" and "m3_kd_loss".
teacher_targets (torch.Tensor): Targets from the teacher model.
student_scores (torch.Tensor): Score of student model.
group_size (int, optional): Number of groups for . Defaults to ``None``.
Raises:
ValueError: Invalid kd_loss_type
Returns:
torch.Tensor: A scalar of computed distillation loss.
"""
if kd_loss_type == 'kl_div':
# teacher_targets: (batch_size, group_size) / (world_size * batch_size, group_size)
# student_scores: (batch_size, group_size) / (world_size * batch_size, group_size)
return - torch.mean(
torch.sum(torch.log_softmax(student_scores, dim=-1) * teacher_targets, dim=-1)
)
elif kd_loss_type == 'm3_kd_loss':
# teacher_targets: (batch_size, group_size) / (world_size * batch_size, group_size)
# student_scores: (batch_size, batch_size * group_size) / (world_size * batch_size, world_size * batch_size * group_size)
labels = torch.arange(student_scores.size(0), device=student_scores.device, dtype=torch.long)
labels = labels * group_size
loss = 0
mask = torch.zeros_like(student_scores)
for i in range(group_size):
# for i in range(2):
temp_target = labels + i
temp_scores = student_scores + mask
# temp_loss = F.cross_entropy(temp_scores, temp_target, reduction="none") # B
# loss += torch.mean(teacher_targets[:, i] * temp_loss)
# print(teacher_targets[:, i])
temp_loss = F.cross_entropy(temp_scores, temp_target, reduction="mean")
loss += temp_loss
# break
mask = torch.scatter(mask, dim=-1, index=temp_target.unsqueeze(-1),
value=torch.finfo(student_scores.dtype).min)
return loss / group_size
else:
raise ValueError(f"Invalid kd_loss_type: {kd_loss_type}")
def save(self, output_dir: str):
"""Save the model to the directory.
Args:
output_dir (str): Directory for saving the model.
"""
state_dict = self.model.state_dict()
state_dict = type(state_dict)(
{k: v.clone().cpu()
for k,
v in state_dict.items()})
self.model.save_pretrained(output_dir, state_dict=state_dict)