533 lines
21 KiB
Python
533 lines
21 KiB
Python
import os
|
|
import math
|
|
import random
|
|
import logging
|
|
import datasets
|
|
import numpy as np
|
|
import torch.distributed as dist
|
|
from dataclasses import dataclass
|
|
from torch.utils.data import Dataset
|
|
from transformers import (
|
|
PreTrainedTokenizer,
|
|
DataCollatorWithPadding,
|
|
TrainerCallback,
|
|
TrainerState,
|
|
TrainerControl
|
|
)
|
|
|
|
from FlagEmbedding.abc.finetune.embedder import AbsEmbedderDataArguments, AbsEmbedderTrainingArguments
|
|
from FlagEmbedding.abc.finetune.embedder import AbsEmbedderTrainDataset, AbsEmbedderCollator, AbsEmbedderSameDatasetTrainDataset, AbsEmbedderSameDatasetCollator, EmbedderTrainerCallbackForDataRefresh
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class IREmbedderTrainDataset(AbsEmbedderTrainDataset):
|
|
"""Abstract class for training dataset.
|
|
|
|
Args:
|
|
args (AbsEmbedderDataArguments): Data arguments.
|
|
tokenizer (PreTrainedTokenizer): Tokenizer to use.
|
|
"""
|
|
def __init__(
|
|
self,
|
|
args: AbsEmbedderDataArguments,
|
|
tokenizer: PreTrainedTokenizer
|
|
):
|
|
super().__init__(
|
|
args,
|
|
tokenizer,
|
|
)
|
|
|
|
def __getitem__(self, item):
|
|
data = self.dataset[item]
|
|
train_group_size = self.args.train_group_size
|
|
|
|
query = data['query']
|
|
answer = data.get('answer', None)
|
|
if self.args.query_instruction_for_retrieval is not None:
|
|
query = self.args.query_instruction_format.format(
|
|
data['prompt'] if 'prompt' in data else self.args.query_instruction_for_retrieval,
|
|
query
|
|
)
|
|
|
|
passages = []
|
|
teacher_scores = []
|
|
|
|
assert isinstance(data['pos'], list) and isinstance(data['neg'], list)
|
|
|
|
pos_idx = random.choice(list(range(len(data['pos']))))
|
|
passages.append(self._shuffle_text(data['pos'][pos_idx]))
|
|
|
|
neg_all_idx = list(range(len(data['neg'])))
|
|
if len(data['neg']) < train_group_size - 1:
|
|
num = math.ceil((train_group_size - 1) / len(data['neg']))
|
|
neg_idxs = random.sample(neg_all_idx * num, train_group_size - 1)
|
|
else:
|
|
neg_idxs = random.sample(neg_all_idx, self.args.train_group_size - 1)
|
|
for neg_idx in neg_idxs:
|
|
passages.append(data['neg'][neg_idx])
|
|
|
|
if self.args.knowledge_distillation:
|
|
assert isinstance(data['pos_scores'], list) and isinstance(data['neg_scores'], list)
|
|
teacher_scores.append(data['pos_scores'][pos_idx])
|
|
for neg_idx in neg_idxs:
|
|
teacher_scores.append(data['neg_scores'][neg_idx])
|
|
if not all(isinstance(score, (int, float)) for score in teacher_scores):
|
|
raise ValueError(f"pos_score or neg_score must be digit")
|
|
else:
|
|
teacher_scores = None
|
|
|
|
if self.args.passage_instruction_for_retrieval is not None:
|
|
passages = [
|
|
self.args.passage_instruction_format.format(
|
|
self.args.passage_instruction_for_retrieval, p
|
|
)
|
|
for p in passages
|
|
]
|
|
|
|
return query, answer, passages, teacher_scores
|
|
|
|
@dataclass
|
|
class IREmbedderCollator(AbsEmbedderCollator):
|
|
"""
|
|
The abstract embedder collator.
|
|
"""
|
|
query_max_len: int = 32
|
|
passage_max_len: int = 128
|
|
sub_batch_size: int = -1
|
|
|
|
def __call__(self, features):
|
|
queries = [f[0] for f in features]
|
|
answers = [f[1] for f in features]
|
|
passages = [f[2] for f in features]
|
|
teacher_scores = [f[3] for f in features]
|
|
if teacher_scores[0] is None:
|
|
teacher_scores = None
|
|
elif isinstance(teacher_scores[0], list):
|
|
teacher_scores = sum(teacher_scores, [])
|
|
|
|
if isinstance(queries[0], list):
|
|
queries = sum(queries, [])
|
|
if isinstance(answers[0], list):
|
|
answers = sum(answers, [])
|
|
if isinstance(passages[0], list):
|
|
passages = sum(passages, [])
|
|
|
|
queries_inputs = self.tokenizer(
|
|
queries,
|
|
truncation=True,
|
|
max_length=self.query_max_len,
|
|
return_tensors=None
|
|
)
|
|
passages_inputs = self.tokenizer(
|
|
passages,
|
|
truncation=True,
|
|
max_length=self.passage_max_len,
|
|
return_tensors=None
|
|
)
|
|
if answers[0] is None and answers[-1] is None:
|
|
answers_inputs = self.tokenizer(
|
|
answers,
|
|
truncation=True,
|
|
max_length=self.query_max_len,
|
|
return_tensors=None
|
|
)
|
|
else:
|
|
answers_inputs = None
|
|
|
|
if self.sub_batch_size is None or self.sub_batch_size <= 0:
|
|
q_collated = self.tokenizer.pad(
|
|
queries_inputs,
|
|
padding=self.padding,
|
|
max_length=self.query_max_len,
|
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
|
return_tensors=self.return_tensors
|
|
)
|
|
d_collated = self.tokenizer.pad(
|
|
passages_inputs,
|
|
padding=self.padding,
|
|
max_length=self.passage_max_len,
|
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
|
return_tensors=self.return_tensors
|
|
)
|
|
if answers_inputs is None:
|
|
a_collated = None
|
|
else:
|
|
a_collated = self.tokenizer.pad(
|
|
answers_inputs,
|
|
padding=self.padding,
|
|
max_length=self.query_max_len,
|
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
|
return_tensors=self.return_tensors
|
|
)
|
|
else:
|
|
batch_size = self.sub_batch_size
|
|
|
|
q_collated = []
|
|
for i in range(0, len(queries_inputs['attention_mask']), batch_size):
|
|
start = i
|
|
end = min(len(queries_inputs['attention_mask']), i + batch_size)
|
|
sub_features = {}
|
|
for k, v in queries_inputs.items():
|
|
sub_features[k] = v[start:end]
|
|
q_collated.append(self.tokenizer.pad(
|
|
sub_features,
|
|
padding=self.padding,
|
|
max_length=self.passage_max_len,
|
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
|
return_tensors=self.return_tensors
|
|
))
|
|
|
|
d_collated = []
|
|
for i in range(0, len(passages_inputs['attention_mask']), batch_size):
|
|
start = i
|
|
end = min(len(passages_inputs['attention_mask']), i + batch_size)
|
|
sub_features = {}
|
|
|
|
for k, v in passages_inputs.items():
|
|
sub_features[k] = v[start:end]
|
|
d_collated.append(self.tokenizer.pad(
|
|
sub_features,
|
|
padding=self.padding,
|
|
max_length=self.passage_max_len,
|
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
|
return_tensors=self.return_tensors
|
|
))
|
|
|
|
if answers_inputs is None:
|
|
a_collated = None
|
|
else:
|
|
a_collated = []
|
|
for i in range(0, len(answers_inputs['attention_mask']), batch_size):
|
|
start = i
|
|
end = min(len(answers_inputs['attention_mask']), i + batch_size)
|
|
sub_features = {}
|
|
for k, v in answers_inputs.items():
|
|
sub_features[k] = v[start:end]
|
|
a_collated.append(self.tokenizer.pad(
|
|
sub_features,
|
|
padding=self.padding,
|
|
max_length=self.passage_max_len,
|
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
|
return_tensors=self.return_tensors
|
|
))
|
|
|
|
return {
|
|
"queries": q_collated,
|
|
"answers": a_collated,
|
|
"passages": d_collated,
|
|
"teacher_scores": teacher_scores,
|
|
"no_in_batch_neg_flag": False
|
|
}
|
|
|
|
|
|
class IREmbedderSameDatasetTrainDataset(AbsEmbedderSameDatasetTrainDataset):
|
|
"""Abstract class for training dataset that samples batches from same dataset.
|
|
|
|
Args:
|
|
args (AbsEmbedderDataArguments): Data arguments.
|
|
default_batch_size (int): The default batch size for training.
|
|
seed (int): Random seed.
|
|
tokenizer (PreTrainedTokenizer): Tokenizer to use.
|
|
process_index (int, optional): Current process index. Defaults to 0.
|
|
num_processes (int, optional): Total number of processes. Defaults to 1.
|
|
"""
|
|
def __init__(
|
|
self,
|
|
args: AbsEmbedderDataArguments,
|
|
default_batch_size: int,
|
|
seed: int,
|
|
tokenizer: PreTrainedTokenizer,
|
|
process_index: int=0,
|
|
num_processes: int=1
|
|
):
|
|
super().__init__(
|
|
args,
|
|
default_batch_size,
|
|
seed,
|
|
tokenizer,
|
|
process_index,
|
|
num_processes
|
|
)
|
|
|
|
def _shuffle_answer(self, text):
|
|
"""shuffle the input text.
|
|
|
|
Args:
|
|
text (str): Input text.
|
|
|
|
Returns:
|
|
str: Shuffled text.
|
|
"""
|
|
split_text = []
|
|
chunk_size = len(text)//3 + 1
|
|
for i in range(0, len(text), chunk_size):
|
|
split_text.append(text[i:i+chunk_size])
|
|
random.shuffle(split_text)
|
|
return " ".join(split_text)
|
|
|
|
def __getitem__(self, _):
|
|
batch_indices, no_in_batch_neg_flag = self.batch_datas[self.step] # extend here
|
|
cur_batch_size = int(len(batch_indices) / self.num_processes)
|
|
batch_indices = batch_indices[self.process_index * cur_batch_size: (self.process_index + 1) * cur_batch_size]
|
|
batch_data = self.dataset[batch_indices]
|
|
self.step += 1
|
|
queries, answers, passages, teacher_scores, teacher_scores_answers = self._create_batch_data(batch_raw_data=batch_data)
|
|
return queries, answers, passages, teacher_scores, teacher_scores_answers, no_in_batch_neg_flag
|
|
def _create_batch_data(self, batch_raw_data):
|
|
"""Create a comple batch of data with queries, documents and teacher scores.
|
|
|
|
Args:
|
|
batch_raw_data (datasets.Dataset): One batch of raw data.
|
|
|
|
Returns:
|
|
List[str]: Queries with instruction format.
|
|
List[str]: Documents with instruction format.
|
|
List[float]: Teacher scores for model distillation.
|
|
"""
|
|
queries, answers, passages, teacher_scores, teacher_scores_answers = [], [], [], [], []
|
|
|
|
train_group_size, data_type = self._get_train_group_size(batch_raw_data)
|
|
|
|
for i in range(len(batch_raw_data['query'])):
|
|
if data_type is not None:
|
|
assert batch_raw_data['type'][i] == data_type, f"Data type is not consistent in the same batch"
|
|
|
|
queries.append(
|
|
self.args.query_instruction_format.format(
|
|
batch_raw_data['prompt'][i] if 'prompt' in batch_raw_data else self.args.query_instruction_for_retrieval,
|
|
batch_raw_data['query'][i]
|
|
)
|
|
)
|
|
if 'answer' in batch_raw_data.keys():
|
|
answers.append(
|
|
self.args.query_instruction_format.format(
|
|
batch_raw_data['prompt'][i] if 'prompt' in batch_raw_data else self.args.query_instruction_for_retrieval,
|
|
batch_raw_data['answer'][i]
|
|
)
|
|
)
|
|
# answers[-1] = self._shuffle_answer(answers[-1])
|
|
else:
|
|
answers.append(None)
|
|
|
|
tmp_passages = []
|
|
pos_idx = random.choice(list(range(len(batch_raw_data['pos'][i]))))
|
|
pos = self._shuffle_text(batch_raw_data['pos'][i][pos_idx])
|
|
# pos = self._shuffle_answer(batch_raw_data['answer'][i])
|
|
# pos = batch_raw_data['answer'][i]
|
|
tmp_passages.append(pos)
|
|
|
|
if train_group_size == 1:
|
|
pass
|
|
else:
|
|
neg_all_idx = list(range(len(batch_raw_data['neg'][i])))
|
|
if len(batch_raw_data['neg'][i]) < train_group_size - 1:
|
|
num = math.ceil((train_group_size - 1) / len(batch_raw_data['neg'][i]))
|
|
neg_idxs = random.sample(neg_all_idx * num, train_group_size - 1)
|
|
else:
|
|
neg_idxs = random.sample(neg_all_idx, train_group_size - 1)
|
|
|
|
if self.args.knowledge_distillation:
|
|
tmp_scores = [batch_raw_data['neg_scores'][i][neg_idx] for neg_idx in neg_idxs]
|
|
tmp_data = sorted([(x, y) for x, y in zip(neg_idxs, tmp_scores)], reverse=True, key=lambda x: x[1])
|
|
neg_idxs = [x[0] for x in tmp_data]
|
|
|
|
for neg_idx in neg_idxs:
|
|
tmp_passages.append(batch_raw_data['neg'][i][neg_idx])
|
|
# answers.append(batch_raw_data['neg'][i][neg_idx])
|
|
|
|
if self.args.knowledge_distillation:
|
|
if 'pos_scores' in batch_raw_data and batch_raw_data['pos_scores'][i] is not None:
|
|
if batch_raw_data['pos_scores'][i][pos_idx] < max(batch_raw_data['neg_scores'][i]):
|
|
teacher_scores.append(batch_raw_data['pos_scores'][i][pos_idx])
|
|
else:
|
|
teacher_scores.append(
|
|
batch_raw_data['pos_scores'][i][pos_idx] +
|
|
(max(batch_raw_data['neg_scores'][i]) - batch_raw_data['pos_scores'][i][pos_idx]) * 0.2
|
|
)
|
|
for neg_idx in neg_idxs:
|
|
if 'neg_scores' in batch_raw_data and batch_raw_data['neg_scores'][i] is not None:
|
|
teacher_scores.append(batch_raw_data['neg_scores'][i][neg_idx])
|
|
else:
|
|
teacher_scores = None
|
|
|
|
### add answer knowledge distillation
|
|
if self.args.answer_inbatch:
|
|
if train_group_size == 1:
|
|
pass
|
|
else:
|
|
neg_all_idx = list(range(len(batch_raw_data['neg_answer'][i])))
|
|
if len(batch_raw_data['neg_answer'][i]) < train_group_size - 1:
|
|
num = math.ceil((train_group_size - 1) / len(batch_raw_data['neg_answer'][i]))
|
|
neg_idxs = random.sample(neg_all_idx * num, train_group_size - 1)
|
|
else:
|
|
neg_idxs = random.sample(neg_all_idx, train_group_size - 1)
|
|
for neg_idx in neg_idxs:
|
|
answers.append(batch_raw_data['neg_answer'][i][neg_idx])
|
|
# if self.args.knowledge_distillation:
|
|
# if 'pos_scores' in batch_raw_data and batch_raw_data['pos_scores'][i] is not None:
|
|
# teacher_scores_answers.append(batch_raw_data['pos_scores'][i][pos_idx])
|
|
# for neg_idx in neg_idxs:
|
|
# if 'neg_scores' in batch_raw_data and batch_raw_data['neg_scores'][i] is not None:
|
|
# teacher_scores_answers.append(batch_raw_data['neg_scores'][i][neg_idx])
|
|
# else:
|
|
# teacher_scores_answers = None
|
|
|
|
if data_type is not None and data_type in ['symmetric_sts', 'symmetric_clustering']:
|
|
tmp_passages = [
|
|
self.args.query_instruction_format.format(
|
|
batch_raw_data['prompt'][i] if 'prompt' in batch_raw_data else self.args.query_instruction_for_retrieval,
|
|
p
|
|
) for p in tmp_passages
|
|
]
|
|
else:
|
|
if self.args.passage_instruction_for_retrieval is not None:
|
|
tmp_passages = [
|
|
self.args.passage_instruction_format.format(
|
|
self.args.passage_instruction_for_retrieval, p
|
|
) for p in tmp_passages
|
|
]
|
|
|
|
passages.extend(tmp_passages)
|
|
|
|
if teacher_scores is not None:
|
|
if len(teacher_scores) > 0 and len(passages) > 0:
|
|
assert len(teacher_scores) == len(passages)
|
|
|
|
return queries, answers, passages, teacher_scores, teacher_scores_answers
|
|
|
|
|
|
@dataclass
|
|
class IREmbedderSameDatasetCollator(AbsEmbedderSameDatasetCollator):
|
|
"""
|
|
EmbedCollator for SameDataset.
|
|
Note that after using this collator, the training_args should be set as:
|
|
|
|
``training_args.per_device_train_batch_size = 1``
|
|
|
|
``training_args.dataloader_num_workers = 0 # avoid multi-processing``
|
|
"""
|
|
query_max_len: int = 32
|
|
passage_max_len: int = 128
|
|
sub_batch_size: int = -1
|
|
|
|
def __call__(self, features):
|
|
queries = features[0][0]
|
|
answers = features[0][1]
|
|
passages = features[0][2]
|
|
teacher_scores = features[0][3]
|
|
teacher_scores_answers = features[0][4]
|
|
no_in_batch_neg_flag = features[0][5]
|
|
|
|
queries_inputs = self.tokenizer(
|
|
queries,
|
|
truncation=True,
|
|
max_length=self.query_max_len,
|
|
return_tensors=None
|
|
)
|
|
if answers[0] is not None:
|
|
answers_inputs = self.tokenizer(
|
|
answers,
|
|
truncation=True,
|
|
max_length=self.query_max_len,
|
|
return_tensors=None
|
|
)
|
|
else:
|
|
answers_inputs = None
|
|
passages_inputs = self.tokenizer(
|
|
passages,
|
|
truncation=True,
|
|
max_length=self.passage_max_len,
|
|
return_tensors=None
|
|
)
|
|
|
|
if self.sub_batch_size is None or self.sub_batch_size <= 0:
|
|
q_collated = self.tokenizer.pad(
|
|
queries_inputs,
|
|
padding=self.padding,
|
|
max_length=self.query_max_len,
|
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
|
return_tensors=self.return_tensors,
|
|
)
|
|
|
|
d_collated = self.tokenizer.pad(
|
|
passages_inputs,
|
|
padding=self.padding,
|
|
max_length=self.passage_max_len,
|
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
|
return_tensors=self.return_tensors,
|
|
)
|
|
|
|
if answers_inputs is None:
|
|
a_collated = None
|
|
else:
|
|
a_collated = self.tokenizer.pad(
|
|
answers_inputs,
|
|
padding=self.padding,
|
|
max_length=self.query_max_len,
|
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
|
return_tensors=self.return_tensors,
|
|
)
|
|
else:
|
|
batch_size = self.sub_batch_size
|
|
|
|
q_collated = []
|
|
for i in range(0, len(queries_inputs['attention_mask']), batch_size):
|
|
start = i
|
|
end = min(len(queries_inputs['attention_mask']), i + batch_size)
|
|
sub_features = {}
|
|
for k, v in queries_inputs.items():
|
|
sub_features[k] = v[start:end]
|
|
q_collated.append(self.tokenizer.pad(
|
|
sub_features,
|
|
padding=self.padding,
|
|
max_length=self.query_max_len,
|
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
|
return_tensors=self.return_tensors,
|
|
))
|
|
|
|
d_collated = []
|
|
for i in range(0, len(passages_inputs['attention_mask']), batch_size):
|
|
start = i
|
|
end = min(len(passages_inputs['attention_mask']), i + batch_size)
|
|
sub_features = {}
|
|
|
|
for k, v in passages_inputs.items():
|
|
sub_features[k] = v[start:end]
|
|
d_collated.append(self.tokenizer.pad(
|
|
sub_features,
|
|
padding=self.padding,
|
|
max_length=self.passage_max_len,
|
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
|
return_tensors=self.return_tensors,
|
|
))
|
|
|
|
if answers_inputs is None:
|
|
a_collated = None
|
|
else:
|
|
a_collated = []
|
|
for i in range(0, len(answers_inputs['attention_mask']), batch_size):
|
|
start = i
|
|
end = min(len(answers_inputs['attention_mask']), i + batch_size)
|
|
sub_features = {}
|
|
for k, v in answers_inputs.items():
|
|
sub_features[k] = v[start:end]
|
|
a_collated.append(self.tokenizer.pad(
|
|
sub_features,
|
|
padding=self.padding,
|
|
max_length=self.query_max_len,
|
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
|
return_tensors=self.return_tensors,
|
|
))
|
|
|
|
if isinstance(teacher_scores, list) and len(teacher_scores) == 0:
|
|
teacher_scores = None
|
|
|
|
return {
|
|
"queries": q_collated,
|
|
"answers": a_collated,
|
|
"passages": d_collated,
|
|
"teacher_scores": teacher_scores,
|
|
"teacher_scores_answers": teacher_scores_answers,
|
|
"no_in_batch_neg_flag": no_in_batch_neg_flag
|
|
} |