79 lines
3.0 KiB
Python
79 lines
3.0 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import os.path
|
|
from collections import defaultdict
|
|
from typing import List, Optional, Union
|
|
|
|
from evalscope.benchmarks import Benchmark
|
|
from evalscope.constants import OutputType
|
|
from evalscope.metrics import mean
|
|
from evalscope.utils.io_utils import jsonl_to_list
|
|
from evalscope.utils.logger import get_logger
|
|
from .base import T2IBaseAdapter
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
@Benchmark.register(
|
|
name='evalmuse',
|
|
dataset_id='AI-ModelScope/T2V-Eval-Prompts',
|
|
model_adapter=OutputType.IMAGE_GENERATION,
|
|
output_types=[OutputType.IMAGE_GENERATION],
|
|
subset_list=['EvalMuse'],
|
|
metric_list=['FGA_BLIP2Score'],
|
|
few_shot_num=0,
|
|
train_split=None,
|
|
eval_split='test',
|
|
)
|
|
class EvalMuseAdapter(T2IBaseAdapter):
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
def load(self, **kwargs) -> dict:
|
|
if os.path.isfile(self.dataset_id):
|
|
data_list = jsonl_to_list(self.dataset_id)
|
|
data_dict = {self.subset_list[0]: {'test': data_list}}
|
|
return data_dict
|
|
else:
|
|
return super().load(**kwargs)
|
|
|
|
def get_gold_answer(self, input_d: dict) -> dict:
|
|
# return prompt and elements dict
|
|
return {'prompt': input_d.get('prompt'), 'tags': input_d.get('tags', {})}
|
|
|
|
def match(self, gold: dict, pred: str) -> dict:
|
|
# dummy match for general t2i
|
|
# pred is the image path, gold is the prompt
|
|
res = {}
|
|
for metric_name, metric_func in self.metrics.items():
|
|
if metric_name == 'FGA_BLIP2Score':
|
|
# For FGA_BLIP2Score, we need to pass the dictionary
|
|
score = metric_func(images=[pred], texts=[gold])[0][0]
|
|
else:
|
|
score = metric_func(images=[pred], texts=[gold['prompt']])[0][0]
|
|
if isinstance(score, dict):
|
|
for k, v in score.items():
|
|
res[f'{metric_name}:{k}'] = v.cpu().item()
|
|
else:
|
|
res[metric_name] = score.cpu().item()
|
|
return res
|
|
|
|
def compute_metric(self, review_res_list: Union[List[dict], List[List[dict]]], **kwargs) -> List[dict]:
|
|
"""
|
|
compute weighted mean of the bleu score of all samples
|
|
"""
|
|
items = super().compute_dict_metric(review_res_list, **kwargs)
|
|
# add statistics for each metric
|
|
new_items = defaultdict(list)
|
|
for metric_name, value_list in items.items():
|
|
if 'FGA_BLIP2Score' in metric_name and '(' in metric_name: # FGA_BLIP2Score element score
|
|
metrics_prefix = metric_name.split(':')[0]
|
|
category = metric_name.rpartition('(')[-1].split(')')[0]
|
|
category = category.split('-')[0].lower() # remove the suffix if exists
|
|
new_items[f'{metrics_prefix}:{category}'].extend(value_list)
|
|
else:
|
|
new_items[metric_name].extend(value_list)
|
|
|
|
# calculate mean for each metric
|
|
return [{'metric_name': k, 'score': mean(v), 'num': len(v)} for k, v in new_items.items()]
|