embed-bge-m3/FlagEmbedding/research/baai_general_embedding/finetune/eval_msmarco.py

267 lines
8.3 KiB
Python

import faiss
import torch
import logging
import datasets
import numpy as np
from tqdm import tqdm
from typing import Optional
from dataclasses import dataclass, field
from transformers import HfArgumentParser
from FlagEmbedding import FlagModel
logger = logging.getLogger(__name__)
@dataclass
class Args:
encoder: str = field(
default="BAAI/bge-base-en-v1.5",
metadata={'help': 'The encoder name or path.'}
)
fp16: bool = field(
default=False,
metadata={'help': 'Use fp16 in inference?'}
)
add_instruction: bool = field(
default=False,
metadata={'help': 'Add query-side instruction?'}
)
corpus_data: str = field(
default="namespace-Pt/msmarco",
metadata={'help': 'candidate passages'}
)
query_data: str = field(
default="namespace-Pt/msmarco-corpus",
metadata={'help': 'queries and their positive passages for evaluation'}
)
max_query_length: int = field(
default=32,
metadata={'help': 'Max query length.'}
)
max_passage_length: int = field(
default=128,
metadata={'help': 'Max passage length.'}
)
batch_size: int = field(
default=256,
metadata={'help': 'Inference batch size.'}
)
index_factory: str = field(
default="Flat",
metadata={'help': 'Faiss index factory.'}
)
k: int = field(
default=100,
metadata={'help': 'How many neighbors to retrieve?'}
)
save_embedding: bool = field(
default=False,
metadata={'help': 'Save embeddings in memmap at save_dir?'}
)
load_embedding: bool = field(
default=False,
metadata={'help': 'Load embeddings from save_dir?'}
)
save_path: str = field(
default="embeddings.memmap",
metadata={'help': 'Path to save embeddings.'}
)
def index(model: FlagModel, corpus: datasets.Dataset, batch_size: int = 256, max_length: int=512, index_factory: str = "Flat", save_path: str = None, save_embedding: bool = False, load_embedding: bool = False):
"""
1. Encode the entire corpus into dense embeddings;
2. Create faiss index;
3. Optionally save embeddings.
"""
if load_embedding:
test = model.encode("test")
dtype = test.dtype
dim = len(test)
corpus_embeddings = np.memmap(
save_path,
mode="r",
dtype=dtype
).reshape(-1, dim)
else:
corpus_embeddings = model.encode_corpus(corpus["content"], batch_size=batch_size, max_length=max_length)
dim = corpus_embeddings.shape[-1]
if save_embedding:
logger.info(f"saving embeddings at {save_path}...")
memmap = np.memmap(
save_path,
shape=corpus_embeddings.shape,
mode="w+",
dtype=corpus_embeddings.dtype
)
length = corpus_embeddings.shape[0]
# add in batch
save_batch_size = 10000
if length > save_batch_size:
for i in tqdm(range(0, length, save_batch_size), leave=False, desc="Saving Embeddings"):
j = min(i + save_batch_size, length)
memmap[i: j] = corpus_embeddings[i: j]
else:
memmap[:] = corpus_embeddings
# create faiss index
faiss_index = faiss.index_factory(dim, index_factory, faiss.METRIC_INNER_PRODUCT)
if model.device == torch.device("cuda"):
# co = faiss.GpuClonerOptions()
co = faiss.GpuMultipleClonerOptions()
co.useFloat16 = True
# faiss_index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, faiss_index, co)
faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co)
# NOTE: faiss only accepts float32
logger.info("Adding embeddings...")
corpus_embeddings = corpus_embeddings.astype(np.float32)
faiss_index.train(corpus_embeddings)
faiss_index.add(corpus_embeddings)
return faiss_index
def search(model: FlagModel, queries: datasets, faiss_index: faiss.Index, k:int = 100, batch_size: int = 256, max_length: int=512):
"""
1. Encode queries into dense embeddings;
2. Search through faiss index
"""
query_embeddings = model.encode_queries(queries["query"], batch_size=batch_size, max_length=max_length)
query_size = len(query_embeddings)
all_scores = []
all_indices = []
for i in tqdm(range(0, query_size, batch_size), desc="Searching"):
j = min(i + batch_size, query_size)
query_embedding = query_embeddings[i: j]
score, indice = faiss_index.search(query_embedding.astype(np.float32), k=k)
all_scores.append(score)
all_indices.append(indice)
all_scores = np.concatenate(all_scores, axis=0)
all_indices = np.concatenate(all_indices, axis=0)
return all_scores, all_indices
def evaluate(preds,
preds_scores,
labels,
cutoffs=[1, 10, 100]):
"""
Evaluate MRR and Recall at cutoffs.
"""
metrics = {}
# MRR
mrrs = np.zeros(len(cutoffs))
for pred, label in zip(preds, labels):
jump = False
for i, x in enumerate(pred, 1):
if x in label:
for k, cutoff in enumerate(cutoffs):
if i <= cutoff:
mrrs[k] += 1 / i
jump = True
if jump:
break
mrrs /= len(preds)
for i, cutoff in enumerate(cutoffs):
mrr = mrrs[i]
metrics[f"MRR@{cutoff}"] = mrr
# Recall
recalls = np.zeros(len(cutoffs))
for pred, label in zip(preds, labels):
for k, cutoff in enumerate(cutoffs):
recall = np.intersect1d(label, pred[:cutoff])
recalls[k] += len(recall) / max(min(cutoff, len(label)), 1)
recalls /= len(preds)
for i, cutoff in enumerate(cutoffs):
recall = recalls[i]
metrics[f"Recall@{cutoff}"] = recall
# AUC
pred_hard_encodings = []
for pred, label in zip(preds, labels):
pred_hard_encoding = np.isin(pred, label).astype(int).tolist()
pred_hard_encodings.append(pred_hard_encoding)
from sklearn.metrics import roc_auc_score, roc_curve, ndcg_score
pred_hard_encodings1d = np.asarray(pred_hard_encodings).flatten()
preds_scores1d = preds_scores.flatten()
auc = roc_auc_score(pred_hard_encodings1d, preds_scores1d)
metrics['AUC@100'] = auc
# nDCG
for k, cutoff in enumerate(cutoffs):
nDCG = ndcg_score(pred_hard_encodings, preds_scores, k=cutoff)
metrics[f"nDCG@{cutoff}"] = nDCG
return metrics
def main():
parser = HfArgumentParser([Args])
args: Args = parser.parse_args_into_dataclasses()[0]
if args.query_data == 'namespace-Pt/msmarco-corpus':
assert args.corpus_data == 'namespace-Pt/msmarco'
eval_data = datasets.load_dataset("namespace-Pt/msmarco", split="dev")
corpus = datasets.load_dataset("namespace-Pt/msmarco-corpus", split="train")
else:
eval_data = datasets.load_dataset('json', data_files=args.query_data, split='train')
corpus = datasets.load_dataset('json', data_files=args.corpus_data, split='train')
model = FlagModel(
args.encoder,
query_instruction_for_retrieval="Represent this sentence for searching relevant passages: " if args.add_instruction else None,
use_fp16=args.fp16
)
faiss_index = index(
model=model,
corpus=corpus,
batch_size=args.batch_size,
max_length=args.max_passage_length,
index_factory=args.index_factory,
save_path=args.save_path,
save_embedding=args.save_embedding,
load_embedding=args.load_embedding
)
scores, indices = search(
model=model,
queries=eval_data,
faiss_index=faiss_index,
k=args.k,
batch_size=args.batch_size,
max_length=args.max_query_length
)
retrieval_results = []
for indice in indices:
# filter invalid indices
indice = indice[indice != -1].tolist()
retrieval_results.append(corpus[indice]["content"])
ground_truths = []
for sample in eval_data:
ground_truths.append(sample["positive"])
metrics = evaluate(retrieval_results, scores, ground_truths)
print(metrics)
if __name__ == "__main__":
main()