189 lines
7.4 KiB
Python
189 lines
7.4 KiB
Python
import re
|
|
import os
|
|
import json
|
|
import math
|
|
import random
|
|
import datasets
|
|
from tqdm import tqdm
|
|
from functools import partial
|
|
from glob import glob
|
|
from contextlib import nullcontext
|
|
from transformers.utils import logging
|
|
from src import apply_chat_template, add_eos, split_file_dir_name_ext
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
# RETRIEVAL_CAND = [(1024,1), (512,2), (256,4), (128,8), (512,1), (256,2), (128,4)]
|
|
RETRIEVAL_CAND = [(1024,1)]
|
|
|
|
|
|
class Data:
|
|
def _process_language_modeling(data, indices, tokenizer, min_length, max_length):
|
|
outputs = {'input_ids': [], 'attention_mask': [], "labels": [], "length": [], "index": []}
|
|
|
|
for i, text in enumerate(data['text']):
|
|
# truncate text for faster processing
|
|
encoded = tokenizer(text)
|
|
if len(encoded["input_ids"]) < min_length:
|
|
continue
|
|
elif len(encoded['input_ids']) < max_length:
|
|
encoded = add_eos(encoded, tokenizer.eos_token_id)
|
|
else:
|
|
for k, v in encoded.items():
|
|
encoded[k] = v[:max_length]
|
|
|
|
encoded["labels"] = encoded["input_ids"].copy()
|
|
|
|
for k, v in encoded.items():
|
|
outputs[k].append(v)
|
|
# length is required for grouping
|
|
outputs["length"].append(len(encoded['input_ids']))
|
|
outputs["index"].append(indices[i])
|
|
|
|
return outputs
|
|
|
|
def _process_instruction_tuning(data, indices, tokenizer, chat_template, min_length, max_length, eval_mode=False):
|
|
outputs = {'input_ids': [], 'attention_mask': [], "labels": [], "length": [], "index": []}
|
|
|
|
for i, source in enumerate(data['conversations']):
|
|
if source[0]["role"] != 'user':
|
|
# Skip the first one if it is not from user
|
|
source = source[1:]
|
|
|
|
# NOTE: in evaluation, we only use the first turn in the conversation
|
|
if eval_mode:
|
|
# a string (the expected output from the assistant)
|
|
if len(source) > 1:
|
|
labels = source[1]['content']
|
|
else:
|
|
labels = None
|
|
source = source[:1]
|
|
|
|
encoded = apply_chat_template(
|
|
chat_template,
|
|
source,
|
|
tokenizer=tokenizer,
|
|
# only return labels in evaluation mode
|
|
return_labels=not eval_mode,
|
|
add_generation_prompt=eval_mode,
|
|
).encoded
|
|
|
|
# skip data that not fall in between min_length and max_length
|
|
if len(encoded["input_ids"]) < min_length:
|
|
continue
|
|
if len(encoded["input_ids"]) > max_length:
|
|
continue
|
|
|
|
if eval_mode:
|
|
encoded["labels"] = labels
|
|
|
|
for k, v in encoded.items():
|
|
outputs[k].append(v)
|
|
outputs['length'].append(len(encoded['input_ids']))
|
|
outputs['index'].append(indices[i])
|
|
|
|
return outputs
|
|
|
|
def prepare_train_data(data_files=None, tokenizer=None, max_length=4096, min_length=512, chat_template="vicuna", max_sample_num=None, seed=42, cache_dir=None, load_from_cache_file=None):
|
|
if data_files is None:
|
|
return None
|
|
|
|
if isinstance(data_files, list):
|
|
logger.info(f"Loading training data from {data_files}...")
|
|
elif isinstance(data_files, str):
|
|
logger.info(f"Loading training data from {data_files}...")
|
|
data_files = [data_files]
|
|
else:
|
|
raise ValueError(f"Invalid training data {data_files}!")
|
|
|
|
data_2_num_sample = {}
|
|
for data_file in data_files:
|
|
match = re.search("\[(\d*)\]", data_file)
|
|
if match:
|
|
max_sample_num = int(match.group(1))
|
|
data_file = re.sub("\[(\d*)\]", "", data_file)
|
|
else:
|
|
max_sample_num = None
|
|
data_2_num_sample[data_file] = max_sample_num
|
|
|
|
random.seed(seed)
|
|
|
|
train_datasets = []
|
|
for data_file, max_sample_num in data_2_num_sample.items():
|
|
|
|
if os.path.isdir(data_file) and os.path.exists(os.path.join(data_file, "dataset_info.json")):
|
|
# the dataset may be save_to_disk in advance
|
|
dataset = datasets.load_from_disk(data_file)
|
|
|
|
else:
|
|
# the dataset is a json file
|
|
dataset = datasets.load_dataset('json', data_files=data_file, split='train', cache_dir=cache_dir)
|
|
|
|
column_names = dataset.column_names
|
|
if "text" in column_names:
|
|
process_fn = partial(
|
|
Data._process_language_modeling,
|
|
tokenizer=tokenizer,
|
|
min_length=min_length,
|
|
max_length=max_length
|
|
)
|
|
elif "conversations" in column_names:
|
|
process_fn = partial(
|
|
Data._process_instruction_tuning,
|
|
tokenizer=tokenizer,
|
|
chat_template=chat_template,
|
|
min_length=min_length,
|
|
max_length=max_length
|
|
)
|
|
else:
|
|
raise ValueError(f"Found neither 'text' nor 'conversations' in the training data!")
|
|
|
|
dataset = dataset.map(process_fn, batched=True, num_proc=32, remove_columns=dataset.column_names, batch_size=32, with_indices=True, load_from_cache_file=load_from_cache_file)
|
|
|
|
if max_sample_num is not None and len(dataset) > max_sample_num:
|
|
dataset = dataset.train_test_split(max_sample_num, seed=seed)["test"]
|
|
|
|
# index column is useless in training
|
|
if "index" in dataset.column_names:
|
|
dataset = dataset.remove_columns(["index"])
|
|
|
|
train_datasets.append(dataset)
|
|
|
|
dataset = datasets.concatenate_datasets(train_datasets)
|
|
|
|
return dataset
|
|
|
|
def prepare_eval_data(data_files=None, tokenizer=None, max_length=4096, min_length=512, chat_template="vicuna", max_eval_num=None, cache_dir=None, seed=42, load_from_cache_file=None):
|
|
if data_files is None:
|
|
return None
|
|
|
|
random.seed(seed)
|
|
|
|
if max_eval_num is not None:
|
|
dataset = datasets.load_dataset('json', data_files=data_files, split=f'train[:{max_eval_num}]', cache_dir=cache_dir)
|
|
else:
|
|
dataset = datasets.load_dataset('json', data_files=data_files, split='train', cache_dir=cache_dir)
|
|
|
|
column_names = dataset.column_names
|
|
if "text" in column_names:
|
|
process_fn = partial(
|
|
Data._process_language_modeling,
|
|
tokenizer=tokenizer,
|
|
min_length=min_length,
|
|
max_length=max_length
|
|
)
|
|
elif "conversations" in column_names:
|
|
process_fn = partial(
|
|
Data._process_instruction_tuning,
|
|
tokenizer=tokenizer,
|
|
chat_template=chat_template,
|
|
min_length=min_length,
|
|
max_length=max_length,
|
|
eval_mode=True,
|
|
)
|
|
else:
|
|
raise ValueError(f"Found neither 'text' nor 'conversations' in the training data!")
|
|
|
|
dataset = dataset.map(process_fn, batched=True, num_proc=32, remove_columns=dataset.column_names, with_indices=True, load_from_cache_file=load_from_cache_file)
|
|
return dataset |