import copy import pickle import sys import math import os.path import random from dataclasses import dataclass from typing import List, Tuple import json import numpy as np import datasets from numpy import mean from torch.utils.data import Dataset from transformers import DataCollatorWithPadding, DataCollatorForSeq2Seq from transformers import PreTrainedTokenizer, BatchEncoding import torch.distributed as dist from arguments import DataArguments def get_query_prompt(query, prompt, use_special_tokens): if use_special_tokens: return f'{prompt}\n{query}' else: return f'Instruct: {prompt}\nQuery: {query}' def add_prompt(example, prompt): example['prompt'] = prompt return example def traverse_directory_using_os(root_folder): file_list = [] if not os.path.isdir(root_folder): file_list.append(root_folder) else: for dirpath, dirnames, filenames in os.walk(root_folder): for filename in filenames: full_path = os.path.join(dirpath, filename) file_list.append(full_path) return file_list class SameDatasetTrainDataset(Dataset): """Dataset to yield a batch of data at one time. All samples in the same batch comes from the same task. """ tokenizer: PreTrainedTokenizer loss_type: str def __init__(self, args: DataArguments, batch_size, seed, tokenizer, process_index=0, num_processes=1): train_datasets = [] each_data_inxs = [] batch_size_inxs = [] data_names = [] cur_all_num = 0 FLAG_LOG_NAME = '.log' if not args.load_from_disk: train_data = args.train_data all_dataset = datasets.load_dataset(train_data, cache_dir=args.cache_path) for name in all_dataset.keys(): train_datasets.append(all_dataset[name]) each_data_inxs.append(np.arange(len(all_dataset[name])) + cur_all_num) cur_all_num += len(all_dataset[name]) if 'symmetric' in all_dataset[name][0]['type']: batch_size_inxs.append(args.symmetric_batch_size // num_processes) else: batch_size_inxs.append(batch_size) data_names.append(name) self.dataset = datasets.concatenate_datasets(train_datasets) self.each_data_inxs = each_data_inxs # k个列表,每个列表里存小数据集的位置 self.datasets_inxs = np.arange(len(each_data_inxs)) # k个小数据集,0 —— k-1 self.batch_size_inxs = batch_size_inxs # 每个小数据的batch size self.data_names = data_names else: # assert isinstance(args.load_disk_path, list) and len(args.load_disk_path) >= 1 # for load_disk_path in args.load_disk_path: load_disk_path = args.load_disk_path if not os.path.isdir(load_disk_path): raise FileNotFoundError(f"{load_disk_path} is a file, not a directory") if not os.path.exists(os.path.join(load_disk_path, FLAG_LOG_NAME)): raise FileNotFoundError(f"{load_disk_path} does not have {FLAG_LOG_NAME}") with open(os.path.join(load_disk_path, FLAG_LOG_NAME), "r", encoding='utf-8') as f: log_info = json.load(f) cur_each_data_inxs = [np.array(x) for x in log_info["each_data_inxs"]] cur_batch_size_inxs = [batch_size for x in log_info["batch_size_inxs"]] cur_data_names = [x for x in log_info["data_names"]] print(f"start loading {log_info['train_data']} from {load_disk_path}") args.train_data = log_info['train_data'] cur_dataset = datasets.load_from_disk(load_disk_path) for i in range(len(cur_each_data_inxs)): cur_each_data_inxs[i] += cur_all_num cur_all_num += len(cur_dataset) train_datasets.append(cur_dataset) each_data_inxs.extend(cur_each_data_inxs) batch_size_inxs.extend(cur_batch_size_inxs) data_names.extend(cur_data_names) self.dataset = datasets.concatenate_datasets(train_datasets) self.each_data_inxs = each_data_inxs self.datasets_inxs = np.arange(len(each_data_inxs)) self.batch_size_inxs = batch_size_inxs self.data_names = data_names if args.save_to_disk: if not os.path.exists(args.save_disk_path): os.makedirs(args.save_disk_path) if os.path.exists(os.path.join(args.save_disk_path, FLAG_LOG_NAME)): print(f"FLAG_LOG file {FLAG_LOG_NAME} already exists in {args.save_disk_path}!!!") print("args.save_to_disk deprecated.") else: if args.num_shards <= 0: self.dataset.save_to_disk(args.save_disk_path, max_shard_size=args.save_max_shard_size) else: self.dataset.save_to_disk(args.save_disk_path, num_shards=args.num_shards) with open(os.path.join(args.save_disk_path, FLAG_LOG_NAME), "w", encoding='utf-8') as f: log_info = { "train_data": args.train_data, "each_data_inxs": [x.tolist() for x in each_data_inxs], "batch_size_inxs": batch_size_inxs, "data_names": data_names } json.dump(log_info, f, ensure_ascii=False, indent=4) print(f"save {args.train_data} to {args.save_disk_path}") if args.exit_after_save: print("exit after save") exit(0) self.process_index = process_index self.num_processes = num_processes self.args = args self.shuffle_ratio = args.shuffle_ratio self.deterministic_generator = np.random.default_rng(seed) self.step = 0 self.refresh_epoch() self.tokenizer = tokenizer self.query_max_len = self.args.query_max_len self.passage_max_len = self.args.passage_max_len if args.use_special_tokens: self.suffix = self.tokenizer('\n', add_special_tokens=False)['input_ids'] else: self.suffix = self.tokenizer('\nResponse:', add_special_tokens=False)['input_ids'] self.prefix = self.tokenizer('', add_special_tokens=False)['input_ids'] def refresh_epoch(self): print(f'---------------------------*Rank {self.process_index}: refresh data---------------------------') self.deterministic_generator.shuffle(self.datasets_inxs) # 洗了小数据集的顺序 # Dynamically adjust batch size batch_datas = [] for dataset_inx in self.datasets_inxs: # 按洗了小数据集的顺序,加载所有的数据集 self.deterministic_generator.shuffle(self.each_data_inxs[dataset_inx]) # 洗了小数据集内数据的顺序 cur_batch_size = self.batch_size_inxs[ dataset_inx] * self.num_processes # 总batch_size,小batch_size * num_processes for start_index in range(0, len(self.each_data_inxs[dataset_inx]), cur_batch_size): # judge the last batch's length # 丢弃最后一个不完整的batch size if start_index + cur_batch_size > len(self.each_data_inxs[dataset_inx]): # batch_datas.append(self.each_data_inxs[dataset_inx][start_index: len(self.each_data_inxs[dataset_inx])]) # self.deterministic_generator.shuffle(self.each_data_inxs[dataset_inx]) # 洗了小数据集内数据的顺序 # batch_datas[-1].extend(self.each_data_inxs[dataset_inx][: start_index + cur_batch_size - len(self.each_data_inxs[dataset_inx])]) break batch_datas.append(self.each_data_inxs[dataset_inx][start_index:start_index + cur_batch_size]) self.deterministic_generator.shuffle(batch_datas) # 让所有小数据集混在一起 self.batch_datas = batch_datas self.step = 0 def __getitem__(self, idx): if self.step >= len(self.batch_datas): self.refresh_epoch() batch_indices = self.batch_datas[self.step] cur_batch_size = int(len(batch_indices) / self.num_processes) batch_indices = batch_indices[self.process_index * cur_batch_size: ( self.process_index + 1) * cur_batch_size] # 只获取当前小batch_size的数据 batch_data = self.dataset[batch_indices] self.step += 1 queries_inputs, passages_inputs, messages, scores = self.create_batch_data(batch_raw_data=batch_data) return queries_inputs, passages_inputs, messages, scores def create_batch_data(self, batch_raw_data): queries, passages, scores = [], [], [] finetune_type = batch_raw_data['type'][0] if 'symmetric' in finetune_type and ('sts' in finetune_type or 'clustering' in finetune_type): train_group_size = self.args.symmetric_train_group_size elif 'only_1neg' in finetune_type: train_group_size = 2 elif 'symmetric' in finetune_type and 'class' in finetune_type: train_group_size = self.args.max_class_neg + 1 else: train_group_size = self.args.train_group_size icl_pairs = [] for i in range(len(batch_raw_data['query'])): # print(batch_raw_data['query'][i], batch_raw_data['prompt'][i], batch_raw_data['pos_scores'][i], # batch_raw_data['neg_scores'][i]) queries.append( get_query_prompt(batch_raw_data['query'][i], batch_raw_data['prompt'][i], self.args.use_special_tokens)) pos_index = random.choice(list(range(len(batch_raw_data['pos'][i])))) pos = batch_raw_data['pos'][i][pos_index] if batch_raw_data.get('pos_scores') is not None: if batch_raw_data['pos_scores'][i] is not None: if batch_raw_data['pos_scores'][i][pos_index] is not None: scores.append(batch_raw_data['pos_scores'][i][pos_index]) if len(batch_raw_data['neg'][i]) < train_group_size - 1: num = math.ceil((train_group_size - 1) / len(batch_raw_data['neg'][i])) neg_indexes = list(range(len(batch_raw_data['neg'][i]))) * num else: neg_indexes = list(range(len(batch_raw_data['neg'][i]))) neg_indexes = random.sample(neg_indexes, train_group_size - 1) negs = [batch_raw_data['neg'][i][neg_index] for neg_index in neg_indexes] if batch_raw_data.get('neg_scores') is not None: try: if batch_raw_data['neg_scores'][i] is not None: for neg_index in neg_indexes: if batch_raw_data['neg_scores'][i][neg_index] is not None: scores.append(batch_raw_data['neg_scores'][i][neg_index]) except: print(neg_indexes, batch_raw_data['neg_scores'][i]) sys.exit() tmp_passages = [] tmp_passages.append(pos) tmp_passages.extend(negs) if self.args.retrieval_use_examples or ('clustering' in batch_raw_data['type'][i] or 'sts' in batch_raw_data['type'][i] or 'class' in batch_raw_data['type'][i]): if 'clustering' in batch_raw_data['type'][i]: icl_pairs.append( (self.tokenizer.decode(self.tokenizer(queries[-1], add_special_tokens=False)['input_ids'][ :self.args.example_query_max_len]), self.tokenizer.decode( self.tokenizer(batch_raw_data['category'][i], add_special_tokens=False)['input_ids'][ :self.args.example_passage_max_len])) ) else: icl_pairs.append( (self.tokenizer.decode(self.tokenizer(queries[-1], add_special_tokens=False)['input_ids'][ :self.args.example_query_max_len]), self.tokenizer.decode( self.tokenizer(pos, add_special_tokens=False)['input_ids'][:self.args.example_passage_max_len])) ) else: icl_pairs = [] if 'sts' in batch_raw_data['type'][i] or 'clustering' in batch_raw_data['type'][i]: tmp_passages = [get_query_prompt(p, batch_raw_data['prompt'][i], self.args.use_special_tokens) for p in tmp_passages] tmp_passages = self.tokenizer.batch_decode( self.tokenizer(tmp_passages, max_length=self.passage_max_len - 1 - len(self.suffix), truncation=True, add_special_tokens=False)['input_ids'] ) for i in range(len(tmp_passages)): if self.args.use_special_tokens: tmp_passages[i] = tmp_passages[i] + '\n' else: tmp_passages[i] = tmp_passages[i] + '\nResponse:' passages.extend(tmp_passages) if 'symmetric' in finetune_type and ('class' in finetune_type or 'clustering' in finetune_type): messages = ['not in-batch'] else: messages = ['normal'] * len(passages) for i in range(len(queries)): choices = random.choice([0, 1, 2, 3, 4, 5]) if choices > 0 and len(icl_pairs) > 0: prefix_ids = random.sample(list(range(len(icl_pairs))), choices + 1) if i in prefix_ids: prefix_ids.remove(i) prefix_ids = prefix_ids[:choices] if self.args.use_special_tokens: prefix = '' for idx in prefix_ids: tmp = prefix + '\n'.join(icl_pairs[idx]) + '\n\n' if len(self.tokenizer(tmp)['input_ids']) > self.query_max_len - 512: break prefix = tmp # prefix = '\n\n'.join(['\n'.join(icl_pairs[idx]) for idx in prefix_ids]) + '\n\n' else: prefix = '' for idx in prefix_ids: tmp = prefix + '\nResponse: '.join(icl_pairs[idx]) + '\n\n' if len(self.tokenizer(tmp)['input_ids']) > self.query_max_len - 512: break prefix = tmp # prefix = '\n\n'.join(['\nResponse: '.join(icl_pairs[idx]) for idx in prefix_ids]) + '\n\n' else: prefix = '' if self.args.use_special_tokens: queries[i] = prefix + queries[i] queries[i] = self.tokenizer.decode( self.tokenizer(queries[i], max_length=self.query_max_len - len(self.prefix) - len(self.suffix), truncation=True, add_special_tokens=False)['input_ids'] ) + '\n' # queries[i] = prefix + queries[i] + '\n' else: queries[i] = prefix + queries[i] queries[i] = self.tokenizer.decode( self.tokenizer(queries[i], max_length=self.query_max_len - len(self.prefix) - len(self.suffix), truncation=True, add_special_tokens=False)['input_ids'] ) + '\nResponse:' # queries[i] = prefix + queries[i] + '\nResponse: ' queries_inputs = self.tokenizer(queries, return_tensors=None, max_length=self.query_max_len, truncation=True, add_special_tokens=True) passage_inputs = self.tokenizer(passages, return_tensors=None, max_length=self.passage_max_len, truncation=True, add_special_tokens=True) return queries_inputs, passage_inputs, messages, scores def __len__(self): return len(self.batch_datas) * self.num_processes @dataclass class SameEmbedCollator(DataCollatorForSeq2Seq): """ Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg] and pass batch separately to the actual collator. Abstract out data detail for the model. """ query_max_len: int = 32 passage_max_len: int = 128 sub_batch_size: int = 0 train_group_size: int = 0 def __call__(self, features, return_tensors='pt'): if return_tensors is None: return_tensors = self.return_tensors queries = features[0][0] passages = features[0][1] messages = features[0][2] scores = features[0][3] if self.sub_batch_size is None or self.sub_batch_size <= 0: q_collated = self.tokenizer.pad( queries, padding=self.padding, max_length=self.query_max_len, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=return_tensors, ) d_collated = self.tokenizer.pad( passages, padding=self.padding, max_length=self.passage_max_len, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=return_tensors, ) else: batch_size = self.sub_batch_size q_collated = [] for i in range(0, len(queries['attention_mask']), batch_size): start = i end = min(len(queries['attention_mask']), i + batch_size) sub_features = {} for k, v in queries.items(): sub_features[k] = v[start:end] q_collated.append(self.tokenizer.pad( sub_features, padding=self.padding, max_length=self.query_max_len, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=return_tensors, )) d_collated = [] for i in range(0, len(passages['attention_mask']), batch_size): start = i end = min(len(passages['attention_mask']), i + batch_size) sub_features = {} for k, v in passages.items(): sub_features[k] = v[start:end] d_collated.append(self.tokenizer.pad( sub_features, padding=self.padding, max_length=self.passage_max_len, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=return_tensors, )) # print(self.tokenizer.decode(q_collated['input_ids'][0])) if len(scores) == 0: scores = None return {"query": q_collated, "passage": d_collated, 'messages': messages, "teacher_scores": scores}