267 lines
8.3 KiB
Python
267 lines
8.3 KiB
Python
import faiss
|
|
import torch
|
|
import logging
|
|
import datasets
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from typing import Optional
|
|
from dataclasses import dataclass, field
|
|
from transformers import HfArgumentParser
|
|
from FlagEmbedding import FlagModel
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class Args:
|
|
encoder: str = field(
|
|
default="BAAI/bge-base-en-v1.5",
|
|
metadata={'help': 'The encoder name or path.'}
|
|
)
|
|
fp16: bool = field(
|
|
default=False,
|
|
metadata={'help': 'Use fp16 in inference?'}
|
|
)
|
|
add_instruction: bool = field(
|
|
default=False,
|
|
metadata={'help': 'Add query-side instruction?'}
|
|
)
|
|
|
|
corpus_data: str = field(
|
|
default="namespace-Pt/msmarco",
|
|
metadata={'help': 'candidate passages'}
|
|
)
|
|
query_data: str = field(
|
|
default="namespace-Pt/msmarco-corpus",
|
|
metadata={'help': 'queries and their positive passages for evaluation'}
|
|
)
|
|
|
|
max_query_length: int = field(
|
|
default=32,
|
|
metadata={'help': 'Max query length.'}
|
|
)
|
|
max_passage_length: int = field(
|
|
default=128,
|
|
metadata={'help': 'Max passage length.'}
|
|
)
|
|
batch_size: int = field(
|
|
default=256,
|
|
metadata={'help': 'Inference batch size.'}
|
|
)
|
|
index_factory: str = field(
|
|
default="Flat",
|
|
metadata={'help': 'Faiss index factory.'}
|
|
)
|
|
k: int = field(
|
|
default=100,
|
|
metadata={'help': 'How many neighbors to retrieve?'}
|
|
)
|
|
|
|
save_embedding: bool = field(
|
|
default=False,
|
|
metadata={'help': 'Save embeddings in memmap at save_dir?'}
|
|
)
|
|
load_embedding: bool = field(
|
|
default=False,
|
|
metadata={'help': 'Load embeddings from save_dir?'}
|
|
)
|
|
save_path: str = field(
|
|
default="embeddings.memmap",
|
|
metadata={'help': 'Path to save embeddings.'}
|
|
)
|
|
|
|
|
|
def index(model: FlagModel, corpus: datasets.Dataset, batch_size: int = 256, max_length: int=512, index_factory: str = "Flat", save_path: str = None, save_embedding: bool = False, load_embedding: bool = False):
|
|
"""
|
|
1. Encode the entire corpus into dense embeddings;
|
|
2. Create faiss index;
|
|
3. Optionally save embeddings.
|
|
"""
|
|
if load_embedding:
|
|
test = model.encode("test")
|
|
dtype = test.dtype
|
|
dim = len(test)
|
|
|
|
corpus_embeddings = np.memmap(
|
|
save_path,
|
|
mode="r",
|
|
dtype=dtype
|
|
).reshape(-1, dim)
|
|
|
|
else:
|
|
corpus_embeddings = model.encode_corpus(corpus["content"], batch_size=batch_size, max_length=max_length)
|
|
dim = corpus_embeddings.shape[-1]
|
|
|
|
if save_embedding:
|
|
logger.info(f"saving embeddings at {save_path}...")
|
|
memmap = np.memmap(
|
|
save_path,
|
|
shape=corpus_embeddings.shape,
|
|
mode="w+",
|
|
dtype=corpus_embeddings.dtype
|
|
)
|
|
|
|
length = corpus_embeddings.shape[0]
|
|
# add in batch
|
|
save_batch_size = 10000
|
|
if length > save_batch_size:
|
|
for i in tqdm(range(0, length, save_batch_size), leave=False, desc="Saving Embeddings"):
|
|
j = min(i + save_batch_size, length)
|
|
memmap[i: j] = corpus_embeddings[i: j]
|
|
else:
|
|
memmap[:] = corpus_embeddings
|
|
|
|
# create faiss index
|
|
faiss_index = faiss.index_factory(dim, index_factory, faiss.METRIC_INNER_PRODUCT)
|
|
|
|
if model.device == torch.device("cuda"):
|
|
# co = faiss.GpuClonerOptions()
|
|
co = faiss.GpuMultipleClonerOptions()
|
|
co.useFloat16 = True
|
|
# faiss_index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, faiss_index, co)
|
|
faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co)
|
|
|
|
# NOTE: faiss only accepts float32
|
|
logger.info("Adding embeddings...")
|
|
corpus_embeddings = corpus_embeddings.astype(np.float32)
|
|
faiss_index.train(corpus_embeddings)
|
|
faiss_index.add(corpus_embeddings)
|
|
return faiss_index
|
|
|
|
|
|
def search(model: FlagModel, queries: datasets, faiss_index: faiss.Index, k:int = 100, batch_size: int = 256, max_length: int=512):
|
|
"""
|
|
1. Encode queries into dense embeddings;
|
|
2. Search through faiss index
|
|
"""
|
|
query_embeddings = model.encode_queries(queries["query"], batch_size=batch_size, max_length=max_length)
|
|
query_size = len(query_embeddings)
|
|
|
|
all_scores = []
|
|
all_indices = []
|
|
|
|
for i in tqdm(range(0, query_size, batch_size), desc="Searching"):
|
|
j = min(i + batch_size, query_size)
|
|
query_embedding = query_embeddings[i: j]
|
|
score, indice = faiss_index.search(query_embedding.astype(np.float32), k=k)
|
|
all_scores.append(score)
|
|
all_indices.append(indice)
|
|
|
|
all_scores = np.concatenate(all_scores, axis=0)
|
|
all_indices = np.concatenate(all_indices, axis=0)
|
|
return all_scores, all_indices
|
|
|
|
|
|
def evaluate(preds,
|
|
preds_scores,
|
|
labels,
|
|
cutoffs=[1, 10, 100]):
|
|
"""
|
|
Evaluate MRR and Recall at cutoffs.
|
|
"""
|
|
metrics = {}
|
|
|
|
# MRR
|
|
mrrs = np.zeros(len(cutoffs))
|
|
for pred, label in zip(preds, labels):
|
|
jump = False
|
|
for i, x in enumerate(pred, 1):
|
|
if x in label:
|
|
for k, cutoff in enumerate(cutoffs):
|
|
if i <= cutoff:
|
|
mrrs[k] += 1 / i
|
|
jump = True
|
|
if jump:
|
|
break
|
|
mrrs /= len(preds)
|
|
for i, cutoff in enumerate(cutoffs):
|
|
mrr = mrrs[i]
|
|
metrics[f"MRR@{cutoff}"] = mrr
|
|
|
|
# Recall
|
|
recalls = np.zeros(len(cutoffs))
|
|
for pred, label in zip(preds, labels):
|
|
for k, cutoff in enumerate(cutoffs):
|
|
recall = np.intersect1d(label, pred[:cutoff])
|
|
recalls[k] += len(recall) / max(min(cutoff, len(label)), 1)
|
|
recalls /= len(preds)
|
|
for i, cutoff in enumerate(cutoffs):
|
|
recall = recalls[i]
|
|
metrics[f"Recall@{cutoff}"] = recall
|
|
|
|
# AUC
|
|
pred_hard_encodings = []
|
|
for pred, label in zip(preds, labels):
|
|
pred_hard_encoding = np.isin(pred, label).astype(int).tolist()
|
|
pred_hard_encodings.append(pred_hard_encoding)
|
|
|
|
from sklearn.metrics import roc_auc_score, roc_curve, ndcg_score
|
|
pred_hard_encodings1d = np.asarray(pred_hard_encodings).flatten()
|
|
preds_scores1d = preds_scores.flatten()
|
|
auc = roc_auc_score(pred_hard_encodings1d, preds_scores1d)
|
|
|
|
metrics['AUC@100'] = auc
|
|
|
|
# nDCG
|
|
for k, cutoff in enumerate(cutoffs):
|
|
nDCG = ndcg_score(pred_hard_encodings, preds_scores, k=cutoff)
|
|
metrics[f"nDCG@{cutoff}"] = nDCG
|
|
|
|
return metrics
|
|
|
|
def main():
|
|
parser = HfArgumentParser([Args])
|
|
args: Args = parser.parse_args_into_dataclasses()[0]
|
|
|
|
if args.query_data == 'namespace-Pt/msmarco-corpus':
|
|
assert args.corpus_data == 'namespace-Pt/msmarco'
|
|
eval_data = datasets.load_dataset("namespace-Pt/msmarco", split="dev")
|
|
corpus = datasets.load_dataset("namespace-Pt/msmarco-corpus", split="train")
|
|
else:
|
|
eval_data = datasets.load_dataset('json', data_files=args.query_data, split='train')
|
|
corpus = datasets.load_dataset('json', data_files=args.corpus_data, split='train')
|
|
|
|
model = FlagModel(
|
|
args.encoder,
|
|
query_instruction_for_retrieval="Represent this sentence for searching relevant passages: " if args.add_instruction else None,
|
|
use_fp16=args.fp16
|
|
)
|
|
|
|
faiss_index = index(
|
|
model=model,
|
|
corpus=corpus,
|
|
batch_size=args.batch_size,
|
|
max_length=args.max_passage_length,
|
|
index_factory=args.index_factory,
|
|
save_path=args.save_path,
|
|
save_embedding=args.save_embedding,
|
|
load_embedding=args.load_embedding
|
|
)
|
|
|
|
scores, indices = search(
|
|
model=model,
|
|
queries=eval_data,
|
|
faiss_index=faiss_index,
|
|
k=args.k,
|
|
batch_size=args.batch_size,
|
|
max_length=args.max_query_length
|
|
)
|
|
|
|
retrieval_results = []
|
|
for indice in indices:
|
|
# filter invalid indices
|
|
indice = indice[indice != -1].tolist()
|
|
retrieval_results.append(corpus[indice]["content"])
|
|
|
|
ground_truths = []
|
|
for sample in eval_data:
|
|
ground_truths.append(sample["positive"])
|
|
|
|
metrics = evaluate(retrieval_results, scores, ground_truths)
|
|
|
|
print(metrics)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|