268 lines
10 KiB
Python
268 lines
10 KiB
Python
import os
|
|
import datasets
|
|
import json
|
|
import torch
|
|
from tqdm import tqdm
|
|
from typing import Optional, Dict, List
|
|
from functools import partial
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass, field, asdict
|
|
from accelerate import Accelerator
|
|
from transformers import HfArgumentParser
|
|
from transformers.utils import logging
|
|
from torch.utils.data import DataLoader
|
|
|
|
from src import ModelArgs, DefaultDataCollator, FileLogger, get_model_and_tokenizer, makedirs, apply_chat_template
|
|
from .longbench_utils import DATASET2PROMPT, DATASET2MAXNEWTOKENS, DATASET2CATEGORY, scorer
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class Args(ModelArgs):
|
|
eval_data: str = field(
|
|
default="long-llm:longbench/",
|
|
metadata={'help': 'The evaluation json data path.'}
|
|
)
|
|
output_dir: str = field(
|
|
default="data/results/longbench/",
|
|
metadata={'help': 'The base directory for saving results and logs.'}
|
|
)
|
|
result_dir: Optional[str] = field(
|
|
default=None,
|
|
metadata={'help': 'The directory relative to output_dir for saving results.'}
|
|
)
|
|
|
|
tasks: List[str] = field(
|
|
default_factory=lambda: ['narrativeqa', 'qasper', 'multifieldqa_en', 'hotpotqa', '2wikimqa', 'musique', 'gov_report', 'qmsum', 'multi_news', 'trec', 'triviaqa', 'samsum', 'lcc', 'repobench-p'],
|
|
metadata={'help': 'Which dataset to evaluate?'}
|
|
)
|
|
newline_as_eos: bool = field(
|
|
default=True,
|
|
metadata={'help': 'Whether to use new line as eos (for QA tasks only) or not.'}
|
|
)
|
|
|
|
max_length: int = field(
|
|
default=31500,
|
|
metadata={'help': 'Max input length.'}
|
|
)
|
|
truncate_from_middle: bool = field(
|
|
default=True,
|
|
metadata={'help': 'Truncate inputs from the middle.'}
|
|
)
|
|
load_result: bool = field(
|
|
default=False,
|
|
metadata={'help': 'Load result from saved files?'}
|
|
)
|
|
|
|
do_sample: bool = False
|
|
|
|
|
|
def process_longbench(data, indices, tokenizer, chat_template, task, max_length=3500, truncate_from_middle=True):
|
|
outputs = {'input_ids': [], 'attention_mask': [], "index": []}
|
|
|
|
for input, context, index in zip(data['input'], data['context'], indices):
|
|
prompt_template = DATASET2PROMPT[task]
|
|
prompt = prompt_template.format(input=input, context=context)
|
|
|
|
if truncate_from_middle:
|
|
tokenized_prompt = tokenizer.encode(prompt)
|
|
if len(tokenized_prompt) > max_length:
|
|
half = int(max_length / 2)
|
|
prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True) + tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
|
|
else:
|
|
tokenized_prompt = tokenizer.encode(prompt)
|
|
prompt = tokenizer.decode(tokenized_prompt[-max_length:], skip_special_tokens=True)
|
|
|
|
# in fewshot learning and code completion we do not need chat template
|
|
if not any(x in DATASET2CATEGORY[task] for x in ["Few-Shot Learning", "Code Completion"]):
|
|
encoded = apply_chat_template(
|
|
chat_template,
|
|
messages=[{'role': 'user', 'content': prompt}],
|
|
tokenizer=tokenizer,
|
|
add_generation_prompt=True,
|
|
).encoded
|
|
else:
|
|
encoded = tokenizer(prompt)
|
|
|
|
outputs["input_ids"].append(encoded["input_ids"])
|
|
outputs["attention_mask"].append(encoded["attention_mask"])
|
|
outputs["index"].append(index)
|
|
|
|
return outputs
|
|
|
|
|
|
@torch.no_grad()
|
|
def main():
|
|
parser = HfArgumentParser([Args])
|
|
args = parser.parse_args_into_dataclasses()[0]
|
|
|
|
accelerator = Accelerator(cpu=args.cpu)
|
|
model, tokenizer = get_model_and_tokenizer(args, device=accelerator.device)
|
|
|
|
if hasattr(model, "generation_config"):
|
|
eos_token_id = model.generation_config.eos_token_id
|
|
else:
|
|
eos_token_id = tokenizer.eos_token_id
|
|
if isinstance(eos_token_id, int):
|
|
eos_token_id = [eos_token_id]
|
|
# stop generation for QA tasks when \n appears
|
|
if args.newline_as_eos:
|
|
eos_token_id.append(tokenizer.encode("\n", add_special_tokens=False)[-1])
|
|
|
|
if args.tasks == ["all"]:
|
|
tasks = list(DATASET2PROMPT.keys())
|
|
else:
|
|
tasks = args.tasks
|
|
|
|
with accelerator.main_process_first():
|
|
all_datasets = {}
|
|
|
|
for task in tasks:
|
|
process_fn = partial(
|
|
process_longbench,
|
|
tokenizer=tokenizer,
|
|
chat_template=args.chat_template,
|
|
task=task,
|
|
max_length=args.max_length,
|
|
truncate_from_middle=args.truncate_from_middle,
|
|
)
|
|
|
|
path = os.path.join(args.eval_data, f"{task}.jsonl")
|
|
raw_dataset = datasets.load_dataset("json", data_files=path, cache_dir=args.dataset_cache_dir, split="train")
|
|
dataset = raw_dataset.map(process_fn, batched=True, num_proc=32, batch_size=10, with_indices=True, remove_columns=raw_dataset.column_names)
|
|
|
|
all_datasets[task] = (raw_dataset, dataset)
|
|
|
|
result_dir = os.path.join(args.output_dir, args.result_dir)
|
|
|
|
metrics = {}
|
|
|
|
for i, task in enumerate(all_datasets.keys()):
|
|
if accelerator.process_index == 0:
|
|
logger.info(f"Evaluating {task} ({i + 1} / {len(all_datasets)})...")
|
|
|
|
result_path = os.path.join(result_dir, f"{task}.json")
|
|
|
|
raw_dataset, dataset = all_datasets[task]
|
|
|
|
if not (args.load_result and os.path.exists(result_path)):
|
|
data_collator = DefaultDataCollator(tokenizer=tokenizer)
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=args.batch_size,
|
|
collate_fn=data_collator,
|
|
# only pin memory when no gpu
|
|
pin_memory=not args.cpu,
|
|
)
|
|
|
|
dataloader = accelerator.prepare(dataloader)
|
|
|
|
indices = []
|
|
preds = []
|
|
max_new_tokens = DATASET2MAXNEWTOKENS[task]
|
|
|
|
for i, x in enumerate(tqdm(dataloader, desc="Generating")):
|
|
index = x.pop("index").tolist()
|
|
input_length = x["input_ids"].shape[1]
|
|
|
|
# NOTE: important to reset memory for every batch
|
|
if hasattr(model, "memory"):
|
|
model.memory.reset()
|
|
|
|
kwargs = {"max_new_tokens": max_new_tokens}
|
|
if task in ["2wikimqa", "hotpotqa", "musique", "multifieldqa_en", "qasper", "narrativeqa", "samsum"]:
|
|
kwargs["eos_token_id"] = eos_token_id
|
|
|
|
# NOTE: very important to include \n as an eos token for QA tasks, otherwise the F1 score is devastating
|
|
output = model.generate(
|
|
**x,
|
|
**kwargs
|
|
)
|
|
if isinstance(output, torch.Tensor):
|
|
# 1, max_new_tokens
|
|
output = output[:, input_length:]
|
|
output = tokenizer.batch_decode(output, skip_special_tokens=True)
|
|
elif isinstance(output, list):
|
|
pass
|
|
|
|
if accelerator.num_processes > 1:
|
|
output = accelerator.gather_for_metrics(output)
|
|
index = accelerator.gather_for_metrics(index)
|
|
|
|
if accelerator.process_index == 0:
|
|
preds.extend(output)
|
|
indices.extend(index)
|
|
else:
|
|
if accelerator.process_index == 0:
|
|
preds = []
|
|
indices = []
|
|
|
|
with open(result_path, "r", encoding="utf-8") as f:
|
|
# the first line is the metric score
|
|
f.readline()
|
|
|
|
for line in f:
|
|
item = json.loads(line)
|
|
preds.append(item["pred"])
|
|
indices.append(len(indices))
|
|
|
|
if accelerator.process_index == 0:
|
|
answers = raw_dataset["answers"]
|
|
lengths = raw_dataset["length"]
|
|
all_classes = raw_dataset["all_classes"][0]
|
|
score = scorer(task, preds, answers, all_classes)
|
|
|
|
logger.info(f"{task}: {score}")
|
|
metrics[task] = score
|
|
|
|
with open(makedirs(result_path), "w", encoding="utf-8") as f:
|
|
f.write(json.dumps(score, ensure_ascii=False) + "\n")
|
|
for index, pred in zip(indices, preds):
|
|
sample = raw_dataset[index]
|
|
del sample["all_classes"]
|
|
del sample["context"]
|
|
del sample["language"]
|
|
del sample["_id"]
|
|
sample["pred"] = pred
|
|
f.write(json.dumps(sample, ensure_ascii=False) + "\n")
|
|
|
|
if accelerator.process_index == 0:
|
|
# save config
|
|
args.save(os.path.join(result_dir, "config.json"))
|
|
|
|
# compute category score
|
|
category_metrics = defaultdict(list)
|
|
for dataset, metric in metrics.items():
|
|
category = DATASET2CATEGORY[dataset]
|
|
category_metrics[category].append(metric)
|
|
for k, v in category_metrics.items():
|
|
# when evaluating on longbench_e, each metric is a dict of float
|
|
if isinstance(v[0], dict):
|
|
category_metric = {}
|
|
for kk in v[0].keys():
|
|
vv = [v[j][kk] for j in range(len(v))]
|
|
category_metric[kk] = round(sum(vv) / len(vv), 2)
|
|
category_metrics[k] = category_metric
|
|
else:
|
|
category_metrics[k] = round(sum(v) / len(v), 2)
|
|
|
|
# compute average score
|
|
if isinstance(next(iter(metrics.values())), dict):
|
|
avg = defaultdict(list)
|
|
for k, v in metrics.items():
|
|
for kk, vv in v.items():
|
|
avg[kk].append(vv)
|
|
for k, v in avg.items():
|
|
avg[k] = round(sum(v) / len(v), 2)
|
|
else:
|
|
avg = round(sum(metrics.values()) / len(metrics), 2)
|
|
metrics["avg"] = avg
|
|
|
|
file_logger = FileLogger(makedirs(os.path.join(args.output_dir, "metrics.log")))
|
|
file_logger.log(metrics, Args=asdict(args), Category_Metrics=category_metrics)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|