122 lines
4.5 KiB
Python
122 lines
4.5 KiB
Python
import os
|
|
import json
|
|
import inspect
|
|
import numpy as np
|
|
from functools import partial
|
|
from rouge import Rouge
|
|
from tqdm import tqdm
|
|
from transformers.utils import logging
|
|
from .utils import makedirs, split_file_dir_name_ext, normalize_text
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class Metric:
|
|
"""Class for computing metrics and some post-processings."""
|
|
@classmethod
|
|
def get_metric_fn(cls, metrics, **kwds):
|
|
assert isinstance(metrics, list) or isinstance(metrics, tuple), "You must pass metric_names in a list or tuple!"
|
|
return_metrics = {}
|
|
# get all methods
|
|
metric_fns = []
|
|
|
|
all_metric_names = [x[0] for x in inspect.getmembers(cls, predicate=inspect.isfunction) if not x[0].startswith("get_")]
|
|
for metric_name in metrics:
|
|
if metric_name in all_metric_names:
|
|
metric_fns.append(partial(getattr(cls, metric_name), **kwds))
|
|
else:
|
|
raise NotImplementedError(f"Metric {metric_name} not implemented!")
|
|
|
|
def compute_metrics(*args, **kwargs):
|
|
for metric_fn in metric_fns:
|
|
# call corresponding method
|
|
metric = metric_fn(*args, **kwargs)
|
|
# NOTE: some metric_fn are only used for post-processing and saving results, which return None by default
|
|
if metric is not None:
|
|
return_metrics.update(metric)
|
|
return return_metrics
|
|
return compute_metrics
|
|
|
|
def get_save_path(eval_data, output_dir=None, field="result", save_name=None):
|
|
"""
|
|
if output_dir is None:
|
|
-> {eval_data_dir}/{eval_data_name}.{field}.{save_name}.{eval_data_ext}
|
|
else:
|
|
-> {output_dir}/{eval_data_name}.{field}.{save_name}.{eval_data_ext}
|
|
"""
|
|
eval_data_dir, eval_data_name, eval_data_ext = split_file_dir_name_ext(eval_data)
|
|
if output_dir is None:
|
|
output_dir = eval_data_dir
|
|
fields = [eval_data_name, field]
|
|
if save_name is not None:
|
|
fields.append(save_name)
|
|
save_path = os.path.join(output_dir, ".".join(fields) + eval_data_ext)
|
|
makedirs(save_path)
|
|
return save_path
|
|
|
|
def save_result(preds, labels, save_path, indices=None, **kwargs):
|
|
if len(preds) != len(labels):
|
|
logger.warning(f"There are {len(preds)} samples in predictions while {len(labels)} samples in labels!")
|
|
labels = labels[:min(len(preds), len(labels))]
|
|
preds = preds[:min(len(preds), len(labels))]
|
|
|
|
with open(save_path, "w", encoding="utf-8") as f:
|
|
for i, (pred, label) in enumerate(zip(preds, labels)):
|
|
item = {
|
|
"prediction": pred,
|
|
"target": label,
|
|
}
|
|
if indices is not None:
|
|
item["index"] = indices[i]
|
|
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
|
|
|
def rouge(preds, labels, **kwargs):
|
|
rouge = Rouge()
|
|
|
|
if len(preds) != len(labels):
|
|
logger.warning(f"There are {len(preds)} samples in predictions while {len(labels)} samples in labels!")
|
|
labels = labels[:min(len(preds), len(labels))]
|
|
preds = preds[:min(len(preds), len(labels))]
|
|
|
|
preds = normalize_text(preds)
|
|
labels = normalize_text(labels)
|
|
|
|
# filter empty preditions
|
|
preds = [":)" if len(pred) == 0 else pred for pred in preds]
|
|
|
|
score = rouge.get_scores(preds, labels, avg=True)
|
|
|
|
metric = {
|
|
"rouge-1": score["rouge-1"]["f"],
|
|
"rouge-2": score["rouge-2"]["f"],
|
|
"rouge-l": score["rouge-2"]["f"],
|
|
}
|
|
return metric
|
|
|
|
# def acc(eval_data=None, **kwds):
|
|
# if eval_data is not None:
|
|
# data_labels = Metric._prepare_label(eval_data)
|
|
|
|
# def compute_metric(indices, preds, labels=None, **kwargs):
|
|
# if labels is None:
|
|
# labels = data_labels
|
|
|
|
# if len(preds) != len(labels):
|
|
# logger.warning(f"There are {len(preds)} queries in predictions while {len(labels)} queries in labels!")
|
|
|
|
# labels = [labels[query_id] for query_id in indices]
|
|
|
|
# preds = normalize_text(preds)
|
|
# labels = normalize_text(labels)
|
|
|
|
# overlap = 0
|
|
# for pred, label in zip(preds, labels):
|
|
# if pred == label:
|
|
# overlap += 1
|
|
|
|
# metric = {
|
|
# "acc": overlap / len(preds),
|
|
# }
|
|
# return metric
|
|
# return compute_metric
|