embed-bge-m3/FlagEmbedding/research/Matroyshka_reranker/finetune/compensation/modeling.py

184 lines
7.1 KiB
Python

import copy
import logging
import os
import random
import sys
from dataclasses import dataclass
from typing import Dict, Optional, List, Union
import torch
from torch import nn, Tensor
from transformers import AutoTokenizer
from transformers.file_utils import ModelOutput
import torch.distributed as dist
logger = logging.getLogger(__name__)
@dataclass
class RerankerOutput(ModelOutput):
loss: Optional[Tensor] = None
scores: Optional[Tensor] = None
def last_logit_pool(logits: Tensor,
attention_mask: Tensor) -> Tensor:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return logits[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = logits.shape[0]
return torch.stack([logits[i, sequence_lengths[i]] for i in range(batch_size)], dim=0)
def set_nested_attr(obj, attr, value):
attributes = attr.split('.')
for attribute in attributes[:-1]:
obj = getattr(obj, attribute)
setattr(obj, attributes[-1], value)
def get_nested_attr(obj, attr):
attributes = attr.split('.')
for attribute in attributes:
obj = getattr(obj, attribute)
return obj
class BiEncoderModel(nn.Module):
def __init__(self,
model: None,
tmp_model: None,
tokenizer: AutoTokenizer = None,
compress_method: str = 'mean',
train_batch_size: int = 4,
cutoff_layers: List[int] = [2, 4],
compress_layers: List[int] = [6],
compress_ratios: List[int] = [2],
train_method: str = 'distill'
):
super().__init__()
self.model = model
self.tmp_model = tmp_model
if self.tmp_model is not None:
self.tmp_model_attrs = [i.replace('.weight', '') for i, _ in self.tmp_model.named_parameters()]
self.tokenizer = tokenizer
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
self.pointCE = nn.BCEWithLogitsLoss(reduction='mean')
if self.model.config.pad_token_id is None:
self.model.config.pad_token_id = self.tokenizer.pad_token_id
self.config = self.model.config
self.train_batch_size = train_batch_size
self.compress_method = compress_method
self.yes_loc = self.tokenizer('Yes', add_special_tokens=False)['input_ids'][-1]
self.cutoff_layers = cutoff_layers
self.compress_layers = compress_layers
self.compress_ratios = compress_ratios
self.train_method = train_method
def gradient_checkpointing_enable(self, **kwargs):
self.model.gradient_checkpointing_enable(**kwargs)
def enable_input_require_grads(self, **kwargs):
self.model.enable_input_require_grads(**kwargs)
def encode(self, features, query_lengths, prompt_lengths):
# input('continue?')
if features is None:
return None
outputs = self.model(input_ids=features['input_ids'],
attention_mask=features['attention_mask'],
position_ids=features['position_ids'] if 'position_ids' in features.keys() else None,
output_hidden_states=True,
# compress_layer=random.choice(self.compress_layers),
# compress_layer=[random.choice([0, 1]) * i * 4 for i in range(7)],
compress_layer=[random.choice([0, 1]) * i for i in self.compress_layers],
compress_ratio=random.choice(self.compress_ratios),
cutoff_layers=self.cutoff_layers,
# cutoff_layers=random.choice([9, 12, 15, 18]),
query_lengths=query_lengths,
prompt_lengths=prompt_lengths)
if self.config.layer_wise:
scores = []
for i in range(len(outputs.logits)):
logits = last_logit_pool(outputs.logits[i], outputs.attention_masks[i])
scores.append(logits)
else:
logits = last_logit_pool(outputs.logits, outputs.attention_masks)
scores = logits[:, self.yes_loc]
return scores
def forward(self,
pair: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = None,
query_lengths: List[int] = None,
prompt_lengths: List[int] = None,
teacher_scores: List[int] = None):
# if dist.get_rank() == 0:
# print(self.tokenizer.decode(pair['input_ids'][0]))
ranker_logits = self.encode(pair, query_lengths, prompt_lengths) # (batch_size * num, dim)
if self.training:
if isinstance(ranker_logits, List):
loss = 0
for idx, logits in enumerate(ranker_logits[::-1]):
grouped_logits = logits.view(self.train_batch_size, -1)
target = torch.zeros(self.train_batch_size, device=grouped_logits.device, dtype=torch.long)
loss += self.compute_loss(grouped_logits, target)
else:
grouped_logits = ranker_logits.view(self.train_batch_size, -1)
target = torch.zeros(self.train_batch_size, device=grouped_logits.device, dtype=torch.long)
loss = self.compute_loss(grouped_logits, target)
if self.train_method == 'distill':
teacher_scores = torch.tensor(teacher_scores, device=ranker_logits.device)
teacher_scores = teacher_scores.view(self.train_batch_size, -1)
teacher_targets = torch.softmax(teacher_scores.detach(), dim=-1)
student_scores = ranker_logits.view(
self.train_batch_size,
-1
)
loss += - torch.mean(torch.sum(torch.log_softmax(student_scores, dim=-1) * teacher_targets, dim=-1))
else:
loss = None
# print(loss)
return RerankerOutput(
loss=loss,
scores=ranker_logits,
)
def compute_loss(self, scores, target):
return self.cross_entropy(scores, target)
def save(self, output_dir: str):
if self.tmp_model is None:
state_dict = self.model.state_dict()
state_dict = type(state_dict)(
{k: v.clone().cpu()
for k,
v in state_dict.items()})
self.model.save_pretrained(output_dir, state_dict=state_dict)
else:
os.makedirs(output_dir, exist_ok=True)
state_dict = self.tmp_model.state_dict()
torch.save(state_dict, os.path.join(output_dir, 'tmp_model.pth'))
# torch.save(self.tmp_model, os.path.join(output_dir, 'tmp_model.pth'))
def save_pretrained(self, **kwargs):
if self.tmp_model is None:
return self.model.save_pretrained(**kwargs)
else:
os.makedirs(kwargs['output_dir'], exist_ok=True)
state_dict = self.tmp_model.state_dict()
torch.save(state_dict, os.path.join(kwargs['output_dir'], 'tmp_model.pth'))
return True