evalscope_v0.17.0/evalscope.0.17.0/evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py

79 lines
3.0 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path
from collections import defaultdict
from typing import List, Optional, Union
from evalscope.benchmarks import Benchmark
from evalscope.constants import OutputType
from evalscope.metrics import mean
from evalscope.utils.io_utils import jsonl_to_list
from evalscope.utils.logger import get_logger
from .base import T2IBaseAdapter
logger = get_logger()
@Benchmark.register(
name='evalmuse',
dataset_id='AI-ModelScope/T2V-Eval-Prompts',
model_adapter=OutputType.IMAGE_GENERATION,
output_types=[OutputType.IMAGE_GENERATION],
subset_list=['EvalMuse'],
metric_list=['FGA_BLIP2Score'],
few_shot_num=0,
train_split=None,
eval_split='test',
)
class EvalMuseAdapter(T2IBaseAdapter):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def load(self, **kwargs) -> dict:
if os.path.isfile(self.dataset_id):
data_list = jsonl_to_list(self.dataset_id)
data_dict = {self.subset_list[0]: {'test': data_list}}
return data_dict
else:
return super().load(**kwargs)
def get_gold_answer(self, input_d: dict) -> dict:
# return prompt and elements dict
return {'prompt': input_d.get('prompt'), 'tags': input_d.get('tags', {})}
def match(self, gold: dict, pred: str) -> dict:
# dummy match for general t2i
# pred is the image path, gold is the prompt
res = {}
for metric_name, metric_func in self.metrics.items():
if metric_name == 'FGA_BLIP2Score':
# For FGA_BLIP2Score, we need to pass the dictionary
score = metric_func(images=[pred], texts=[gold])[0][0]
else:
score = metric_func(images=[pred], texts=[gold['prompt']])[0][0]
if isinstance(score, dict):
for k, v in score.items():
res[f'{metric_name}:{k}'] = v.cpu().item()
else:
res[metric_name] = score.cpu().item()
return res
def compute_metric(self, review_res_list: Union[List[dict], List[List[dict]]], **kwargs) -> List[dict]:
"""
compute weighted mean of the bleu score of all samples
"""
items = super().compute_dict_metric(review_res_list, **kwargs)
# add statistics for each metric
new_items = defaultdict(list)
for metric_name, value_list in items.items():
if 'FGA_BLIP2Score' in metric_name and '(' in metric_name: # FGA_BLIP2Score element score
metrics_prefix = metric_name.split(':')[0]
category = metric_name.rpartition('(')[-1].split(')')[0]
category = category.split('-')[0].lower() # remove the suffix if exists
new_items[f'{metrics_prefix}:{category}'].extend(value_list)
else:
new_items[metric_name].extend(value_list)
# calculate mean for each metric
return [{'metric_name': k, 'score': mean(v), 'num': len(v)} for k, v in new_items.items()]