embed-bge-m3/FlagEmbedding/research/Reinforced_IR/finetune/retriever/dataset.py

533 lines
21 KiB
Python

import os
import math
import random
import logging
import datasets
import numpy as np
import torch.distributed as dist
from dataclasses import dataclass
from torch.utils.data import Dataset
from transformers import (
PreTrainedTokenizer,
DataCollatorWithPadding,
TrainerCallback,
TrainerState,
TrainerControl
)
from FlagEmbedding.abc.finetune.embedder import AbsEmbedderDataArguments, AbsEmbedderTrainingArguments
from FlagEmbedding.abc.finetune.embedder import AbsEmbedderTrainDataset, AbsEmbedderCollator, AbsEmbedderSameDatasetTrainDataset, AbsEmbedderSameDatasetCollator, EmbedderTrainerCallbackForDataRefresh
logger = logging.getLogger(__name__)
class IREmbedderTrainDataset(AbsEmbedderTrainDataset):
"""Abstract class for training dataset.
Args:
args (AbsEmbedderDataArguments): Data arguments.
tokenizer (PreTrainedTokenizer): Tokenizer to use.
"""
def __init__(
self,
args: AbsEmbedderDataArguments,
tokenizer: PreTrainedTokenizer
):
super().__init__(
args,
tokenizer,
)
def __getitem__(self, item):
data = self.dataset[item]
train_group_size = self.args.train_group_size
query = data['query']
answer = data.get('answer', None)
if self.args.query_instruction_for_retrieval is not None:
query = self.args.query_instruction_format.format(
data['prompt'] if 'prompt' in data else self.args.query_instruction_for_retrieval,
query
)
passages = []
teacher_scores = []
assert isinstance(data['pos'], list) and isinstance(data['neg'], list)
pos_idx = random.choice(list(range(len(data['pos']))))
passages.append(self._shuffle_text(data['pos'][pos_idx]))
neg_all_idx = list(range(len(data['neg'])))
if len(data['neg']) < train_group_size - 1:
num = math.ceil((train_group_size - 1) / len(data['neg']))
neg_idxs = random.sample(neg_all_idx * num, train_group_size - 1)
else:
neg_idxs = random.sample(neg_all_idx, self.args.train_group_size - 1)
for neg_idx in neg_idxs:
passages.append(data['neg'][neg_idx])
if self.args.knowledge_distillation:
assert isinstance(data['pos_scores'], list) and isinstance(data['neg_scores'], list)
teacher_scores.append(data['pos_scores'][pos_idx])
for neg_idx in neg_idxs:
teacher_scores.append(data['neg_scores'][neg_idx])
if not all(isinstance(score, (int, float)) for score in teacher_scores):
raise ValueError(f"pos_score or neg_score must be digit")
else:
teacher_scores = None
if self.args.passage_instruction_for_retrieval is not None:
passages = [
self.args.passage_instruction_format.format(
self.args.passage_instruction_for_retrieval, p
)
for p in passages
]
return query, answer, passages, teacher_scores
@dataclass
class IREmbedderCollator(AbsEmbedderCollator):
"""
The abstract embedder collator.
"""
query_max_len: int = 32
passage_max_len: int = 128
sub_batch_size: int = -1
def __call__(self, features):
queries = [f[0] for f in features]
answers = [f[1] for f in features]
passages = [f[2] for f in features]
teacher_scores = [f[3] for f in features]
if teacher_scores[0] is None:
teacher_scores = None
elif isinstance(teacher_scores[0], list):
teacher_scores = sum(teacher_scores, [])
if isinstance(queries[0], list):
queries = sum(queries, [])
if isinstance(answers[0], list):
answers = sum(answers, [])
if isinstance(passages[0], list):
passages = sum(passages, [])
queries_inputs = self.tokenizer(
queries,
truncation=True,
max_length=self.query_max_len,
return_tensors=None
)
passages_inputs = self.tokenizer(
passages,
truncation=True,
max_length=self.passage_max_len,
return_tensors=None
)
if answers[0] is None and answers[-1] is None:
answers_inputs = self.tokenizer(
answers,
truncation=True,
max_length=self.query_max_len,
return_tensors=None
)
else:
answers_inputs = None
if self.sub_batch_size is None or self.sub_batch_size <= 0:
q_collated = self.tokenizer.pad(
queries_inputs,
padding=self.padding,
max_length=self.query_max_len,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors
)
d_collated = self.tokenizer.pad(
passages_inputs,
padding=self.padding,
max_length=self.passage_max_len,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors
)
if answers_inputs is None:
a_collated = None
else:
a_collated = self.tokenizer.pad(
answers_inputs,
padding=self.padding,
max_length=self.query_max_len,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors
)
else:
batch_size = self.sub_batch_size
q_collated = []
for i in range(0, len(queries_inputs['attention_mask']), batch_size):
start = i
end = min(len(queries_inputs['attention_mask']), i + batch_size)
sub_features = {}
for k, v in queries_inputs.items():
sub_features[k] = v[start:end]
q_collated.append(self.tokenizer.pad(
sub_features,
padding=self.padding,
max_length=self.passage_max_len,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors
))
d_collated = []
for i in range(0, len(passages_inputs['attention_mask']), batch_size):
start = i
end = min(len(passages_inputs['attention_mask']), i + batch_size)
sub_features = {}
for k, v in passages_inputs.items():
sub_features[k] = v[start:end]
d_collated.append(self.tokenizer.pad(
sub_features,
padding=self.padding,
max_length=self.passage_max_len,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors
))
if answers_inputs is None:
a_collated = None
else:
a_collated = []
for i in range(0, len(answers_inputs['attention_mask']), batch_size):
start = i
end = min(len(answers_inputs['attention_mask']), i + batch_size)
sub_features = {}
for k, v in answers_inputs.items():
sub_features[k] = v[start:end]
a_collated.append(self.tokenizer.pad(
sub_features,
padding=self.padding,
max_length=self.passage_max_len,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors
))
return {
"queries": q_collated,
"answers": a_collated,
"passages": d_collated,
"teacher_scores": teacher_scores,
"no_in_batch_neg_flag": False
}
class IREmbedderSameDatasetTrainDataset(AbsEmbedderSameDatasetTrainDataset):
"""Abstract class for training dataset that samples batches from same dataset.
Args:
args (AbsEmbedderDataArguments): Data arguments.
default_batch_size (int): The default batch size for training.
seed (int): Random seed.
tokenizer (PreTrainedTokenizer): Tokenizer to use.
process_index (int, optional): Current process index. Defaults to 0.
num_processes (int, optional): Total number of processes. Defaults to 1.
"""
def __init__(
self,
args: AbsEmbedderDataArguments,
default_batch_size: int,
seed: int,
tokenizer: PreTrainedTokenizer,
process_index: int=0,
num_processes: int=1
):
super().__init__(
args,
default_batch_size,
seed,
tokenizer,
process_index,
num_processes
)
def _shuffle_answer(self, text):
"""shuffle the input text.
Args:
text (str): Input text.
Returns:
str: Shuffled text.
"""
split_text = []
chunk_size = len(text)//3 + 1
for i in range(0, len(text), chunk_size):
split_text.append(text[i:i+chunk_size])
random.shuffle(split_text)
return " ".join(split_text)
def __getitem__(self, _):
batch_indices, no_in_batch_neg_flag = self.batch_datas[self.step] # extend here
cur_batch_size = int(len(batch_indices) / self.num_processes)
batch_indices = batch_indices[self.process_index * cur_batch_size: (self.process_index + 1) * cur_batch_size]
batch_data = self.dataset[batch_indices]
self.step += 1
queries, answers, passages, teacher_scores, teacher_scores_answers = self._create_batch_data(batch_raw_data=batch_data)
return queries, answers, passages, teacher_scores, teacher_scores_answers, no_in_batch_neg_flag
def _create_batch_data(self, batch_raw_data):
"""Create a comple batch of data with queries, documents and teacher scores.
Args:
batch_raw_data (datasets.Dataset): One batch of raw data.
Returns:
List[str]: Queries with instruction format.
List[str]: Documents with instruction format.
List[float]: Teacher scores for model distillation.
"""
queries, answers, passages, teacher_scores, teacher_scores_answers = [], [], [], [], []
train_group_size, data_type = self._get_train_group_size(batch_raw_data)
for i in range(len(batch_raw_data['query'])):
if data_type is not None:
assert batch_raw_data['type'][i] == data_type, f"Data type is not consistent in the same batch"
queries.append(
self.args.query_instruction_format.format(
batch_raw_data['prompt'][i] if 'prompt' in batch_raw_data else self.args.query_instruction_for_retrieval,
batch_raw_data['query'][i]
)
)
if 'answer' in batch_raw_data.keys():
answers.append(
self.args.query_instruction_format.format(
batch_raw_data['prompt'][i] if 'prompt' in batch_raw_data else self.args.query_instruction_for_retrieval,
batch_raw_data['answer'][i]
)
)
# answers[-1] = self._shuffle_answer(answers[-1])
else:
answers.append(None)
tmp_passages = []
pos_idx = random.choice(list(range(len(batch_raw_data['pos'][i]))))
pos = self._shuffle_text(batch_raw_data['pos'][i][pos_idx])
# pos = self._shuffle_answer(batch_raw_data['answer'][i])
# pos = batch_raw_data['answer'][i]
tmp_passages.append(pos)
if train_group_size == 1:
pass
else:
neg_all_idx = list(range(len(batch_raw_data['neg'][i])))
if len(batch_raw_data['neg'][i]) < train_group_size - 1:
num = math.ceil((train_group_size - 1) / len(batch_raw_data['neg'][i]))
neg_idxs = random.sample(neg_all_idx * num, train_group_size - 1)
else:
neg_idxs = random.sample(neg_all_idx, train_group_size - 1)
if self.args.knowledge_distillation:
tmp_scores = [batch_raw_data['neg_scores'][i][neg_idx] for neg_idx in neg_idxs]
tmp_data = sorted([(x, y) for x, y in zip(neg_idxs, tmp_scores)], reverse=True, key=lambda x: x[1])
neg_idxs = [x[0] for x in tmp_data]
for neg_idx in neg_idxs:
tmp_passages.append(batch_raw_data['neg'][i][neg_idx])
# answers.append(batch_raw_data['neg'][i][neg_idx])
if self.args.knowledge_distillation:
if 'pos_scores' in batch_raw_data and batch_raw_data['pos_scores'][i] is not None:
if batch_raw_data['pos_scores'][i][pos_idx] < max(batch_raw_data['neg_scores'][i]):
teacher_scores.append(batch_raw_data['pos_scores'][i][pos_idx])
else:
teacher_scores.append(
batch_raw_data['pos_scores'][i][pos_idx] +
(max(batch_raw_data['neg_scores'][i]) - batch_raw_data['pos_scores'][i][pos_idx]) * 0.2
)
for neg_idx in neg_idxs:
if 'neg_scores' in batch_raw_data and batch_raw_data['neg_scores'][i] is not None:
teacher_scores.append(batch_raw_data['neg_scores'][i][neg_idx])
else:
teacher_scores = None
### add answer knowledge distillation
if self.args.answer_inbatch:
if train_group_size == 1:
pass
else:
neg_all_idx = list(range(len(batch_raw_data['neg_answer'][i])))
if len(batch_raw_data['neg_answer'][i]) < train_group_size - 1:
num = math.ceil((train_group_size - 1) / len(batch_raw_data['neg_answer'][i]))
neg_idxs = random.sample(neg_all_idx * num, train_group_size - 1)
else:
neg_idxs = random.sample(neg_all_idx, train_group_size - 1)
for neg_idx in neg_idxs:
answers.append(batch_raw_data['neg_answer'][i][neg_idx])
# if self.args.knowledge_distillation:
# if 'pos_scores' in batch_raw_data and batch_raw_data['pos_scores'][i] is not None:
# teacher_scores_answers.append(batch_raw_data['pos_scores'][i][pos_idx])
# for neg_idx in neg_idxs:
# if 'neg_scores' in batch_raw_data and batch_raw_data['neg_scores'][i] is not None:
# teacher_scores_answers.append(batch_raw_data['neg_scores'][i][neg_idx])
# else:
# teacher_scores_answers = None
if data_type is not None and data_type in ['symmetric_sts', 'symmetric_clustering']:
tmp_passages = [
self.args.query_instruction_format.format(
batch_raw_data['prompt'][i] if 'prompt' in batch_raw_data else self.args.query_instruction_for_retrieval,
p
) for p in tmp_passages
]
else:
if self.args.passage_instruction_for_retrieval is not None:
tmp_passages = [
self.args.passage_instruction_format.format(
self.args.passage_instruction_for_retrieval, p
) for p in tmp_passages
]
passages.extend(tmp_passages)
if teacher_scores is not None:
if len(teacher_scores) > 0 and len(passages) > 0:
assert len(teacher_scores) == len(passages)
return queries, answers, passages, teacher_scores, teacher_scores_answers
@dataclass
class IREmbedderSameDatasetCollator(AbsEmbedderSameDatasetCollator):
"""
EmbedCollator for SameDataset.
Note that after using this collator, the training_args should be set as:
``training_args.per_device_train_batch_size = 1``
``training_args.dataloader_num_workers = 0 # avoid multi-processing``
"""
query_max_len: int = 32
passage_max_len: int = 128
sub_batch_size: int = -1
def __call__(self, features):
queries = features[0][0]
answers = features[0][1]
passages = features[0][2]
teacher_scores = features[0][3]
teacher_scores_answers = features[0][4]
no_in_batch_neg_flag = features[0][5]
queries_inputs = self.tokenizer(
queries,
truncation=True,
max_length=self.query_max_len,
return_tensors=None
)
if answers[0] is not None:
answers_inputs = self.tokenizer(
answers,
truncation=True,
max_length=self.query_max_len,
return_tensors=None
)
else:
answers_inputs = None
passages_inputs = self.tokenizer(
passages,
truncation=True,
max_length=self.passage_max_len,
return_tensors=None
)
if self.sub_batch_size is None or self.sub_batch_size <= 0:
q_collated = self.tokenizer.pad(
queries_inputs,
padding=self.padding,
max_length=self.query_max_len,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
d_collated = self.tokenizer.pad(
passages_inputs,
padding=self.padding,
max_length=self.passage_max_len,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
if answers_inputs is None:
a_collated = None
else:
a_collated = self.tokenizer.pad(
answers_inputs,
padding=self.padding,
max_length=self.query_max_len,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
else:
batch_size = self.sub_batch_size
q_collated = []
for i in range(0, len(queries_inputs['attention_mask']), batch_size):
start = i
end = min(len(queries_inputs['attention_mask']), i + batch_size)
sub_features = {}
for k, v in queries_inputs.items():
sub_features[k] = v[start:end]
q_collated.append(self.tokenizer.pad(
sub_features,
padding=self.padding,
max_length=self.query_max_len,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
))
d_collated = []
for i in range(0, len(passages_inputs['attention_mask']), batch_size):
start = i
end = min(len(passages_inputs['attention_mask']), i + batch_size)
sub_features = {}
for k, v in passages_inputs.items():
sub_features[k] = v[start:end]
d_collated.append(self.tokenizer.pad(
sub_features,
padding=self.padding,
max_length=self.passage_max_len,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
))
if answers_inputs is None:
a_collated = None
else:
a_collated = []
for i in range(0, len(answers_inputs['attention_mask']), batch_size):
start = i
end = min(len(answers_inputs['attention_mask']), i + batch_size)
sub_features = {}
for k, v in answers_inputs.items():
sub_features[k] = v[start:end]
a_collated.append(self.tokenizer.pad(
sub_features,
padding=self.padding,
max_length=self.query_max_len,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
))
if isinstance(teacher_scores, list) and len(teacher_scores) == 0:
teacher_scores = None
return {
"queries": q_collated,
"answers": a_collated,
"passages": d_collated,
"teacher_scores": teacher_scores,
"teacher_scores_answers": teacher_scores_answers,
"no_in_batch_neg_flag": no_in_batch_neg_flag
}