import os import datasets import json import torch import pandas as pd from tqdm import tqdm from functools import partial from typing import Optional, Dict, List from dataclasses import dataclass, field, asdict from accelerate import Accelerator from transformers import HfArgumentParser, AutoTokenizer from transformers.utils import logging from torch.utils.data import DataLoader from src import ModelArgs, DefaultDataCollator, FileLogger, get_model_and_tokenizer, makedirs, apply_chat_template from .infbench_utils import TASK_TO_PATH, TASK_TO_MAX_NEW_TOKENS, get_score_one, create_prompt, get_answer logger = logging.get_logger(__name__) @dataclass class Args(ModelArgs): eval_data: str = field( default="long-llm:infbench", metadata={'help': 'The directory of all infbench evaluation data.'} ) output_dir: str = field( default="data/results/infbench/", metadata={'help': 'The base directory for saving results and logs.'} ) result_dir: Optional[str] = field( default=None, metadata={'help': 'The directory relative to output_dir for saving results.'} ) tasks: List[str] = field( default_factory=lambda: ['longbook_qa_eng', 'longbook_sum_eng'], metadata={'help': 'Which dataset to evaluate?'} ) prompt_template: str = field( default="mistral", metadata={'help': 'Which prompt template to use? (See infbench_utils.py for reference.)'} ) max_length: int = field( default=128000, metadata={'help': 'Max input length.'} ) truncate_from_middle: bool = field( default=True, metadata={'help': 'Truncate inputs from the middle.'} ) load_result: bool = field( default=False, metadata={'help': 'Load result from saved files?'} ) do_sample: bool = False def process_infbench(data, indices, tokenizer, chat_template, task:str, prompt_template:str="mistral", max_length=100000, truncate_from_middle=True): outputs = {'input_ids': [], 'attention_mask': [], "index": [], "answer": []} # NOTE: high version datasets use LazyBatch to wrap data, which cannot be reverted to list of dicts, thus, we need to convert it to dict first data = pd.DataFrame(dict(data)).to_dict(orient="records") for sample, index in zip(data, indices): prompt = create_prompt(sample, task, prompt_template) answer = get_answer(sample, task) if truncate_from_middle: tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False) if len(tokenized_prompt) > max_length: half = int(max_length / 2) prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True) + tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True) else: tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False) prompt = tokenizer.decode(tokenized_prompt[-max_length:], skip_special_tokens=True) encoded = apply_chat_template( chat_template, messages=[{'role': 'user', 'content': prompt}], tokenizer=tokenizer, add_generation_prompt=True, ).encoded outputs["input_ids"].append(encoded["input_ids"]) outputs["attention_mask"].append(encoded["attention_mask"]) outputs["index"].append(index) outputs["answer"].append(answer) return outputs @torch.no_grad() def main(): parser = HfArgumentParser([Args]) args = parser.parse_args_into_dataclasses()[0] accelerator = Accelerator(cpu=args.cpu) model, tokenizer = get_model_and_tokenizer(args, device=accelerator.device) if args.tasks == ["all"]: tasks = list(TASK_TO_PATH.keys()) else: tasks = args.tasks with accelerator.main_process_first(): all_datasets = {} for task in tasks: process_fn = partial( process_infbench, tokenizer=tokenizer, chat_template=args.chat_template, max_length=args.max_length, task=task, prompt_template=args.prompt_template, truncate_from_middle=args.truncate_from_middle, ) path = os.path.join(args.eval_data, TASK_TO_PATH[task]) raw_dataset = datasets.load_dataset("json", data_files=path, cache_dir=args.dataset_cache_dir, split="train") dataset = raw_dataset.map(process_fn, batched=True, num_proc=32, batch_size=10, with_indices=True, remove_columns=raw_dataset.column_names) all_datasets[task] = dataset result_dir = os.path.join(args.output_dir, args.result_dir) metrics = {} for i, (task, dataset) in enumerate(all_datasets.items()): if accelerator.process_index == 0: logger.info(f"Evaluating {task} ({i + 1} / {len(all_datasets)})...") result_path = os.path.join(result_dir, f"{task}.json") # get answers in advance labels = dataset["answer"] dataset = dataset.remove_columns(["answer"]) if not (args.load_result and os.path.exists(result_path)): data_collator = DefaultDataCollator(tokenizer=tokenizer) dataloader = DataLoader( dataset, batch_size=args.batch_size, collate_fn=data_collator, # only pin memory when no gpu pin_memory=not args.cpu, ) # NOTE: prepare dataloader so the data moves to GPU automatically dataloader = accelerator.prepare(dataloader) indices = [] preds = [] max_new_tokens = TASK_TO_MAX_NEW_TOKENS[task] for j, x in enumerate(tqdm(dataloader, desc="Generating")): index = x.pop("index").tolist() input_length = x["input_ids"].shape[1] # NOTE: important to reset memory for every batch if hasattr(model, "memory"): model.memory.reset() output = model.generate( **x, max_new_tokens=max_new_tokens, ) if isinstance(output, torch.Tensor): # 1, max_new_tokens output = output[:, input_length:] output = tokenizer.batch_decode(output, skip_special_tokens=True) elif isinstance(output, list): pass if accelerator.num_processes > 1: output = accelerator.gather_for_metrics(output) index = accelerator.gather_for_metrics(index) if accelerator.process_index == 0: preds.extend(output) indices.extend(index) else: if accelerator.process_index == 0: preds = [] indices = [] with open(result_path, "r", encoding="utf-8") as f: # the first line is metric f.readline() for line in f: item = json.loads(line) preds.append(item["pred"]) indices.append(len(indices)) if accelerator.process_index == 0: scores = [] for label, pred in tqdm(zip(labels, preds)): # NOTE: here we explicitly input model_name=None score = get_score_one(pred, label, task, None) scores.append(score) score = round(sum(scores) / len(scores), 4) logger.info(f"{task}: {score}") metrics[task] = score with open(makedirs(result_path), "w", encoding="utf-8") as f: f.write(json.dumps(score, ensure_ascii=False) + "\n") for index, pred, label in zip(indices, preds, labels): item = { "index": index, "pred": pred, "label": label, } f.write(json.dumps(item, ensure_ascii=False) + "\n") if accelerator.process_index == 0: # save config args.save(os.path.join(result_dir, "config.json")) avg = round(sum(metrics.values()) / len(metrics), 4) metrics["avg"] = avg file_logger = FileLogger(makedirs(os.path.join(args.output_dir, "metrics.log"))) file_logger.log(metrics, Args=asdict(args)) if __name__ == "__main__": main()