embed-bge-m3/FlagEmbedding/research/Long_LLM/activation_beacon/main/eval_infbench.py

235 lines
8.4 KiB
Python

import os
import datasets
import json
import torch
import pandas as pd
from tqdm import tqdm
from functools import partial
from typing import Optional, Dict, List
from dataclasses import dataclass, field, asdict
from accelerate import Accelerator
from transformers import HfArgumentParser, AutoTokenizer
from transformers.utils import logging
from torch.utils.data import DataLoader
from src import ModelArgs, DefaultDataCollator, FileLogger, get_model_and_tokenizer, makedirs, apply_chat_template
from .infbench_utils import TASK_TO_PATH, TASK_TO_MAX_NEW_TOKENS, get_score_one, create_prompt, get_answer
logger = logging.get_logger(__name__)
@dataclass
class Args(ModelArgs):
eval_data: str = field(
default="long-llm:infbench",
metadata={'help': 'The directory of all infbench evaluation data.'}
)
output_dir: str = field(
default="data/results/infbench/",
metadata={'help': 'The base directory for saving results and logs.'}
)
result_dir: Optional[str] = field(
default=None,
metadata={'help': 'The directory relative to output_dir for saving results.'}
)
tasks: List[str] = field(
default_factory=lambda: ['longbook_qa_eng', 'longbook_sum_eng'],
metadata={'help': 'Which dataset to evaluate?'}
)
prompt_template: str = field(
default="mistral",
metadata={'help': 'Which prompt template to use? (See infbench_utils.py for reference.)'}
)
max_length: int = field(
default=128000,
metadata={'help': 'Max input length.'}
)
truncate_from_middle: bool = field(
default=True,
metadata={'help': 'Truncate inputs from the middle.'}
)
load_result: bool = field(
default=False,
metadata={'help': 'Load result from saved files?'}
)
do_sample: bool = False
def process_infbench(data, indices, tokenizer, chat_template, task:str, prompt_template:str="mistral", max_length=100000, truncate_from_middle=True):
outputs = {'input_ids': [], 'attention_mask': [], "index": [], "answer": []}
# NOTE: high version datasets use LazyBatch to wrap data, which cannot be reverted to list of dicts, thus, we need to convert it to dict first
data = pd.DataFrame(dict(data)).to_dict(orient="records")
for sample, index in zip(data, indices):
prompt = create_prompt(sample, task, prompt_template)
answer = get_answer(sample, task)
if truncate_from_middle:
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
if len(tokenized_prompt) > max_length:
half = int(max_length / 2)
prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True) + tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
else:
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
prompt = tokenizer.decode(tokenized_prompt[-max_length:], skip_special_tokens=True)
encoded = apply_chat_template(
chat_template,
messages=[{'role': 'user', 'content': prompt}],
tokenizer=tokenizer,
add_generation_prompt=True,
).encoded
outputs["input_ids"].append(encoded["input_ids"])
outputs["attention_mask"].append(encoded["attention_mask"])
outputs["index"].append(index)
outputs["answer"].append(answer)
return outputs
@torch.no_grad()
def main():
parser = HfArgumentParser([Args])
args = parser.parse_args_into_dataclasses()[0]
accelerator = Accelerator(cpu=args.cpu)
model, tokenizer = get_model_and_tokenizer(args, device=accelerator.device)
if args.tasks == ["all"]:
tasks = list(TASK_TO_PATH.keys())
else:
tasks = args.tasks
with accelerator.main_process_first():
all_datasets = {}
for task in tasks:
process_fn = partial(
process_infbench,
tokenizer=tokenizer,
chat_template=args.chat_template,
max_length=args.max_length,
task=task,
prompt_template=args.prompt_template,
truncate_from_middle=args.truncate_from_middle,
)
path = os.path.join(args.eval_data, TASK_TO_PATH[task])
raw_dataset = datasets.load_dataset("json", data_files=path, cache_dir=args.dataset_cache_dir, split="train")
dataset = raw_dataset.map(process_fn, batched=True, num_proc=32, batch_size=10, with_indices=True, remove_columns=raw_dataset.column_names)
all_datasets[task] = dataset
result_dir = os.path.join(args.output_dir, args.result_dir)
metrics = {}
for i, (task, dataset) in enumerate(all_datasets.items()):
if accelerator.process_index == 0:
logger.info(f"Evaluating {task} ({i + 1} / {len(all_datasets)})...")
result_path = os.path.join(result_dir, f"{task}.json")
# get answers in advance
labels = dataset["answer"]
dataset = dataset.remove_columns(["answer"])
if not (args.load_result and os.path.exists(result_path)):
data_collator = DefaultDataCollator(tokenizer=tokenizer)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
collate_fn=data_collator,
# only pin memory when no gpu
pin_memory=not args.cpu,
)
# NOTE: prepare dataloader so the data moves to GPU automatically
dataloader = accelerator.prepare(dataloader)
indices = []
preds = []
max_new_tokens = TASK_TO_MAX_NEW_TOKENS[task]
for j, x in enumerate(tqdm(dataloader, desc="Generating")):
index = x.pop("index").tolist()
input_length = x["input_ids"].shape[1]
# NOTE: important to reset memory for every batch
if hasattr(model, "memory"):
model.memory.reset()
output = model.generate(
**x,
max_new_tokens=max_new_tokens,
)
if isinstance(output, torch.Tensor):
# 1, max_new_tokens
output = output[:, input_length:]
output = tokenizer.batch_decode(output, skip_special_tokens=True)
elif isinstance(output, list):
pass
if accelerator.num_processes > 1:
output = accelerator.gather_for_metrics(output)
index = accelerator.gather_for_metrics(index)
if accelerator.process_index == 0:
preds.extend(output)
indices.extend(index)
else:
if accelerator.process_index == 0:
preds = []
indices = []
with open(result_path, "r", encoding="utf-8") as f:
# the first line is metric
f.readline()
for line in f:
item = json.loads(line)
preds.append(item["pred"])
indices.append(len(indices))
if accelerator.process_index == 0:
scores = []
for label, pred in tqdm(zip(labels, preds)):
# NOTE: here we explicitly input model_name=None
score = get_score_one(pred, label, task, None)
scores.append(score)
score = round(sum(scores) / len(scores), 4)
logger.info(f"{task}: {score}")
metrics[task] = score
with open(makedirs(result_path), "w", encoding="utf-8") as f:
f.write(json.dumps(score, ensure_ascii=False) + "\n")
for index, pred, label in zip(indices, preds, labels):
item = {
"index": index,
"pred": pred,
"label": label,
}
f.write(json.dumps(item, ensure_ascii=False) + "\n")
if accelerator.process_index == 0:
# save config
args.save(os.path.join(result_dir, "config.json"))
avg = round(sum(metrics.values()) / len(metrics), 4)
metrics["avg"] = avg
file_logger = FileLogger(makedirs(os.path.join(args.output_dir, "metrics.log")))
file_logger.log(metrics, Args=asdict(args))
if __name__ == "__main__":
main()