embed-bge-m3/FlagEmbedding/research/Long_LLM/activation_beacon/main/eval_longbench.py

268 lines
10 KiB
Python

import os
import datasets
import json
import torch
from tqdm import tqdm
from typing import Optional, Dict, List
from functools import partial
from collections import defaultdict
from dataclasses import dataclass, field, asdict
from accelerate import Accelerator
from transformers import HfArgumentParser
from transformers.utils import logging
from torch.utils.data import DataLoader
from src import ModelArgs, DefaultDataCollator, FileLogger, get_model_and_tokenizer, makedirs, apply_chat_template
from .longbench_utils import DATASET2PROMPT, DATASET2MAXNEWTOKENS, DATASET2CATEGORY, scorer
logger = logging.get_logger(__name__)
@dataclass
class Args(ModelArgs):
eval_data: str = field(
default="long-llm:longbench/",
metadata={'help': 'The evaluation json data path.'}
)
output_dir: str = field(
default="data/results/longbench/",
metadata={'help': 'The base directory for saving results and logs.'}
)
result_dir: Optional[str] = field(
default=None,
metadata={'help': 'The directory relative to output_dir for saving results.'}
)
tasks: List[str] = field(
default_factory=lambda: ['narrativeqa', 'qasper', 'multifieldqa_en', 'hotpotqa', '2wikimqa', 'musique', 'gov_report', 'qmsum', 'multi_news', 'trec', 'triviaqa', 'samsum', 'lcc', 'repobench-p'],
metadata={'help': 'Which dataset to evaluate?'}
)
newline_as_eos: bool = field(
default=True,
metadata={'help': 'Whether to use new line as eos (for QA tasks only) or not.'}
)
max_length: int = field(
default=31500,
metadata={'help': 'Max input length.'}
)
truncate_from_middle: bool = field(
default=True,
metadata={'help': 'Truncate inputs from the middle.'}
)
load_result: bool = field(
default=False,
metadata={'help': 'Load result from saved files?'}
)
do_sample: bool = False
def process_longbench(data, indices, tokenizer, chat_template, task, max_length=3500, truncate_from_middle=True):
outputs = {'input_ids': [], 'attention_mask': [], "index": []}
for input, context, index in zip(data['input'], data['context'], indices):
prompt_template = DATASET2PROMPT[task]
prompt = prompt_template.format(input=input, context=context)
if truncate_from_middle:
tokenized_prompt = tokenizer.encode(prompt)
if len(tokenized_prompt) > max_length:
half = int(max_length / 2)
prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True) + tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
else:
tokenized_prompt = tokenizer.encode(prompt)
prompt = tokenizer.decode(tokenized_prompt[-max_length:], skip_special_tokens=True)
# in fewshot learning and code completion we do not need chat template
if not any(x in DATASET2CATEGORY[task] for x in ["Few-Shot Learning", "Code Completion"]):
encoded = apply_chat_template(
chat_template,
messages=[{'role': 'user', 'content': prompt}],
tokenizer=tokenizer,
add_generation_prompt=True,
).encoded
else:
encoded = tokenizer(prompt)
outputs["input_ids"].append(encoded["input_ids"])
outputs["attention_mask"].append(encoded["attention_mask"])
outputs["index"].append(index)
return outputs
@torch.no_grad()
def main():
parser = HfArgumentParser([Args])
args = parser.parse_args_into_dataclasses()[0]
accelerator = Accelerator(cpu=args.cpu)
model, tokenizer = get_model_and_tokenizer(args, device=accelerator.device)
if hasattr(model, "generation_config"):
eos_token_id = model.generation_config.eos_token_id
else:
eos_token_id = tokenizer.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
# stop generation for QA tasks when \n appears
if args.newline_as_eos:
eos_token_id.append(tokenizer.encode("\n", add_special_tokens=False)[-1])
if args.tasks == ["all"]:
tasks = list(DATASET2PROMPT.keys())
else:
tasks = args.tasks
with accelerator.main_process_first():
all_datasets = {}
for task in tasks:
process_fn = partial(
process_longbench,
tokenizer=tokenizer,
chat_template=args.chat_template,
task=task,
max_length=args.max_length,
truncate_from_middle=args.truncate_from_middle,
)
path = os.path.join(args.eval_data, f"{task}.jsonl")
raw_dataset = datasets.load_dataset("json", data_files=path, cache_dir=args.dataset_cache_dir, split="train")
dataset = raw_dataset.map(process_fn, batched=True, num_proc=32, batch_size=10, with_indices=True, remove_columns=raw_dataset.column_names)
all_datasets[task] = (raw_dataset, dataset)
result_dir = os.path.join(args.output_dir, args.result_dir)
metrics = {}
for i, task in enumerate(all_datasets.keys()):
if accelerator.process_index == 0:
logger.info(f"Evaluating {task} ({i + 1} / {len(all_datasets)})...")
result_path = os.path.join(result_dir, f"{task}.json")
raw_dataset, dataset = all_datasets[task]
if not (args.load_result and os.path.exists(result_path)):
data_collator = DefaultDataCollator(tokenizer=tokenizer)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
collate_fn=data_collator,
# only pin memory when no gpu
pin_memory=not args.cpu,
)
dataloader = accelerator.prepare(dataloader)
indices = []
preds = []
max_new_tokens = DATASET2MAXNEWTOKENS[task]
for i, x in enumerate(tqdm(dataloader, desc="Generating")):
index = x.pop("index").tolist()
input_length = x["input_ids"].shape[1]
# NOTE: important to reset memory for every batch
if hasattr(model, "memory"):
model.memory.reset()
kwargs = {"max_new_tokens": max_new_tokens}
if task in ["2wikimqa", "hotpotqa", "musique", "multifieldqa_en", "qasper", "narrativeqa", "samsum"]:
kwargs["eos_token_id"] = eos_token_id
# NOTE: very important to include \n as an eos token for QA tasks, otherwise the F1 score is devastating
output = model.generate(
**x,
**kwargs
)
if isinstance(output, torch.Tensor):
# 1, max_new_tokens
output = output[:, input_length:]
output = tokenizer.batch_decode(output, skip_special_tokens=True)
elif isinstance(output, list):
pass
if accelerator.num_processes > 1:
output = accelerator.gather_for_metrics(output)
index = accelerator.gather_for_metrics(index)
if accelerator.process_index == 0:
preds.extend(output)
indices.extend(index)
else:
if accelerator.process_index == 0:
preds = []
indices = []
with open(result_path, "r", encoding="utf-8") as f:
# the first line is the metric score
f.readline()
for line in f:
item = json.loads(line)
preds.append(item["pred"])
indices.append(len(indices))
if accelerator.process_index == 0:
answers = raw_dataset["answers"]
lengths = raw_dataset["length"]
all_classes = raw_dataset["all_classes"][0]
score = scorer(task, preds, answers, all_classes)
logger.info(f"{task}: {score}")
metrics[task] = score
with open(makedirs(result_path), "w", encoding="utf-8") as f:
f.write(json.dumps(score, ensure_ascii=False) + "\n")
for index, pred in zip(indices, preds):
sample = raw_dataset[index]
del sample["all_classes"]
del sample["context"]
del sample["language"]
del sample["_id"]
sample["pred"] = pred
f.write(json.dumps(sample, ensure_ascii=False) + "\n")
if accelerator.process_index == 0:
# save config
args.save(os.path.join(result_dir, "config.json"))
# compute category score
category_metrics = defaultdict(list)
for dataset, metric in metrics.items():
category = DATASET2CATEGORY[dataset]
category_metrics[category].append(metric)
for k, v in category_metrics.items():
# when evaluating on longbench_e, each metric is a dict of float
if isinstance(v[0], dict):
category_metric = {}
for kk in v[0].keys():
vv = [v[j][kk] for j in range(len(v))]
category_metric[kk] = round(sum(vv) / len(vv), 2)
category_metrics[k] = category_metric
else:
category_metrics[k] = round(sum(v) / len(v), 2)
# compute average score
if isinstance(next(iter(metrics.values())), dict):
avg = defaultdict(list)
for k, v in metrics.items():
for kk, vv in v.items():
avg[kk].append(vv)
for k, v in avg.items():
avg[k] = round(sum(v) / len(v), 2)
else:
avg = round(sum(metrics.values()) / len(metrics), 2)
metrics["avg"] = avg
file_logger = FileLogger(makedirs(os.path.join(args.output_dir, "metrics.log")))
file_logger.log(metrics, Args=asdict(args), Category_Metrics=category_metrics)
if __name__ == "__main__":
main()