embed-bge-m3/FlagEmbedding/research/Long_LLM/activation_beacon/main/pretrain_data.py

204 lines
8.0 KiB
Python

import os
import json
import random
import math
import datasets
from tqdm import tqdm
from typing import List
from datetime import timedelta
from accelerate import Accelerator, InitProcessGroupKwargs
from dataclasses import dataclass, field
from transformers import HfArgumentParser
from transformers.utils import logging
from transformers.tokenization_utils import PreTrainedTokenizer
from src import split_file_dir_name_ext, get_model_and_tokenizer, format_numel_str, ModelArgs
logger = logging.get_logger(__name__)
@dataclass
class Args(ModelArgs):
config: str = field(
default="data/config/slimpajama.json",
metadata={'help': 'Configuration json path for standard pretraining (concatenating multiple documents to form instances of equal lengths).'}
)
train_data: str = field(
default="long-llm:slimpajama",
metadata={'help': 'Directory of training data (multiple json files whose name correspond to the ones in config).'}
)
output_dir: str = field(
default="data/pretrain/llama-8K_2B",
metadata={'help': 'Output directory for results and logs.'}
)
num_token: List[str] = field(
default_factory=lambda: ["8192:2B"],
metadata={'help': 'How many tokens to use for a specified length? (T/t for trillion, B/b for billion, M/m for million)'}
)
add_bos: bool = field(
default=True,
metadata={'help': 'Add bos at the end of each document?'}
)
add_eos: bool = field(
default=True,
metadata={'help': 'Add eos at the end of each document?'}
)
seed: int = field(
default=123,
metadata={'help': 'Random seed.'}
)
def prepare_pretrain_data(data_files, tokenizer: PreTrainedTokenizer, config: dict, length_2_num_token: dict, add_bos:bool=True, add_eos:bool=True, seed=42, cache_dir=None, load_from_cache_file=None):
random.seed(seed)
if isinstance(data_files, list):
data_files = data_files[0]
assert os.path.isdir(data_files), f"Make sure the data_files parameter is a directory containing the pretraining data json files! Found {data_files}."
def _process(data):
input_ids = tokenizer(data["text"], add_special_tokens=False)["input_ids"]
return {"input_ids": input_ids}
num_token_avg_per_source = config["num_tokens_avg"]
mixture = config["mixture"]
# concatenate all input_ids and partiton them according to num_instances
outputs = {"input_ids": [], "attention_mask": [], "labels": [], "length": []}
for file_name in os.listdir(data_files):
file_path = os.path.join(data_files, file_name)
dataset_name = split_file_dir_name_ext(file_path)[1]
if dataset_name not in mixture:
continue
mix_portion = mixture[dataset_name] / 100
if mix_portion == 0:
continue
num_token_this_dataset = {k: math.ceil(v * mix_portion) for k, v in length_2_num_token.items()}
num_instances_this_dataset = {k: math.ceil(v / k) for k, v in num_token_this_dataset.items()}
info = {k: format_numel_str(v) for k, v in num_token_this_dataset.items()}
logger.info(f"processing {dataset_name} dataset, generating {info} tokens...")
# tokenize all records
dataset = datasets.load_dataset("json", data_files=file_path, split="train", cache_dir=cache_dir)
dataset = dataset.map(_process, batched=True, num_proc=32, remove_columns=dataset.column_names, batch_size=100, load_from_cache_file=load_from_cache_file)
tqdm_bar = tqdm(total=sum(num_instances_this_dataset.values()))
max_length_candidates = [k for k, v in num_instances_this_dataset.items() if v > 0]
max_length = random.choice(max_length_candidates)
input_ids = []
for x in dataset:
sample_input_ids = x["input_ids"]
if add_bos:
assert tokenizer.bos_token_id is not None, f"Make sure the bos_token_id exists when enable add_eos."
sample_input_ids = [tokenizer.bos_token_id] + sample_input_ids
if add_eos:
assert tokenizer.eos_token_id is not None, f"Make sure the eos_token_id exists when enable add_eos."
sample_input_ids = sample_input_ids + [tokenizer.eos_token_id]
# add input_ids
input_ids.extend(sample_input_ids)
if len(input_ids) >= max_length:
cursor = 0
while cursor + max_length <= len(input_ids):
instance_input_ids = input_ids[cursor: cursor + max_length].copy()
instance_attention_mask = [1 for _ in instance_input_ids]
instance_labels = instance_input_ids.copy()
# move the cursor
cursor += max_length
# add to final data
outputs["input_ids"].append(instance_input_ids)
outputs["attention_mask"].append(instance_attention_mask)
outputs["labels"].append(instance_labels)
outputs["length"].append(max_length)
# update num_instances
num_instances_this_dataset[max_length] -= 1
tqdm_bar.update(1)
# sample new max_length
max_length_candidates = [k for k, v in num_instances_this_dataset.items() if v > 0]
if len(max_length_candidates) == 0:
# all needed data have been collected
break
elif len(max_length_candidates) == 1:
max_length = max_length_candidates[0]
else:
max_length = random.choice(max_length_candidates)
# remove input_ids that have been saved in outputs
input_ids = input_ids[cursor:]
# all needed data have been collected
if len(max_length_candidates) == 0:
break
tqdm_bar.close()
if len(max_length_candidates) > 0:
logger.warning(f"There are not enough data ! The remainings are {num_instances_this_dataset} instances for {dataset_name} dataset. Consider increase the corresponding data in {data_files}.")
dataset = datasets.Dataset.from_dict(outputs)
return dataset
if __name__ == "__main__":
parser = HfArgumentParser([Args])
args: Args = parser.parse_args_into_dataclasses()[0]
accelerator = Accelerator(cpu=args.cpu, kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(days=10))])
# this script may be executed in DDP, so we make sure the dataset is create only on the main process
if accelerator.process_index == 0:
tokenizer = get_model_and_tokenizer(args, return_tokenizer_only=True)
if args.add_eos:
assert tokenizer.eos_token_id is not None, "Make sure the eos_token_id is not None when enabling add_eos!"
with open(args.config, encoding="utf-8") as f:
config = json.load(f)
length_2_num_token = {}
for x in args.num_token:
length, ntok = x.split(":")
length = int(length)
if ntok.lower().endswith("t"):
ntok = float(ntok[:-1]) * 1e12
elif ntok.lower().endswith("b"):
ntok = float(ntok[:-1]) * 1e9
elif ntok.lower().endswith("m"):
ntok = float(ntok[:-1]) * 1e6
else:
raise ValueError(f"Make sure num_token ends with T/t/B/b/M/m!")
length_2_num_token[length] = ntok
pretrain_dataset = prepare_pretrain_data(
args.train_data,
tokenizer=tokenizer,
config=config,
length_2_num_token=length_2_num_token,
add_bos=args.add_bos,
add_eos=args.add_eos,
seed=args.seed,
cache_dir=args.dataset_cache_dir,
)
logger.info(f"Saving dataset to {args.output_dir}...")
pretrain_dataset.save_to_disk(args.output_dir)
accelerator.wait_for_everyone()