embed-bge-m3/FlagEmbedding/research/Matroyshka_reranker/finetune/self_distillation/modeling.py

217 lines
9.8 KiB
Python

import logging
import random
from dataclasses import dataclass
from typing import Dict, Optional, List, Union
import torch
from torch import nn, Tensor
from transformers import AutoTokenizer
from transformers.file_utils import ModelOutput
logger = logging.getLogger(__name__)
@dataclass
class RerankerOutput(ModelOutput):
loss: Optional[Tensor] = None
scores: Optional[Tensor] = None
def last_logit_pool(logits: Tensor,
attention_mask: Tensor) -> Tensor:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return logits[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = logits.shape[0]
return torch.stack([logits[i, sequence_lengths[i]] for i in range(batch_size)], dim=0)
class BiEncoderModel(nn.Module):
def __init__(self,
model: None,
tokenizer: AutoTokenizer = None,
compress_method: str = 'mean',
train_batch_size: int = 4,
cutoff_layers: List[int] = [2, 4],
compress_layers: List[int] = [6],
compress_ratios: List[int] = [2],
train_method: str = 'distill'
):
super().__init__()
self.model = model
self.tokenizer = tokenizer
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
if self.model.config.pad_token_id is None:
self.model.config.pad_token_id = self.tokenizer.pad_token_id
self.config = self.model.config
self.train_batch_size = train_batch_size
self.compress_method = compress_method
self.yes_loc = self.tokenizer('Yes', add_special_tokens=False)['input_ids'][-1]
self.cutoff_layers = cutoff_layers
self.compress_layers = compress_layers
self.compress_ratios = compress_ratios
self.train_method = train_method
def gradient_checkpointing_enable(self, **kwargs):
self.model.gradient_checkpointing_enable(**kwargs)
def enable_input_require_grads(self, **kwargs):
self.model.enable_input_require_grads(**kwargs)
def encode(self, features, query_lengths, prompt_lengths):
if features is None:
return None
outputs = self.model(input_ids=features['input_ids'],
attention_mask=features['attention_mask'],
position_ids=features['position_ids'] if 'position_ids' in features.keys() else None,
output_hidden_states=True,
# compress_layer=random.choice(self.compress_layers),
compress_layer=[random.choice([0, 1]) * i for i in self.compress_layers],
compress_ratio=random.choice(self.compress_ratios),
cutoff_layers=self.cutoff_layers,
query_lengths=query_lengths,
prompt_lengths=prompt_lengths)
if self.config.layer_wise:
scores = []
for i in range(len(outputs.logits)):
logits = last_logit_pool(outputs.logits[i], outputs.attention_masks[i])
scores.append(logits)
else:
logits = last_logit_pool(outputs.logits, outputs.attention_masks)
scores = logits[:, self.yes_loc]
return scores
def encode_full(self, features, query_lengths, prompt_lengths):
if features is None:
return None
outputs = self.model(input_ids=features['input_ids'],
attention_mask=features['attention_mask'],
position_ids=features['position_ids'] if 'position_ids' in features.keys() else None,
output_hidden_states=True,
# compress_layer=random.choice(self.compress_layers),
compress_layer=[random.choice([0, 1]) * i for i in self.compress_layers],
compress_ratio=1,
cutoff_layers=self.cutoff_layers,
query_lengths=query_lengths,
prompt_lengths=prompt_lengths)
if self.config.layer_wise:
scores = []
for i in range(len(outputs.logits)):
logits = last_logit_pool(outputs.logits[i], outputs.attention_masks[i])
scores.append(logits)
else:
logits = last_logit_pool(outputs.logits, outputs.attention_masks)
scores = logits[:, self.yes_loc]
return scores
def forward(self,
pair: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = None,
query_lengths: List[int] = None,
prompt_lengths: List[int] = None,
teacher_scores: List[int] = None):
ranker_logits = self.encode(pair, query_lengths, prompt_lengths) # (batch_size * num, dim)
if '_layer' in self.train_method:
full_ranker_logits = self.encode_full(pair, query_lengths, prompt_lengths)
else:
full_ranker_logits = True
if self.training:
if isinstance(ranker_logits, List):
loss = 0
if 'distill' in self.train_method:
teacher_scores = torch.tensor(teacher_scores, device=ranker_logits[0].device)
teacher_scores = teacher_scores.view(self.train_batch_size, -1)
teacher_targets = torch.softmax(teacher_scores.detach(), dim=-1)
else:
teacher_scores = ranker_logits[-1].view(self.train_batch_size, -1)
teacher_targets = torch.softmax(teacher_scores.detach(), dim=-1)
teacher_targets_new = None
for idx, logits in enumerate(ranker_logits[::-1]):
grouped_logits = logits.view(self.train_batch_size, -1)
target = torch.zeros(self.train_batch_size, device=grouped_logits.device, dtype=torch.long)
loss += self.compute_loss(grouped_logits, target)
if 'distill' in self.train_method:
student_scores = logits.view(
self.train_batch_size,
-1
)
if full_ranker_logits is not None:
student_scores_full = full_ranker_logits[::-1][idx].view(
self.train_batch_size,
-1
)
else:
student_scores_full = None
if 'teacher' in self.train_method:
loss += - torch.mean(
torch.sum(torch.log_softmax(student_scores, dim=-1) * teacher_targets, dim=-1))
elif idx == 0 and student_scores_full is not None:
loss += - torch.mean(
torch.sum(torch.log_softmax(student_scores, dim=-1) * teacher_targets, dim=-1))
teacher_targets_new = torch.softmax(student_scores_full.detach(), dim=-1)
continue
if 'final_layer' in self.train_method:
if teacher_targets_new is not None:
loss += - torch.mean(
torch.sum(torch.log_softmax(student_scores, dim=-1) * teacher_targets_new, dim=-1))
elif 'last_layer' in self.train_method:
if teacher_targets_new is not None:
loss += - torch.mean(
torch.sum(torch.log_softmax(student_scores, dim=-1) * teacher_targets_new, dim=-1))
teacher_targets_new = torch.softmax(student_scores_full.detach(), dim=-1)
elif 'fix_layer' in self.train_method:
if teacher_targets_new is not None:
loss += - torch.mean(
torch.sum(torch.log_softmax(student_scores, dim=-1) * teacher_targets_new, dim=-1))
if idx % 8 == 0:
teacher_targets_new = torch.softmax(student_scores_full.detach(), dim=-1)
else:
grouped_logits = ranker_logits.view(self.train_batch_size, -1)
target = torch.zeros(self.train_batch_size, device=grouped_logits.device, dtype=torch.long)
loss = self.compute_loss(grouped_logits, target)
if self.train_method == 'distill':
teacher_scores = torch.tensor(teacher_scores, device=ranker_logits.device)
teacher_scores = teacher_scores.view(self.train_batch_size, -1)
teacher_targets = torch.softmax(teacher_scores.detach(), dim=-1)
student_scores = ranker_logits.view(
self.train_batch_size,
-1
)
loss += - torch.mean(torch.sum(torch.log_softmax(student_scores, dim=-1) * teacher_targets, dim=-1))
else:
loss = None
# print(loss)
return RerankerOutput(
loss=loss,
scores=ranker_logits,
)
def compute_loss(self, scores, target):
return self.cross_entropy(scores, target)
def save(self, output_dir: str):
# self.model.save_pretrained(output_dir)
state_dict = self.model.state_dict()
state_dict = type(state_dict)(
{k: v.clone().cpu()
for k,
v in state_dict.items()})
self.model.save_pretrained(output_dir, state_dict=state_dict)
def save_pretrained(self, **kwargs):
return self.model.save_pretrained(**kwargs)