143 lines
5.4 KiB
Python
143 lines
5.4 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
# Copyright (c) EleutherAI Inc, and its affiliates.
|
|
import csv
|
|
import os
|
|
|
|
from evalscope.benchmarks import Benchmark
|
|
from evalscope.benchmarks.data_adapter import DataAdapter
|
|
from evalscope.constants import EvalType, OutputType
|
|
from evalscope.utils import get_logger
|
|
|
|
# flake8: noqa
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
@Benchmark.register(
|
|
name='trivia_qa',
|
|
pretty_name='TriviaQA',
|
|
tags=['QA', 'Reading Comprehension'],
|
|
description=
|
|
'TriviaQA is a large-scale reading comprehension dataset consisting of question-answer pairs collected from trivia websites. It includes questions with multiple possible answers, making it suitable for evaluating the ability of models to understand and generate answers based on context.', # noqa: E501
|
|
dataset_id='modelscope/trivia_qa',
|
|
subset_list=['default'],
|
|
metric_list=['AverageAccuracy'],
|
|
few_shot_num=5,
|
|
train_split='dev',
|
|
eval_split='test',
|
|
)
|
|
class TriviaQaAdapter(DataAdapter):
|
|
|
|
def __init__(self, **kwargs):
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
def load_from_disk(self, dataset_name_or_path, subset_list, work_dir, **kwargs) -> dict:
|
|
data_dict = {}
|
|
for subset_name in subset_list:
|
|
data_dict[subset_name] = {}
|
|
for split in [self.train_split, self.eval_split]:
|
|
if os.path.exists(dataset_name_or_path):
|
|
file_path = os.path.join(dataset_name_or_path, f'trivia-{split}.qa.csv')
|
|
else:
|
|
file_path = os.path.join(work_dir, dataset_name_or_path, f'trivia-{split}.qa.csv')
|
|
if os.path.exists(file_path):
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
reader = csv.reader(f, delimiter='\t')
|
|
split_data = []
|
|
for row in reader:
|
|
assert len(row) == 2
|
|
question = row[0]
|
|
answers = eval(row[1])
|
|
split_data.append({
|
|
'input': [{
|
|
'role': 'system',
|
|
'content': 'Follow the given examples and answer the question.'
|
|
}, {
|
|
'role': 'user',
|
|
'content': question
|
|
}],
|
|
'ideal':
|
|
answers
|
|
})
|
|
data_dict[subset_name][split] = split_data
|
|
|
|
return data_dict
|
|
|
|
def gen_prompt(self, input_d: dict, subset_name: str, few_shot_list: list, **kwargs) -> dict:
|
|
"""
|
|
Generate model prompt from raw input, unify the prompt format for TriviaQA benchmark.
|
|
|
|
Args:
|
|
input_d (dict): The raw input. A single data format of the TriviaQA:
|
|
|
|
{
|
|
"input": [
|
|
{"role": "system", "content": "Follow the given examples and answer the question."},
|
|
{"role": "user", "content": "Which Lloyd Webber musical premiered in the US on 10th December 1993?"}
|
|
],
|
|
"ideal": [
|
|
"Sunset Blvd",
|
|
"West Sunset Boulevard",
|
|
"Sunset Boulevard",
|
|
"Sunset Bulevard",
|
|
"Sunset Blvd.",
|
|
"sunset boulevard",
|
|
"sunset bulevard",
|
|
"west sunset boulevard",
|
|
"sunset blvd"
|
|
]
|
|
}
|
|
|
|
Returns:
|
|
{'data': [(context, continuation), ...]}
|
|
"""
|
|
|
|
def get_sys_prompt(inp: dict) -> str:
|
|
return inp['input'][0]['content']
|
|
|
|
if self.few_shot_num > 0:
|
|
sys_prompt = get_sys_prompt(input_d)
|
|
else:
|
|
sys_prompt = None
|
|
few_shot_prompts = [self._generate_prompt(input_d=sample, include_answer=True) for sample in few_shot_list]
|
|
context = '\n'.join(few_shot_prompts) + '\n'
|
|
context += self._generate_prompt(input_d=input_d, include_answer=False)
|
|
full_prompt = context
|
|
|
|
return self.gen_prompt_data(full_prompt, system_prompt=sys_prompt)
|
|
|
|
def get_gold_answer(self, input_d: dict) -> list:
|
|
# Get the gold choice
|
|
ans: list = input_d.get('ideal', [])
|
|
return ans
|
|
|
|
def parse_pred_result(self, result: str, raw_input_d: dict = None, eval_type: str = EvalType.CHECKPOINT) -> str:
|
|
"""
|
|
Parse the model output to get the answer.
|
|
|
|
Args:
|
|
result: Predicted answer from the model. A list of loglikelihood values for inputs pairs.
|
|
raw_input_d: The raw input. A single data format of the TriviaQA:
|
|
eval_type: The type of evaluation, e.g. 'checkpoint' or 'service' or 'custom'.
|
|
|
|
Returns:
|
|
The predicted answer.
|
|
"""
|
|
return result
|
|
|
|
def match(self, gold: list, pred: str) -> float:
|
|
lower_pred = pred.lower()
|
|
gold = [g.lower() for g in gold]
|
|
is_correct = any([cand in lower_pred for cand in gold])
|
|
return 1 if is_correct else 0
|
|
|
|
@classmethod
|
|
def _generate_prompt(cls, input_d: dict, include_answer=True) -> str:
|
|
|
|
example: str = f"Question: {input_d['input'][1]['content']}\nAnswer:"
|
|
if include_answer:
|
|
example += f" {input_d['ideal'][0]}\n\n"
|
|
|
|
return example
|