56 lines
2.2 KiB
Python
56 lines
2.2 KiB
Python
from dataclasses import dataclass, field
|
|
from functools import partial
|
|
from typing import Callable, Dict
|
|
|
|
from evalscope.metrics.metrics import mean, pass_at_k, weighted_mean
|
|
from evalscope.metrics.t2v_metrics import (blip2_score, clip_flant5_score, clip_score, fga_blip2_score, hpsv2_1_score,
|
|
hpsv2_score, image_reward_score, mps_score, pick_score)
|
|
|
|
|
|
@dataclass
|
|
class Metric:
|
|
name: str = 'default_metric'
|
|
object: Callable = field(default_factory=lambda: mean)
|
|
|
|
|
|
class MetricRegistry:
|
|
|
|
def __init__(self):
|
|
self.metrics: Dict[str, Metric] = {}
|
|
|
|
def register(self, metric: Metric):
|
|
self.metrics[metric.name] = metric
|
|
|
|
def get(self, name: str) -> Metric:
|
|
try:
|
|
return self.metrics[name]
|
|
except KeyError:
|
|
raise KeyError(f'Metric {name} not found in the registry. Available metrics: {self.list_metrics()}')
|
|
|
|
def list_metrics(self):
|
|
return list(self.metrics.keys())
|
|
|
|
|
|
metric_registry = MetricRegistry()
|
|
|
|
# Register metrics
|
|
metric_registry.register(Metric(name='AverageAccuracy', object=mean))
|
|
metric_registry.register(Metric(name='WeightedAverageAccuracy', object=weighted_mean))
|
|
metric_registry.register(Metric(name='AverageBLEU', object=mean))
|
|
metric_registry.register(Metric(name='AverageRouge', object=mean))
|
|
metric_registry.register(Metric(name='WeightedAverageBLEU', object=weighted_mean))
|
|
metric_registry.register(Metric(name='AveragePass@1', object=mean))
|
|
for k in range(1, 17):
|
|
metric_registry.register(Metric(name=f'Pass@{k}', object=partial(pass_at_k, k=k)))
|
|
|
|
# t2v_metrics
|
|
metric_registry.register(Metric(name='VQAScore', object=clip_flant5_score))
|
|
metric_registry.register(Metric(name='PickScore', object=pick_score))
|
|
metric_registry.register(Metric(name='CLIPScore', object=clip_score))
|
|
metric_registry.register(Metric(name='BLIPv2Score', object=blip2_score))
|
|
metric_registry.register(Metric(name='HPSv2Score', object=hpsv2_score))
|
|
metric_registry.register(Metric(name='HPSv2.1Score', object=hpsv2_1_score))
|
|
metric_registry.register(Metric(name='ImageRewardScore', object=image_reward_score))
|
|
metric_registry.register(Metric(name='FGA_BLIP2Score', object=fga_blip2_score))
|
|
metric_registry.register(Metric(name='MPS', object=mps_score))
|