embed-bge-m3/FlagEmbedding/research/Long_LLM/longllm_qlora/src/metrics.py

122 lines
4.5 KiB
Python

import os
import json
import inspect
import numpy as np
from functools import partial
from rouge import Rouge
from tqdm import tqdm
from transformers.utils import logging
from .utils import makedirs, split_file_dir_name_ext, normalize_text
logger = logging.get_logger(__name__)
class Metric:
"""Class for computing metrics and some post-processings."""
@classmethod
def get_metric_fn(cls, metrics, **kwds):
assert isinstance(metrics, list) or isinstance(metrics, tuple), "You must pass metric_names in a list or tuple!"
return_metrics = {}
# get all methods
metric_fns = []
all_metric_names = [x[0] for x in inspect.getmembers(cls, predicate=inspect.isfunction) if not x[0].startswith("get_")]
for metric_name in metrics:
if metric_name in all_metric_names:
metric_fns.append(partial(getattr(cls, metric_name), **kwds))
else:
raise NotImplementedError(f"Metric {metric_name} not implemented!")
def compute_metrics(*args, **kwargs):
for metric_fn in metric_fns:
# call corresponding method
metric = metric_fn(*args, **kwargs)
# NOTE: some metric_fn are only used for post-processing and saving results, which return None by default
if metric is not None:
return_metrics.update(metric)
return return_metrics
return compute_metrics
def get_save_path(eval_data, output_dir=None, field="result", save_name=None):
"""
if output_dir is None:
-> {eval_data_dir}/{eval_data_name}.{field}.{save_name}.{eval_data_ext}
else:
-> {output_dir}/{eval_data_name}.{field}.{save_name}.{eval_data_ext}
"""
eval_data_dir, eval_data_name, eval_data_ext = split_file_dir_name_ext(eval_data)
if output_dir is None:
output_dir = eval_data_dir
fields = [eval_data_name, field]
if save_name is not None:
fields.append(save_name)
save_path = os.path.join(output_dir, ".".join(fields) + eval_data_ext)
makedirs(save_path)
return save_path
def save_result(preds, labels, save_path, indices=None, **kwargs):
if len(preds) != len(labels):
logger.warning(f"There are {len(preds)} samples in predictions while {len(labels)} samples in labels!")
labels = labels[:min(len(preds), len(labels))]
preds = preds[:min(len(preds), len(labels))]
with open(save_path, "w", encoding="utf-8") as f:
for i, (pred, label) in enumerate(zip(preds, labels)):
item = {
"prediction": pred,
"target": label,
}
if indices is not None:
item["index"] = indices[i]
f.write(json.dumps(item, ensure_ascii=False) + "\n")
def rouge(preds, labels, **kwargs):
rouge = Rouge()
if len(preds) != len(labels):
logger.warning(f"There are {len(preds)} samples in predictions while {len(labels)} samples in labels!")
labels = labels[:min(len(preds), len(labels))]
preds = preds[:min(len(preds), len(labels))]
preds = normalize_text(preds)
labels = normalize_text(labels)
# filter empty preditions
preds = [":)" if len(pred) == 0 else pred for pred in preds]
score = rouge.get_scores(preds, labels, avg=True)
metric = {
"rouge-1": score["rouge-1"]["f"],
"rouge-2": score["rouge-2"]["f"],
"rouge-l": score["rouge-2"]["f"],
}
return metric
# def acc(eval_data=None, **kwds):
# if eval_data is not None:
# data_labels = Metric._prepare_label(eval_data)
# def compute_metric(indices, preds, labels=None, **kwargs):
# if labels is None:
# labels = data_labels
# if len(preds) != len(labels):
# logger.warning(f"There are {len(preds)} queries in predictions while {len(labels)} queries in labels!")
# labels = [labels[query_id] for query_id in indices]
# preds = normalize_text(preds)
# labels = normalize_text(labels)
# overlap = 0
# for pred, label in zip(preds, labels):
# if pred == label:
# overlap += 1
# metric = {
# "acc": overlap / len(preds),
# }
# return metric
# return compute_metric