import re import os import json import math import random import datasets from tqdm import tqdm from functools import partial from glob import glob from contextlib import nullcontext from transformers.utils import logging from src import apply_chat_template, add_eos, split_file_dir_name_ext logger = logging.get_logger(__name__) # RETRIEVAL_CAND = [(1024,1), (512,2), (256,4), (128,8), (512,1), (256,2), (128,4)] RETRIEVAL_CAND = [(1024,1)] class Data: def _process_language_modeling(data, indices, tokenizer, min_length, max_length): outputs = {'input_ids': [], 'attention_mask': [], "labels": [], "length": [], "index": []} for i, text in enumerate(data['text']): # truncate text for faster processing encoded = tokenizer(text) if len(encoded["input_ids"]) < min_length: continue elif len(encoded['input_ids']) < max_length: encoded = add_eos(encoded, tokenizer.eos_token_id) else: for k, v in encoded.items(): encoded[k] = v[:max_length] encoded["labels"] = encoded["input_ids"].copy() for k, v in encoded.items(): outputs[k].append(v) # length is required for grouping outputs["length"].append(len(encoded['input_ids'])) outputs["index"].append(indices[i]) return outputs def _process_instruction_tuning(data, indices, tokenizer, chat_template, min_length, max_length, eval_mode=False): outputs = {'input_ids': [], 'attention_mask': [], "labels": [], "length": [], "index": []} for i, source in enumerate(data['conversations']): if source[0]["role"] != 'user': # Skip the first one if it is not from user source = source[1:] # NOTE: in evaluation, we only use the first turn in the conversation if eval_mode: # a string (the expected output from the assistant) if len(source) > 1: labels = source[1]['content'] else: labels = None source = source[:1] encoded = apply_chat_template( chat_template, source, tokenizer=tokenizer, # only return labels in evaluation mode return_labels=not eval_mode, add_generation_prompt=eval_mode, ).encoded # skip data that not fall in between min_length and max_length if len(encoded["input_ids"]) < min_length: continue if len(encoded["input_ids"]) > max_length: continue if eval_mode: encoded["labels"] = labels for k, v in encoded.items(): outputs[k].append(v) outputs['length'].append(len(encoded['input_ids'])) outputs['index'].append(indices[i]) return outputs def prepare_train_data(data_files=None, tokenizer=None, max_length=4096, min_length=512, chat_template="vicuna", max_sample_num=None, seed=42, cache_dir=None, load_from_cache_file=None): if data_files is None: return None if isinstance(data_files, list): logger.info(f"Loading training data from {data_files}...") elif isinstance(data_files, str): logger.info(f"Loading training data from {data_files}...") data_files = [data_files] else: raise ValueError(f"Invalid training data {data_files}!") data_2_num_sample = {} for data_file in data_files: match = re.search("\[(\d*)\]", data_file) if match: max_sample_num = int(match.group(1)) data_file = re.sub("\[(\d*)\]", "", data_file) else: max_sample_num = None data_2_num_sample[data_file] = max_sample_num random.seed(seed) train_datasets = [] for data_file, max_sample_num in data_2_num_sample.items(): if os.path.isdir(data_file) and os.path.exists(os.path.join(data_file, "dataset_info.json")): # the dataset may be save_to_disk in advance dataset = datasets.load_from_disk(data_file) else: # the dataset is a json file dataset = datasets.load_dataset('json', data_files=data_file, split='train', cache_dir=cache_dir) column_names = dataset.column_names if "text" in column_names: process_fn = partial( Data._process_language_modeling, tokenizer=tokenizer, min_length=min_length, max_length=max_length ) elif "conversations" in column_names: process_fn = partial( Data._process_instruction_tuning, tokenizer=tokenizer, chat_template=chat_template, min_length=min_length, max_length=max_length ) else: raise ValueError(f"Found neither 'text' nor 'conversations' in the training data!") dataset = dataset.map(process_fn, batched=True, num_proc=32, remove_columns=dataset.column_names, batch_size=32, with_indices=True, load_from_cache_file=load_from_cache_file) if max_sample_num is not None and len(dataset) > max_sample_num: dataset = dataset.train_test_split(max_sample_num, seed=seed)["test"] # index column is useless in training if "index" in dataset.column_names: dataset = dataset.remove_columns(["index"]) train_datasets.append(dataset) dataset = datasets.concatenate_datasets(train_datasets) return dataset def prepare_eval_data(data_files=None, tokenizer=None, max_length=4096, min_length=512, chat_template="vicuna", max_eval_num=None, cache_dir=None, seed=42, load_from_cache_file=None): if data_files is None: return None random.seed(seed) if max_eval_num is not None: dataset = datasets.load_dataset('json', data_files=data_files, split=f'train[:{max_eval_num}]', cache_dir=cache_dir) else: dataset = datasets.load_dataset('json', data_files=data_files, split='train', cache_dir=cache_dir) column_names = dataset.column_names if "text" in column_names: process_fn = partial( Data._process_language_modeling, tokenizer=tokenizer, min_length=min_length, max_length=max_length ) elif "conversations" in column_names: process_fn = partial( Data._process_instruction_tuning, tokenizer=tokenizer, chat_template=chat_template, min_length=min_length, max_length=max_length, eval_mode=True, ) else: raise ValueError(f"Found neither 'text' nor 'conversations' in the training data!") dataset = dataset.map(process_fn, batched=True, num_proc=32, remove_columns=dataset.column_names, with_indices=True, load_from_cache_file=load_from_cache_file) return dataset