evalscope_v0.17.0/evalscope.0.17.0/evalscope/benchmarks/general_mcq/general_mcq_adapter.py

119 lines
5.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from collections import defaultdict
from evalscope.benchmarks import Benchmark, DataAdapter
from evalscope.constants import EvalType, OutputType
from evalscope.metrics import exact_match
from evalscope.metrics.completion_parsers import ResponseParser
from evalscope.utils.io_utils import csv_to_list, jsonl_to_list
from evalscope.utils.logger import get_logger
# flake8: noqa
logger = get_logger()
@Benchmark.register(
name='general_mcq',
pretty_name='General-MCQ',
description='A general multiple-choice question answering dataset.',
tags=['MCQ', 'Custom'],
dataset_id='general_mcq',
model_adapter=OutputType.GENERATION,
output_types=[OutputType.MULTIPLE_CHOICE, OutputType.GENERATION],
subset_list=['default'],
metric_list=['AverageAccuracy'],
few_shot_num=0,
train_split='dev',
eval_split='val',
prompt_template='请回答问题并选出其中的正确答案。你的回答的最后一行应该是这样的格式“答案是LETTER”不带引号其中 LETTER 是 A、B、C、D 中的一个。\n{query}',
query_template='问题:{question}\n{choices}\n答案: {answer}\n\n')
class GeneralMCQAdapter(DataAdapter):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.choices = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J']
def load_from_disk(self, dataset_name_or_path, subset_list, work_dir, **kwargs) -> dict:
data_dict = defaultdict(dict)
for subset_name in subset_list:
for split_name in [self.train_split, self.eval_split]:
# Check for files with different extensions
for ext, loader in [('.jsonl', jsonl_to_list), ('.csv', csv_to_list)]:
if os.path.exists(dataset_name_or_path):
file_path = os.path.join(dataset_name_or_path, f'{subset_name}_{split_name}{ext}')
else:
file_path = os.path.join(work_dir, dataset_name_or_path, f'{subset_name}_{split_name}{ext}')
if os.path.exists(file_path):
data_dict[subset_name][split_name] = loader(file_path)
break # Stop checking other extensions once a file is found
return dict(data_dict)
def gen_prompt(self, input_d: dict, subset_name: str, few_shot_list: list, **kwargs) -> dict:
"""
Generate model prompt from raw input, unify the prompt format for C-Eval benchmark.
Args:
input_d (dict): The raw input. A single data format of the C-Eval:
{'id': 0,
'question': '下列关于税法基本原则的表述中不正确的是____。',
'A': '税收法定原则包括税收要件法定原则和税务合法性原则',
'B': '税收公平原则源于法律上的平等性原则',
'C': '税收效率原则包含经济效率和行政效率两个方面',
'D': '税务机关按法定程序依法征税,可以自由做出减征、停征或免征税款的决定',
'answer': 'D'}
Returns:
{'data': ['prompt ...']}
"""
few_shot_prompts = [self._format_example(input_d=sample, include_answer=True) for sample in few_shot_list]
if len(few_shot_prompts) > 0:
context: str = '\n'.join(few_shot_prompts) + '\n'
else:
context = ''
context = context.strip() + self._format_example(input_d=input_d, include_answer=False)
full_prompt = self.prompt_template.format(query=context)
return self.gen_prompt_data(full_prompt)
def get_gold_answer(self, input_d: dict) -> str:
# Get the gold choice
return input_d.get('answer', '')
def parse_pred_result(self, result: str, raw_input_d: dict = None, eval_type: str = EvalType.CHECKPOINT) -> str:
"""
Parse the model output to get the answer. Could be the best choice index.
Args:
result: Predicted answer from the model. Usually a string for chat.
raw_input_d (dict): The raw input. Depending on the dataset.
eval_type: `checkpoint` or `service` or `custom`. Default is `checkpoint`.
Returns:
The parsed answer. Depending on the dataset. Usually a string for chat.
"""
if self.model_adapter == OutputType.MULTIPLE_CHOICE:
return result
else:
return ResponseParser.parse_first_option_with_choices(text=result, options=self.choices)
def match(self, gold: str, pred: str) -> float:
return exact_match(gold=gold, pred=pred)
def _format_example(self, input_d: dict, include_answer=True):
choices_str = '\n'.join([f'{choice}. {input_d[choice]}' for choice in self.choices if choice in input_d])
if include_answer:
return self.query_template.format(
question=input_d['question'], choices=choices_str, answer=input_d['answer'])
else:
return self.query_template.format(question=input_d['question'], choices=choices_str, answer='').rstrip()