evalscope_v0.17.0/evalscope.0.17.0/evalscope/metrics/named_metrics.py

56 lines
2.2 KiB
Python

from dataclasses import dataclass, field
from functools import partial
from typing import Callable, Dict
from evalscope.metrics.metrics import mean, pass_at_k, weighted_mean
from evalscope.metrics.t2v_metrics import (blip2_score, clip_flant5_score, clip_score, fga_blip2_score, hpsv2_1_score,
hpsv2_score, image_reward_score, mps_score, pick_score)
@dataclass
class Metric:
name: str = 'default_metric'
object: Callable = field(default_factory=lambda: mean)
class MetricRegistry:
def __init__(self):
self.metrics: Dict[str, Metric] = {}
def register(self, metric: Metric):
self.metrics[metric.name] = metric
def get(self, name: str) -> Metric:
try:
return self.metrics[name]
except KeyError:
raise KeyError(f'Metric {name} not found in the registry. Available metrics: {self.list_metrics()}')
def list_metrics(self):
return list(self.metrics.keys())
metric_registry = MetricRegistry()
# Register metrics
metric_registry.register(Metric(name='AverageAccuracy', object=mean))
metric_registry.register(Metric(name='WeightedAverageAccuracy', object=weighted_mean))
metric_registry.register(Metric(name='AverageBLEU', object=mean))
metric_registry.register(Metric(name='AverageRouge', object=mean))
metric_registry.register(Metric(name='WeightedAverageBLEU', object=weighted_mean))
metric_registry.register(Metric(name='AveragePass@1', object=mean))
for k in range(1, 17):
metric_registry.register(Metric(name=f'Pass@{k}', object=partial(pass_at_k, k=k)))
# t2v_metrics
metric_registry.register(Metric(name='VQAScore', object=clip_flant5_score))
metric_registry.register(Metric(name='PickScore', object=pick_score))
metric_registry.register(Metric(name='CLIPScore', object=clip_score))
metric_registry.register(Metric(name='BLIPv2Score', object=blip2_score))
metric_registry.register(Metric(name='HPSv2Score', object=hpsv2_score))
metric_registry.register(Metric(name='HPSv2.1Score', object=hpsv2_1_score))
metric_registry.register(Metric(name='ImageRewardScore', object=image_reward_score))
metric_registry.register(Metric(name='FGA_BLIP2Score', object=fga_blip2_score))
metric_registry.register(Metric(name='MPS', object=mps_score))