302 lines
13 KiB
Python
302 lines
13 KiB
Python
import math
|
|
import os.path
|
|
import random
|
|
from dataclasses import dataclass
|
|
import torch
|
|
import numpy as np
|
|
import datasets
|
|
from pprint import pprint
|
|
from torch.utils.data import Dataset
|
|
from transformers import DataCollatorWithPadding
|
|
import torch.distributed as dist
|
|
|
|
from .arguments import DataArguments
|
|
|
|
|
|
class SameDatasetTrainDataset(Dataset):
|
|
"""Dataset to yield a batch of data at one time. All samples in the same batch comes from the same task.
|
|
"""
|
|
def __init__(self, args: DataArguments, batch_size: int, seed: int, process_index: int=0, num_processes: int=1):
|
|
train_datasets = []
|
|
each_data_inxs = []
|
|
batch_size_inxs = []
|
|
pqloss_flag = []
|
|
cur_all_num = 0
|
|
|
|
SMALL_THRESHOLD = args.small_threshold
|
|
DROP_THRESHOLD = args.drop_threshold
|
|
|
|
context_feat = datasets.Features({
|
|
'query': datasets.Value('string'),
|
|
'pos': datasets.Sequence(datasets.Value('string')),
|
|
'neg': datasets.Sequence(datasets.Value('string'))
|
|
})
|
|
context_feat_kd = datasets.Features({
|
|
'query': datasets.Value('string'),
|
|
'pos': datasets.Sequence(datasets.Value('string')),
|
|
'neg': datasets.Sequence(datasets.Value('string')),
|
|
'pos_scores': datasets.Sequence(datasets.Value('float')),
|
|
'neg_scores': datasets.Sequence(datasets.Value('float')),
|
|
})
|
|
assert isinstance(args.train_data, list) and len(args.train_data) >= 1
|
|
|
|
if dist.get_rank() == 0:
|
|
self.print_batch_size(batch_size=batch_size, train_group_size=args.train_group_size)
|
|
|
|
for data_dir in args.train_data:
|
|
if not os.path.isdir(data_dir):
|
|
raise FileNotFoundError(f"{data_dir} is a file, not a directionary")
|
|
|
|
small_datasets = []
|
|
small_batch_size = math.inf
|
|
|
|
# Add `parallel_` in `data_dir` to indicate that this dataset is parallel corpus
|
|
flag = 'parallel_' in data_dir
|
|
for file in os.listdir(data_dir):
|
|
if not (file.endswith('.json') or file.endswith('.jsonl')):
|
|
continue
|
|
|
|
file_path = os.path.join(data_dir, file)
|
|
if dist.get_rank() == 0:
|
|
print(f'loading data from {file_path} ...')
|
|
try:
|
|
temp_dataset = datasets.load_dataset('json', data_files=file_path, split='train', cache_dir=args.cache_path, features=context_feat)
|
|
except:
|
|
temp_dataset = datasets.load_dataset('json', data_files=file_path, split='train', cache_dir=args.cache_path, features=context_feat_kd)
|
|
if not args.knowledge_distillation:
|
|
temp_dataset = temp_dataset.remove_columns(['pos_scores', 'neg_scores'])
|
|
|
|
if len(temp_dataset) == 0:
|
|
continue
|
|
elif len(temp_dataset) < SMALL_THRESHOLD:
|
|
small_datasets.append(temp_dataset)
|
|
small_batch_size = min(small_batch_size, self.get_file_batch_size(file, batch_size, train_group_size=args.train_group_size))
|
|
else:
|
|
if args.max_example_num_per_dataset is not None and len(temp_dataset) > args.max_example_num_per_dataset:
|
|
temp_dataset = temp_dataset.select(
|
|
random.sample(list(range(len(temp_dataset))), args.max_example_num_per_dataset))
|
|
train_datasets.append(temp_dataset)
|
|
each_data_inxs.append(np.arange(len(temp_dataset)) + cur_all_num)
|
|
cur_all_num += len(temp_dataset)
|
|
batch_size_inxs.append(self.get_file_batch_size(file, batch_size, train_group_size=args.train_group_size))
|
|
pqloss_flag.append(flag)
|
|
|
|
if len(small_datasets) > 0:
|
|
small_dataset = datasets.concatenate_datasets(small_datasets)
|
|
if len(small_dataset) >= DROP_THRESHOLD:
|
|
train_datasets.append(small_dataset)
|
|
each_data_inxs.append(np.arange(len(small_dataset)) + cur_all_num)
|
|
cur_all_num += len(small_dataset)
|
|
batch_size_inxs.append(small_batch_size)
|
|
pqloss_flag.append(flag)
|
|
|
|
self.dataset = datasets.concatenate_datasets(train_datasets)
|
|
self.each_data_inxs = each_data_inxs
|
|
self.datasets_inxs = np.arange(len(each_data_inxs))
|
|
self.batch_size_inxs = batch_size_inxs
|
|
self.pqloss_flag = pqloss_flag
|
|
|
|
self.process_index = process_index
|
|
self.num_processes = num_processes
|
|
self.args = args
|
|
self.shuffle_ratio = args.shuffle_ratio
|
|
|
|
self.deterministic_generator = np.random.default_rng(seed)
|
|
self.step = 0
|
|
self.refresh_epoch()
|
|
|
|
def print_batch_size(self, batch_size: int, train_group_size: int):
|
|
length_list = ['0-500', '500-1000', '1000-2000', '2000-3000', '3000-4000', '4000-5000', '5000-6000', '6000-7000', '7000-inf']
|
|
batch_size_dict = {
|
|
k: self.get_file_batch_size(f"len-{k}.jsonl", batch_size, train_group_size) for k in length_list
|
|
}
|
|
batch_size_list = [
|
|
f'{length}: {batch_size_dict[length]}' for length in length_list
|
|
]
|
|
print("=========================")
|
|
print("Batch Size Dict:")
|
|
pprint(batch_size_list)
|
|
print("=========================")
|
|
|
|
@staticmethod
|
|
def get_file_batch_size(file: str, batch_size: int, train_group_size: int):
|
|
if train_group_size == 8:
|
|
# 80GB
|
|
if 'len-0-500.jsonl' in file:
|
|
return 48
|
|
elif 'len-500-1000.jsonl' in file:
|
|
return 32
|
|
elif 'len-1000-2000.jsonl' in file:
|
|
return 20
|
|
elif 'len-2000-3000.jsonl' in file:
|
|
return 18
|
|
elif 'len-3000-4000.jsonl' in file:
|
|
return 14
|
|
elif 'len-4000-5000.jsonl' in file:
|
|
return 14
|
|
elif 'len-5000-6000.jsonl' in file:
|
|
return 12
|
|
elif 'len-6000-7000.jsonl' in file:
|
|
return 10
|
|
elif 'len-7000-inf.jsonl' in file:
|
|
return 8
|
|
else:
|
|
return batch_size
|
|
elif train_group_size == 1:
|
|
# 80GB
|
|
if 'len-0-500.jsonl' in file:
|
|
return 700
|
|
elif 'len-500-1000.jsonl' in file:
|
|
return 570
|
|
elif 'len-1000-2000.jsonl' in file:
|
|
return 388
|
|
elif 'len-2000-3000.jsonl' in file:
|
|
return 288
|
|
elif 'len-3000-4000.jsonl' in file:
|
|
return 224
|
|
elif 'len-4000-5000.jsonl' in file:
|
|
return 180
|
|
elif 'len-5000-6000.jsonl' in file:
|
|
return 157
|
|
elif 'len-6000-7000.jsonl' in file:
|
|
return 128
|
|
elif 'len-7000-inf.jsonl' in file:
|
|
return 104
|
|
else:
|
|
return batch_size
|
|
else:
|
|
return batch_size
|
|
|
|
def refresh_epoch(self):
|
|
print(f'---------------------------*Rank {self.process_index}: refresh data---------------------------')
|
|
self.deterministic_generator.shuffle(self.datasets_inxs)
|
|
# Dynamically adjust batch size
|
|
batch_datas = []
|
|
for dataset_inx in self.datasets_inxs:
|
|
self.deterministic_generator.shuffle(self.each_data_inxs[dataset_inx])
|
|
cur_batch_size = self.batch_size_inxs[dataset_inx]*self.num_processes
|
|
flag = self.pqloss_flag[dataset_inx]
|
|
for start_index in range(0, len(self.each_data_inxs[dataset_inx]), cur_batch_size):
|
|
# judge the last batch's length
|
|
if len(self.each_data_inxs[dataset_inx]) - start_index < 2 * self.num_processes:
|
|
break
|
|
batch_datas.append((self.each_data_inxs[dataset_inx][start_index:start_index+cur_batch_size], flag))
|
|
self.deterministic_generator.shuffle(batch_datas)
|
|
self.batch_datas = batch_datas
|
|
self.step = 0
|
|
|
|
def __getitem__(self, _):
|
|
batch_indices, pqloss_flag = self.batch_datas[self.step]
|
|
cur_batch_size = int(len(batch_indices) / self.num_processes)
|
|
batch_indices = batch_indices[self.process_index * cur_batch_size: (self.process_index + 1) * cur_batch_size]
|
|
batch_data = self.dataset[batch_indices]
|
|
self.step += 1
|
|
queries, passages, teacher_scores = self.create_batch_data(batch_raw_data=batch_data)
|
|
# print('rank, step, flag, query, passage:', dist.get_rank(), self.step, pqloss_flag, queries, passages)
|
|
return queries, passages, teacher_scores, pqloss_flag
|
|
|
|
def shuffle_text(self, text):
|
|
if self.shuffle_ratio > 0 and len(text) > 100 and random.random() < self.shuffle_ratio:
|
|
split_text = []
|
|
chunk_size = len(text)//3 + 1
|
|
for i in range(0, len(text), chunk_size):
|
|
split_text.append(text[i:i+chunk_size])
|
|
random.shuffle(split_text)
|
|
return " ".join(split_text)
|
|
else:
|
|
return text
|
|
|
|
def create_batch_data(self, batch_raw_data):
|
|
queries, passages = [], []
|
|
teacher_scores = []
|
|
for i in range(len(batch_raw_data['query'])):
|
|
queries.append(batch_raw_data['query'][i])
|
|
|
|
pos_inx = random.choice(list(range(len(batch_raw_data['pos'][i]))))
|
|
passages.append(self.shuffle_text(batch_raw_data['pos'][i][pos_inx]))
|
|
if 'pos_scores' in batch_raw_data and batch_raw_data['pos_scores'][i] is not None:
|
|
teacher_scores.append(batch_raw_data['pos_scores'][i][pos_inx])
|
|
|
|
neg_inx_set = list(range(len(batch_raw_data['neg'][i])))
|
|
if len(batch_raw_data['neg'][i]) < self.args.train_group_size - 1:
|
|
num = math.ceil((self.args.train_group_size - 1) / len(batch_raw_data['neg'][i]))
|
|
neg_inxs = random.sample(neg_inx_set * num, self.args.train_group_size - 1)
|
|
else:
|
|
neg_inxs = random.sample(neg_inx_set, self.args.train_group_size - 1)
|
|
|
|
if 'neg_scores' in batch_raw_data and batch_raw_data['neg_scores'][i] is not None:
|
|
neg_scores = [(x, batch_raw_data['neg_scores'][i][x]) for x in neg_inxs]
|
|
neg_scores = sorted(neg_scores, key=lambda x:x[1], reverse=True)
|
|
neg_inxs = [x[0] for x in neg_scores]
|
|
teacher_scores.extend([x[1] for x in neg_scores])
|
|
|
|
negs = [batch_raw_data['neg'][i][x] for x in neg_inxs]
|
|
passages.extend(negs)
|
|
|
|
if len(teacher_scores) > 0 and len(passages) > 0:
|
|
assert len(teacher_scores) == len(passages)
|
|
|
|
if self.args.query_instruction_for_retrieval is not None:
|
|
queries = [self.args.query_instruction_for_retrieval+q for q in queries]
|
|
if self.args.passage_instruction_for_retrieval is not None:
|
|
passages = [self.args.passage_instruction_for_retrieval+p for p in passages]
|
|
|
|
if len(teacher_scores) == 0:
|
|
teacher_scores = None
|
|
return queries, passages, teacher_scores
|
|
|
|
def __len__(self):
|
|
return len(self.batch_datas) * self.num_processes
|
|
|
|
|
|
@dataclass
|
|
class EmbedCollator(DataCollatorWithPadding):
|
|
"""
|
|
Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg]
|
|
and pass batch separately to the actual collator.
|
|
Abstract out data detail for the model.
|
|
"""
|
|
query_max_len: int = 32
|
|
passage_max_len: int = 128
|
|
|
|
def __call__(self, features):
|
|
query = [f[0] for f in features]
|
|
passage = [f[1] for f in features]
|
|
|
|
teacher_scores = None
|
|
if len(features[0]) > 2:
|
|
teacher_scores = [f[2] for f in features]
|
|
if teacher_scores[0] is None:
|
|
teacher_scores = None
|
|
else:
|
|
teacher_scores = torch.FloatTensor(teacher_scores)
|
|
|
|
flag = None
|
|
if len(features[0]) == 4:
|
|
flag = [f[3] for f in features][0]
|
|
|
|
if isinstance(query[0], list):
|
|
query = sum(query, [])
|
|
if isinstance(passage[0], list):
|
|
passage = sum(passage, [])
|
|
|
|
q_collated = self.tokenizer(
|
|
query,
|
|
# padding='max_length', # used for adjusting the batch size in `get_file_batch_size()`
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=self.query_max_len,
|
|
return_tensors="pt",
|
|
)
|
|
d_collated = self.tokenizer(
|
|
passage,
|
|
# padding='max_length', # used for adjusting the batch size in `get_file_batch_size()`
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=self.passage_max_len,
|
|
return_tensors="pt",
|
|
)
|
|
if teacher_scores is not None:
|
|
teacher_scores = teacher_scores.reshape((len(q_collated['input_ids']), -1))
|
|
return {"query": q_collated, "passage": d_collated, "teacher_scores": teacher_scores, "bi_directions": flag}
|