169 lines
5.6 KiB
Python
169 lines
5.6 KiB
Python
import os
|
|
import datasets
|
|
import time
|
|
import torch
|
|
from typing import Optional
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass, field, asdict
|
|
from accelerate import Accelerator
|
|
from transformers import HfArgumentParser
|
|
from torch.utils.data import DataLoader
|
|
|
|
from src import ModelArgs, DefaultDataCollator, FileLogger, get_model_and_tokenizer, makedirs, split_file_dir_name_ext, evaluate_perplexity
|
|
|
|
|
|
@dataclass
|
|
class Args(ModelArgs):
|
|
eval_data: str = field(
|
|
default="long-llm:lm/pg19.json",
|
|
metadata={'help': 'The evaluation json data path.'}
|
|
)
|
|
output_dir: str = field(
|
|
default="data/results/lm/",
|
|
metadata={'help': 'Output directory for results and logs.'}
|
|
)
|
|
|
|
retokenize: bool = field(
|
|
default=False,
|
|
metadata={'help': 'Retokenize the corpus?'}
|
|
)
|
|
|
|
padding_side: str = field(
|
|
default="right",
|
|
metadata={'help': 'Which side to pad?'}
|
|
)
|
|
stride: int = field(
|
|
default=2048,
|
|
metadata={'help': 'Streaming stride when evaluating perplexity.'}
|
|
)
|
|
|
|
max_sample_num: int = field(
|
|
default=100,
|
|
metadata={'help': 'How many samples to evaluate in eval_data?'}
|
|
)
|
|
min_length: Optional[int] = field(
|
|
default=None,
|
|
metadata={'help': 'Minimum length for input_ids.'}
|
|
)
|
|
|
|
|
|
def process_lm_pre(tokenizer, tokenize_max_char=None):
|
|
def _process(data):
|
|
outputs = {'input_ids': []}
|
|
for text in data['text']:
|
|
if tokenize_max_char is not None:
|
|
text = text[:tokenize_max_char]
|
|
|
|
outputs['input_ids'].append(tokenizer.encode(text, add_special_tokens=False))
|
|
return outputs
|
|
return _process
|
|
|
|
|
|
def process_lm(tokenizer, max_length=4096, stride=1024, min_length=None):
|
|
# stride=0 indicates we just use one forward pass with max_length for each text
|
|
if stride == 0:
|
|
stride = max_length
|
|
jump = True
|
|
else:
|
|
jump = False
|
|
|
|
test = tokenizer.encode("test")
|
|
has_bos = False
|
|
if test[0] == tokenizer.bos_token_id:
|
|
# NOTE: subtract 1 because it will be occupied by the bos token
|
|
max_length -= 1
|
|
has_bos = True
|
|
|
|
def _process(data, indices, **kwds):
|
|
outputs = defaultdict(list)
|
|
|
|
for text, index in zip(data["text"], indices):
|
|
input_ids = tokenizer.encode(text, add_special_tokens=False)
|
|
|
|
seq_len = len(input_ids)
|
|
prev_end_loc = 0
|
|
|
|
if min_length is not None and seq_len < min_length:
|
|
continue
|
|
|
|
for start_loc in range(0, seq_len, stride):
|
|
end_loc = min(start_loc + max_length, seq_len)
|
|
sub_seq_len = end_loc - start_loc
|
|
sub_trg_len = end_loc - prev_end_loc # may be different from stride on last loop
|
|
|
|
sub_input_ids = input_ids[start_loc: end_loc]
|
|
sub_attention_mask = [1 for _ in range(sub_seq_len)]
|
|
if has_bos:
|
|
sub_input_ids.insert(0, tokenizer.bos_token_id)
|
|
sub_attention_mask.insert(0, 1)
|
|
sub_seq_len += 1
|
|
|
|
sub_labels = sub_input_ids.copy()
|
|
sub_labels[:-sub_trg_len] = [-100 for _ in range(sub_seq_len - sub_trg_len)]
|
|
|
|
sub_inputs = {
|
|
"index": index,
|
|
"input_ids": sub_input_ids,
|
|
"attention_mask": sub_attention_mask,
|
|
"labels": sub_labels,
|
|
}
|
|
|
|
for k, v in sub_inputs.items():
|
|
outputs[k].append(v)
|
|
|
|
prev_end_loc = end_loc
|
|
# NOTE: when end_loc is just the same as seq_len, jump out
|
|
if end_loc == seq_len or jump:
|
|
break
|
|
|
|
return outputs
|
|
return _process
|
|
|
|
|
|
@torch.no_grad()
|
|
def main():
|
|
parser = HfArgumentParser([Args])
|
|
args: Args = parser.parse_args_into_dataclasses()[0]
|
|
|
|
# increase timeout to avoid error
|
|
accelerator = Accelerator(cpu=args.cpu)
|
|
model, tokenizer = get_model_and_tokenizer(args, device=accelerator.device)
|
|
|
|
_, dataset_name, _ = split_file_dir_name_ext(args.eval_data)
|
|
|
|
process_fn = process_lm(tokenizer, max_length=args.max_length, stride=args.stride, min_length=args.min_length)
|
|
dataset = datasets.load_dataset("json", data_files=args.eval_data, cache_dir=args.dataset_cache_dir, split="train")
|
|
if len(dataset) > args.max_sample_num:
|
|
# slice out the first max_sample_num samples
|
|
dataset = dataset.train_test_split(args.max_sample_num, shuffle=False)["test"]
|
|
dataset = dataset.map(process_fn, batched=True, num_proc=32, remove_columns=dataset.column_names, keep_in_memory=True, with_indices=True)
|
|
|
|
data_collator = DefaultDataCollator(tokenizer=tokenizer)
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=args.batch_size,
|
|
collate_fn=data_collator,
|
|
# only pin memory when no gpu
|
|
pin_memory=not args.cpu,
|
|
)
|
|
accelerator.wait_for_everyone()
|
|
|
|
# NOTE: prepare dataloader so the data moves to GPU automatically
|
|
dataloader = accelerator.prepare(dataloader)
|
|
|
|
t1 = time.time()
|
|
perplexity = evaluate_perplexity(model, dataloader, accelerator)
|
|
t2 = time.time()
|
|
memory = torch.cuda.max_memory_allocated() / 1024**2
|
|
metrics = {"perplexity": perplexity, "time": round((t2 - t1) / len(dataset), 4), "memory": memory}
|
|
|
|
if accelerator.process_index == 0:
|
|
log_path = os.path.join(args.output_dir, f"{dataset_name}.log")
|
|
|
|
file_logger = FileLogger(makedirs(log_path))
|
|
file_logger.log(metrics, Args=asdict(args))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|