evalscope_v0.17.0/evalscope.0.17.0/evalscope/benchmarks/trivia_qa/trivia_qa_adapter.py

143 lines
5.4 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright (c) EleutherAI Inc, and its affiliates.
import csv
import os
from evalscope.benchmarks import Benchmark
from evalscope.benchmarks.data_adapter import DataAdapter
from evalscope.constants import EvalType, OutputType
from evalscope.utils import get_logger
# flake8: noqa
logger = get_logger()
@Benchmark.register(
name='trivia_qa',
pretty_name='TriviaQA',
tags=['QA', 'Reading Comprehension'],
description=
'TriviaQA is a large-scale reading comprehension dataset consisting of question-answer pairs collected from trivia websites. It includes questions with multiple possible answers, making it suitable for evaluating the ability of models to understand and generate answers based on context.', # noqa: E501
dataset_id='modelscope/trivia_qa',
subset_list=['default'],
metric_list=['AverageAccuracy'],
few_shot_num=5,
train_split='dev',
eval_split='test',
)
class TriviaQaAdapter(DataAdapter):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def load_from_disk(self, dataset_name_or_path, subset_list, work_dir, **kwargs) -> dict:
data_dict = {}
for subset_name in subset_list:
data_dict[subset_name] = {}
for split in [self.train_split, self.eval_split]:
if os.path.exists(dataset_name_or_path):
file_path = os.path.join(dataset_name_or_path, f'trivia-{split}.qa.csv')
else:
file_path = os.path.join(work_dir, dataset_name_or_path, f'trivia-{split}.qa.csv')
if os.path.exists(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
reader = csv.reader(f, delimiter='\t')
split_data = []
for row in reader:
assert len(row) == 2
question = row[0]
answers = eval(row[1])
split_data.append({
'input': [{
'role': 'system',
'content': 'Follow the given examples and answer the question.'
}, {
'role': 'user',
'content': question
}],
'ideal':
answers
})
data_dict[subset_name][split] = split_data
return data_dict
def gen_prompt(self, input_d: dict, subset_name: str, few_shot_list: list, **kwargs) -> dict:
"""
Generate model prompt from raw input, unify the prompt format for TriviaQA benchmark.
Args:
input_d (dict): The raw input. A single data format of the TriviaQA:
{
"input": [
{"role": "system", "content": "Follow the given examples and answer the question."},
{"role": "user", "content": "Which Lloyd Webber musical premiered in the US on 10th December 1993?"}
],
"ideal": [
"Sunset Blvd",
"West Sunset Boulevard",
"Sunset Boulevard",
"Sunset Bulevard",
"Sunset Blvd.",
"sunset boulevard",
"sunset bulevard",
"west sunset boulevard",
"sunset blvd"
]
}
Returns:
{'data': [(context, continuation), ...]}
"""
def get_sys_prompt(inp: dict) -> str:
return inp['input'][0]['content']
if self.few_shot_num > 0:
sys_prompt = get_sys_prompt(input_d)
else:
sys_prompt = None
few_shot_prompts = [self._generate_prompt(input_d=sample, include_answer=True) for sample in few_shot_list]
context = '\n'.join(few_shot_prompts) + '\n'
context += self._generate_prompt(input_d=input_d, include_answer=False)
full_prompt = context
return self.gen_prompt_data(full_prompt, system_prompt=sys_prompt)
def get_gold_answer(self, input_d: dict) -> list:
# Get the gold choice
ans: list = input_d.get('ideal', [])
return ans
def parse_pred_result(self, result: str, raw_input_d: dict = None, eval_type: str = EvalType.CHECKPOINT) -> str:
"""
Parse the model output to get the answer.
Args:
result: Predicted answer from the model. A list of loglikelihood values for inputs pairs.
raw_input_d: The raw input. A single data format of the TriviaQA:
eval_type: The type of evaluation, e.g. 'checkpoint' or 'service' or 'custom'.
Returns:
The predicted answer.
"""
return result
def match(self, gold: list, pred: str) -> float:
lower_pred = pred.lower()
gold = [g.lower() for g in gold]
is_correct = any([cand in lower_pred for cand in gold])
return 1 if is_correct else 0
@classmethod
def _generate_prompt(cls, input_d: dict, include_answer=True) -> str:
example: str = f"Question: {input_d['input'][1]['content']}\nAnswer:"
if include_answer:
example += f" {input_d['ideal'][0]}\n\n"
return example