180 lines
5.7 KiB
Python
180 lines
5.7 KiB
Python
import json
|
|
from dataclasses import dataclass, field
|
|
from typing import List
|
|
|
|
import numpy as np
|
|
import pytrec_eval
|
|
from transformers import HfArgumentParser
|
|
from FlagEmbedding import FlagReranker
|
|
|
|
@dataclass
|
|
class Args():
|
|
input_path: str = field(
|
|
default="",
|
|
metadata={'help': """
|
|
The data path points to a file in JSONL format.
|
|
Each line contains `query`, `pos`, and `neg`. Here, `query` is a string (`str`),
|
|
while both `pos` and `neg` are lists of strings (`List[str]`).
|
|
If each line includes `pos_label_scores`, it will use to compute `ndcg@k`, else it will set default `1`.
|
|
"""}
|
|
)
|
|
metrics: List[str] = field(
|
|
default=None, # usage example: recall mrr ndcg
|
|
metadata={'help': 'The evaluation metrics, you can set recall / mrr / ndcg'}
|
|
)
|
|
k_values: List[int] = field(
|
|
default=None,
|
|
metadata={'help': 'Present the top-k metrics evaluation.'}
|
|
)
|
|
cache_dir: str = field(
|
|
default=None,
|
|
metadata={'help': 'The path to store the cache of reranker.'}
|
|
)
|
|
use_fp16: bool = field(
|
|
default=True,
|
|
metadata={'help': 'Whether to use fp16 to accelerate inference, it is not suitable for CPU only inference.'}
|
|
)
|
|
batch_size: int = field(
|
|
default=512
|
|
)
|
|
max_length: int = field(
|
|
default=1024
|
|
)
|
|
|
|
|
|
def evaluate_mrr(predicts, labels, cutoffs):
|
|
"""
|
|
Evaluate MRR.
|
|
"""
|
|
metrics = {}
|
|
|
|
# MRR
|
|
mrrs = np.zeros(len(cutoffs))
|
|
for pred, label in zip(predicts, labels):
|
|
jump = False
|
|
for i, x in enumerate(pred, 1):
|
|
if x in label:
|
|
for k, cutoff in enumerate(cutoffs):
|
|
if i <= cutoff:
|
|
mrrs[k] += 1 / i
|
|
jump = True
|
|
if jump:
|
|
break
|
|
mrrs /= len(predicts)
|
|
for i, cutoff in enumerate(cutoffs):
|
|
mrr = mrrs[i]
|
|
metrics[f"MRR@{cutoff}"] = mrr
|
|
|
|
return metrics
|
|
|
|
def main():
|
|
parser = HfArgumentParser([Args])
|
|
args: Args = parser.parse_args_into_dataclasses()[0]
|
|
input_path = args.input_path
|
|
metrics = args.metrics if args.metrics is not None else ['recall', 'mrr', 'ndcg', 'map', 'precision']
|
|
k_values = args.k_values if args.k_values is not None else [1, 5, 10, 50, 100]
|
|
cache_dir = args.cache_dir
|
|
use_fp16 = args.use_fp16
|
|
batch_size = args.batch_size
|
|
max_length = args.max_length
|
|
|
|
reranker = FlagReranker('BAAI/bge-reranker-v2-m3', cache_dir=cache_dir, use_fp16=use_fp16)
|
|
|
|
data = []
|
|
data_num = []
|
|
with open(input_path) as f:
|
|
for line in f:
|
|
data.append(json.loads(line))
|
|
|
|
pairs = []
|
|
for d in data:
|
|
data_num.append(0)
|
|
passages = []
|
|
passages.extend(d['pos'])
|
|
passages.extend(d['neg'])
|
|
for p in passages:
|
|
pairs.append((d['query'], p))
|
|
data_num[-1] += 1
|
|
|
|
scores = reranker.compute_score(pairs, batch_size=batch_size, max_length=max_length)
|
|
scores = np.asarray(scores)
|
|
scores = scores.reshape(-1)
|
|
|
|
start_num = 0
|
|
ground_truths = {}
|
|
labels = []
|
|
for i in range(len(data)):
|
|
tmp = {}
|
|
tmp_labels = []
|
|
for ind in range(len(data[i]['pos'])):
|
|
try:
|
|
tmp[str(start_num + ind)] = int(data[i]['pos_label_scores'][ind])
|
|
except Exception as e:
|
|
# print(e)
|
|
tmp[str(start_num + ind)] = 1
|
|
tmp_labels.append(start_num + ind)
|
|
ground_truths[str(i)] = tmp
|
|
start_num += data_num[i]
|
|
labels.append(tmp_labels)
|
|
|
|
start_num = 0
|
|
rerank_results = {}
|
|
predicts = []
|
|
for i in range(len(data)):
|
|
tmp = {}
|
|
tmp_predicts = [(start_num + ind, scores[start_num + ind]) for ind in range(data_num[i])]
|
|
tmp_predicts = [idx for (idx, _) in sorted(tmp_predicts, key=lambda x: x[1], reverse=True)]
|
|
for ind in range(data_num[i]):
|
|
tmp[str(start_num + ind)] = float(scores[start_num + ind])
|
|
rerank_results[str(i)] = tmp
|
|
start_num += data_num[i]
|
|
predicts.append(tmp_predicts)
|
|
|
|
ndcg = {}
|
|
_map = {}
|
|
recall = {}
|
|
precision = {}
|
|
|
|
for k in k_values:
|
|
ndcg[f"NDCG@{k}"] = 0.0
|
|
_map[f"MAP@{k}"] = 0.0
|
|
recall[f"Recall@{k}"] = 0.0
|
|
precision[f"Precision@{k}"] = 0.0
|
|
|
|
map_string = "map_cut." + ",".join([str(k) for k in k_values])
|
|
ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values])
|
|
recall_string = "recall." + ",".join([str(k) for k in k_values])
|
|
precision_string = "P." + ",".join([str(k) for k in k_values])
|
|
evaluator = pytrec_eval.RelevanceEvaluator(ground_truths,
|
|
{map_string, ndcg_string, recall_string, precision_string})
|
|
|
|
scores = evaluator.evaluate(rerank_results)
|
|
|
|
for query_id in scores.keys():
|
|
for k in k_values:
|
|
ndcg[f"NDCG@{k}"] += scores[query_id]["ndcg_cut_" + str(k)]
|
|
_map[f"MAP@{k}"] += scores[query_id]["map_cut_" + str(k)]
|
|
recall[f"Recall@{k}"] += scores[query_id]["recall_" + str(k)]
|
|
precision[f"Precision@{k}"] += scores[query_id]["P_" + str(k)]
|
|
|
|
for k in k_values:
|
|
ndcg[f"NDCG@{k}"] = round(ndcg[f"NDCG@{k}"] / len(scores), 5)
|
|
_map[f"MAP@{k}"] = round(_map[f"MAP@{k}"] / len(scores), 5)
|
|
recall[f"Recall@{k}"] = round(recall[f"Recall@{k}"] / len(scores), 5)
|
|
precision[f"Precision@{k}"] = round(precision[f"Precision@{k}"] / len(scores), 5)
|
|
|
|
mrr = evaluate_mrr(predicts, labels, k_values)
|
|
|
|
if 'mrr' in metrics:
|
|
print(mrr)
|
|
if 'recall' in metrics:
|
|
print(recall)
|
|
if 'ndcg' in metrics:
|
|
print(ndcg)
|
|
if 'map' in metrics:
|
|
print(_map)
|
|
if 'precision' in metrics:
|
|
print(precision)
|
|
|
|
if __name__ == "__main__":
|
|
main() |