235 lines
8.4 KiB
Python
235 lines
8.4 KiB
Python
import os
|
|
import datasets
|
|
import json
|
|
import torch
|
|
import pandas as pd
|
|
from tqdm import tqdm
|
|
from functools import partial
|
|
from typing import Optional, Dict, List
|
|
from dataclasses import dataclass, field, asdict
|
|
from accelerate import Accelerator
|
|
from transformers import HfArgumentParser, AutoTokenizer
|
|
from transformers.utils import logging
|
|
from torch.utils.data import DataLoader
|
|
|
|
from src import ModelArgs, DefaultDataCollator, FileLogger, get_model_and_tokenizer, makedirs, apply_chat_template
|
|
from .infbench_utils import TASK_TO_PATH, TASK_TO_MAX_NEW_TOKENS, get_score_one, create_prompt, get_answer
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class Args(ModelArgs):
|
|
eval_data: str = field(
|
|
default="long-llm:infbench",
|
|
metadata={'help': 'The directory of all infbench evaluation data.'}
|
|
)
|
|
output_dir: str = field(
|
|
default="data/results/infbench/",
|
|
metadata={'help': 'The base directory for saving results and logs.'}
|
|
)
|
|
result_dir: Optional[str] = field(
|
|
default=None,
|
|
metadata={'help': 'The directory relative to output_dir for saving results.'}
|
|
)
|
|
|
|
tasks: List[str] = field(
|
|
default_factory=lambda: ['longbook_qa_eng', 'longbook_sum_eng'],
|
|
metadata={'help': 'Which dataset to evaluate?'}
|
|
)
|
|
prompt_template: str = field(
|
|
default="mistral",
|
|
metadata={'help': 'Which prompt template to use? (See infbench_utils.py for reference.)'}
|
|
)
|
|
|
|
max_length: int = field(
|
|
default=128000,
|
|
metadata={'help': 'Max input length.'}
|
|
)
|
|
truncate_from_middle: bool = field(
|
|
default=True,
|
|
metadata={'help': 'Truncate inputs from the middle.'}
|
|
)
|
|
load_result: bool = field(
|
|
default=False,
|
|
metadata={'help': 'Load result from saved files?'}
|
|
)
|
|
|
|
do_sample: bool = False
|
|
|
|
|
|
def process_infbench(data, indices, tokenizer, chat_template, task:str, prompt_template:str="mistral", max_length=100000, truncate_from_middle=True):
|
|
outputs = {'input_ids': [], 'attention_mask': [], "index": [], "answer": []}
|
|
|
|
# NOTE: high version datasets use LazyBatch to wrap data, which cannot be reverted to list of dicts, thus, we need to convert it to dict first
|
|
data = pd.DataFrame(dict(data)).to_dict(orient="records")
|
|
|
|
for sample, index in zip(data, indices):
|
|
prompt = create_prompt(sample, task, prompt_template)
|
|
answer = get_answer(sample, task)
|
|
|
|
if truncate_from_middle:
|
|
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
|
if len(tokenized_prompt) > max_length:
|
|
half = int(max_length / 2)
|
|
prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True) + tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
|
|
else:
|
|
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
|
prompt = tokenizer.decode(tokenized_prompt[-max_length:], skip_special_tokens=True)
|
|
|
|
encoded = apply_chat_template(
|
|
chat_template,
|
|
messages=[{'role': 'user', 'content': prompt}],
|
|
tokenizer=tokenizer,
|
|
add_generation_prompt=True,
|
|
).encoded
|
|
|
|
outputs["input_ids"].append(encoded["input_ids"])
|
|
outputs["attention_mask"].append(encoded["attention_mask"])
|
|
outputs["index"].append(index)
|
|
outputs["answer"].append(answer)
|
|
|
|
return outputs
|
|
|
|
|
|
@torch.no_grad()
|
|
def main():
|
|
parser = HfArgumentParser([Args])
|
|
args = parser.parse_args_into_dataclasses()[0]
|
|
|
|
accelerator = Accelerator(cpu=args.cpu)
|
|
model, tokenizer = get_model_and_tokenizer(args, device=accelerator.device)
|
|
|
|
if args.tasks == ["all"]:
|
|
tasks = list(TASK_TO_PATH.keys())
|
|
else:
|
|
tasks = args.tasks
|
|
|
|
with accelerator.main_process_first():
|
|
all_datasets = {}
|
|
|
|
for task in tasks:
|
|
process_fn = partial(
|
|
process_infbench,
|
|
tokenizer=tokenizer,
|
|
chat_template=args.chat_template,
|
|
max_length=args.max_length,
|
|
task=task,
|
|
prompt_template=args.prompt_template,
|
|
truncate_from_middle=args.truncate_from_middle,
|
|
)
|
|
|
|
path = os.path.join(args.eval_data, TASK_TO_PATH[task])
|
|
raw_dataset = datasets.load_dataset("json", data_files=path, cache_dir=args.dataset_cache_dir, split="train")
|
|
dataset = raw_dataset.map(process_fn, batched=True, num_proc=32, batch_size=10, with_indices=True, remove_columns=raw_dataset.column_names)
|
|
|
|
all_datasets[task] = dataset
|
|
|
|
result_dir = os.path.join(args.output_dir, args.result_dir)
|
|
|
|
metrics = {}
|
|
|
|
for i, (task, dataset) in enumerate(all_datasets.items()):
|
|
if accelerator.process_index == 0:
|
|
logger.info(f"Evaluating {task} ({i + 1} / {len(all_datasets)})...")
|
|
|
|
result_path = os.path.join(result_dir, f"{task}.json")
|
|
|
|
# get answers in advance
|
|
labels = dataset["answer"]
|
|
dataset = dataset.remove_columns(["answer"])
|
|
|
|
if not (args.load_result and os.path.exists(result_path)):
|
|
data_collator = DefaultDataCollator(tokenizer=tokenizer)
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=args.batch_size,
|
|
collate_fn=data_collator,
|
|
# only pin memory when no gpu
|
|
pin_memory=not args.cpu,
|
|
)
|
|
|
|
# NOTE: prepare dataloader so the data moves to GPU automatically
|
|
dataloader = accelerator.prepare(dataloader)
|
|
|
|
indices = []
|
|
preds = []
|
|
max_new_tokens = TASK_TO_MAX_NEW_TOKENS[task]
|
|
|
|
for j, x in enumerate(tqdm(dataloader, desc="Generating")):
|
|
index = x.pop("index").tolist()
|
|
input_length = x["input_ids"].shape[1]
|
|
|
|
# NOTE: important to reset memory for every batch
|
|
if hasattr(model, "memory"):
|
|
model.memory.reset()
|
|
|
|
output = model.generate(
|
|
**x,
|
|
max_new_tokens=max_new_tokens,
|
|
)
|
|
|
|
if isinstance(output, torch.Tensor):
|
|
# 1, max_new_tokens
|
|
output = output[:, input_length:]
|
|
output = tokenizer.batch_decode(output, skip_special_tokens=True)
|
|
elif isinstance(output, list):
|
|
pass
|
|
|
|
if accelerator.num_processes > 1:
|
|
output = accelerator.gather_for_metrics(output)
|
|
index = accelerator.gather_for_metrics(index)
|
|
|
|
if accelerator.process_index == 0:
|
|
preds.extend(output)
|
|
indices.extend(index)
|
|
else:
|
|
if accelerator.process_index == 0:
|
|
preds = []
|
|
indices = []
|
|
|
|
with open(result_path, "r", encoding="utf-8") as f:
|
|
# the first line is metric
|
|
f.readline()
|
|
|
|
for line in f:
|
|
item = json.loads(line)
|
|
preds.append(item["pred"])
|
|
indices.append(len(indices))
|
|
|
|
if accelerator.process_index == 0:
|
|
scores = []
|
|
for label, pred in tqdm(zip(labels, preds)):
|
|
# NOTE: here we explicitly input model_name=None
|
|
score = get_score_one(pred, label, task, None)
|
|
scores.append(score)
|
|
score = round(sum(scores) / len(scores), 4)
|
|
|
|
logger.info(f"{task}: {score}")
|
|
metrics[task] = score
|
|
|
|
with open(makedirs(result_path), "w", encoding="utf-8") as f:
|
|
f.write(json.dumps(score, ensure_ascii=False) + "\n")
|
|
for index, pred, label in zip(indices, preds, labels):
|
|
item = {
|
|
"index": index,
|
|
"pred": pred,
|
|
"label": label,
|
|
}
|
|
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
|
|
|
if accelerator.process_index == 0:
|
|
# save config
|
|
args.save(os.path.join(result_dir, "config.json"))
|
|
|
|
avg = round(sum(metrics.values()) / len(metrics), 4)
|
|
metrics["avg"] = avg
|
|
|
|
file_logger = FileLogger(makedirs(os.path.join(args.output_dir, "metrics.log")))
|
|
file_logger.log(metrics, Args=asdict(args))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|