embed-bge-m3/FlagEmbedding/research/Long_LLM/activation_beacon/main/eval_lm.py

169 lines
5.6 KiB
Python

import os
import datasets
import time
import torch
from typing import Optional
from collections import defaultdict
from dataclasses import dataclass, field, asdict
from accelerate import Accelerator
from transformers import HfArgumentParser
from torch.utils.data import DataLoader
from src import ModelArgs, DefaultDataCollator, FileLogger, get_model_and_tokenizer, makedirs, split_file_dir_name_ext, evaluate_perplexity
@dataclass
class Args(ModelArgs):
eval_data: str = field(
default="long-llm:lm/pg19.json",
metadata={'help': 'The evaluation json data path.'}
)
output_dir: str = field(
default="data/results/lm/",
metadata={'help': 'Output directory for results and logs.'}
)
retokenize: bool = field(
default=False,
metadata={'help': 'Retokenize the corpus?'}
)
padding_side: str = field(
default="right",
metadata={'help': 'Which side to pad?'}
)
stride: int = field(
default=2048,
metadata={'help': 'Streaming stride when evaluating perplexity.'}
)
max_sample_num: int = field(
default=100,
metadata={'help': 'How many samples to evaluate in eval_data?'}
)
min_length: Optional[int] = field(
default=None,
metadata={'help': 'Minimum length for input_ids.'}
)
def process_lm_pre(tokenizer, tokenize_max_char=None):
def _process(data):
outputs = {'input_ids': []}
for text in data['text']:
if tokenize_max_char is not None:
text = text[:tokenize_max_char]
outputs['input_ids'].append(tokenizer.encode(text, add_special_tokens=False))
return outputs
return _process
def process_lm(tokenizer, max_length=4096, stride=1024, min_length=None):
# stride=0 indicates we just use one forward pass with max_length for each text
if stride == 0:
stride = max_length
jump = True
else:
jump = False
test = tokenizer.encode("test")
has_bos = False
if test[0] == tokenizer.bos_token_id:
# NOTE: subtract 1 because it will be occupied by the bos token
max_length -= 1
has_bos = True
def _process(data, indices, **kwds):
outputs = defaultdict(list)
for text, index in zip(data["text"], indices):
input_ids = tokenizer.encode(text, add_special_tokens=False)
seq_len = len(input_ids)
prev_end_loc = 0
if min_length is not None and seq_len < min_length:
continue
for start_loc in range(0, seq_len, stride):
end_loc = min(start_loc + max_length, seq_len)
sub_seq_len = end_loc - start_loc
sub_trg_len = end_loc - prev_end_loc # may be different from stride on last loop
sub_input_ids = input_ids[start_loc: end_loc]
sub_attention_mask = [1 for _ in range(sub_seq_len)]
if has_bos:
sub_input_ids.insert(0, tokenizer.bos_token_id)
sub_attention_mask.insert(0, 1)
sub_seq_len += 1
sub_labels = sub_input_ids.copy()
sub_labels[:-sub_trg_len] = [-100 for _ in range(sub_seq_len - sub_trg_len)]
sub_inputs = {
"index": index,
"input_ids": sub_input_ids,
"attention_mask": sub_attention_mask,
"labels": sub_labels,
}
for k, v in sub_inputs.items():
outputs[k].append(v)
prev_end_loc = end_loc
# NOTE: when end_loc is just the same as seq_len, jump out
if end_loc == seq_len or jump:
break
return outputs
return _process
@torch.no_grad()
def main():
parser = HfArgumentParser([Args])
args: Args = parser.parse_args_into_dataclasses()[0]
# increase timeout to avoid error
accelerator = Accelerator(cpu=args.cpu)
model, tokenizer = get_model_and_tokenizer(args, device=accelerator.device)
_, dataset_name, _ = split_file_dir_name_ext(args.eval_data)
process_fn = process_lm(tokenizer, max_length=args.max_length, stride=args.stride, min_length=args.min_length)
dataset = datasets.load_dataset("json", data_files=args.eval_data, cache_dir=args.dataset_cache_dir, split="train")
if len(dataset) > args.max_sample_num:
# slice out the first max_sample_num samples
dataset = dataset.train_test_split(args.max_sample_num, shuffle=False)["test"]
dataset = dataset.map(process_fn, batched=True, num_proc=32, remove_columns=dataset.column_names, keep_in_memory=True, with_indices=True)
data_collator = DefaultDataCollator(tokenizer=tokenizer)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
collate_fn=data_collator,
# only pin memory when no gpu
pin_memory=not args.cpu,
)
accelerator.wait_for_everyone()
# NOTE: prepare dataloader so the data moves to GPU automatically
dataloader = accelerator.prepare(dataloader)
t1 = time.time()
perplexity = evaluate_perplexity(model, dataloader, accelerator)
t2 = time.time()
memory = torch.cuda.max_memory_allocated() / 1024**2
metrics = {"perplexity": perplexity, "time": round((t2 - t1) / len(dataset), 4), "memory": memory}
if accelerator.process_index == 0:
log_path = os.path.join(args.output_dir, f"{dataset_name}.log")
file_logger = FileLogger(makedirs(log_path))
file_logger.log(metrics, Args=asdict(args))
if __name__ == "__main__":
main()