embed-bge-m3/FlagEmbedding/research/BGE_VL/eval/eval_Circo.py

225 lines
6.9 KiB
Python

import os
import faiss
import torch
import logging
import datasets
import numpy as np
from tqdm import tqdm
from typing import Optional
from dataclasses import dataclass, field
from transformers import HfArgumentParser
from flag_mmret import Flag_mmret
import json
logger = logging.getLogger(__name__)
@dataclass
class Args:
model_name: str = field(
default="BAAI/BGE-VL-large",
metadata={'help': 'Model Name'}
)
result_save_path: str = field(
default="./eval/mmret_large_circo.json",
metadata={'help': 'Where to save the results.'}
)
image_dir: str = field(
default="YOUR_COCO_IMAGE_DIRECTORY",
metadata={'help': 'Where the images located on.'}
)
fp16: bool = field(
default=False,
metadata={'help': 'Use fp16 in inference?'}
)
max_query_length: int = field(
default=64,
metadata={'help': 'Max query length.'}
)
max_passage_length: int = field(
default=77,
metadata={'help': 'Max passage length.'}
)
batch_size: int = field(
default=256,
metadata={'help': 'Inference batch size.'}
)
index_factory: str = field(
default="Flat",
metadata={'help': 'Faiss index factory.'}
)
k: int = field(
default=100,
metadata={'help': 'How many neighbors to retrieve?'}
)
save_embedding: bool = field(
default=False,
metadata={'help': 'Save embeddings in memmap at save_dir?'}
)
load_embedding: bool = field(
default=False,
metadata={'help': 'Load embeddings from save_dir?'}
)
save_path: str = field(
default="embeddings.memmap",
metadata={'help': 'Path to save embeddings.'}
)
def index(model: Flag_mmret, corpus: datasets.Dataset, batch_size: int = 256, max_length: int=512, index_factory: str = "Flat", save_path: str = None, save_embedding: bool = False, load_embedding: bool = False):
"""
1. Encode the entire corpus into dense embeddings;
2. Create faiss index;
3. Optionally save embeddings.
"""
if load_embedding:
test = model.encode("test")
dtype = test.dtype
dim = len(test)
corpus_embeddings = np.memmap(
save_path,
mode="r",
dtype=dtype
).reshape(-1, dim)
else:
corpus_embeddings = model.encode_corpus(corpus, batch_size=batch_size, max_length=max_length, corpus_type='image')
dim = corpus_embeddings.shape[-1]
if save_embedding:
logger.info(f"saving embeddings at {save_path}...")
memmap = np.memmap(
save_path,
shape=corpus_embeddings.shape,
mode="w+",
dtype=corpus_embeddings.dtype
)
length = corpus_embeddings.shape[0]
# add in batch
save_batch_size = 10000
if length > save_batch_size:
for i in tqdm(range(0, length, save_batch_size), leave=False, desc="Saving Embeddings"):
j = min(i + save_batch_size, length)
memmap[i: j] = corpus_embeddings[i: j]
else:
memmap[:] = corpus_embeddings
# create faiss index
faiss_index = faiss.index_factory(dim, index_factory, faiss.METRIC_INNER_PRODUCT)
if model.device == torch.device("cuda"):
# co = faiss.GpuClonerOptions()
co = faiss.GpuMultipleClonerOptions()
co.useFloat16 = True
# faiss_index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, faiss_index, co)
faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co)
# NOTE: faiss only accepts float32
logger.info("Adding embeddings...")
corpus_embeddings = corpus_embeddings.astype(np.float32)
faiss_index.train(corpus_embeddings)
faiss_index.add(corpus_embeddings)
return faiss_index
def search(model: Flag_mmret, queries: datasets, faiss_index: faiss.Index, k:int = 100, batch_size: int = 256, max_length: int=512):
"""
1. Encode queries into dense embeddings;
2. Search through faiss index
"""
query_embeddings = model.encode_queries([queries["q_text"], queries["q_img"]],
batch_size=batch_size,
max_length=max_length,
query_type='mm_it')
query_size = len(query_embeddings)
all_scores = []
all_indices = []
for i in tqdm(range(0, query_size, batch_size), desc="Searching"):
j = min(i + batch_size, query_size)
query_embedding = query_embeddings[i: j]
score, indice = faiss_index.search(query_embedding.astype(np.float32), k=k)
all_scores.append(score)
all_indices.append(indice)
all_scores = np.concatenate(all_scores, axis=0)
all_indices = np.concatenate(all_indices, axis=0)
return all_scores, all_indices
def main():
parser = HfArgumentParser([Args])
args: Args = parser.parse_args_into_dataclasses()[0]
print(f"Results will be saved in {args.result_save_path}")
eval_data = datasets.load_dataset('json', data_files="./eval/data/circo_query.jsonl", split='train')
image_corpus_test = datasets.load_dataset('json', data_files="./eval/data/circo_corpus.jsonl", split='train')
model = Flag_mmret(model_name=args.model_name,
normlized = True,
image_dir=args.image_dir,
use_fp16=False,
)
faiss_index = index(
model=model,
corpus=image_corpus_test,
batch_size=args.batch_size,
max_length=args.max_passage_length,
index_factory=args.index_factory,
save_path=args.save_path,
save_embedding=args.save_embedding,
load_embedding=args.load_embedding
)
scores, indices = search(
model=model,
queries=eval_data,
faiss_index=faiss_index,
k=args.k,
batch_size=args.batch_size,
max_length=args.max_query_length
)
retrieval_results = []
for indice in indices:
# filter invalid indices
indice = indice[indice != -1].tolist()
retrieval_results.append(image_corpus_test[indice]["content"])
########## results in test corpus #########
q_images = eval_data["q_img"]
q_ids = []
for _img in q_images:
_id = os.path.basename(_img)
_id = os.path.splitext(_id)[0]
q_ids.append(_id)
pairids = eval_data["id"]
results = {}
for pairid, re_results, q_img in zip(pairids, retrieval_results, q_images):
id = str(pairid)
top_50_results = re_results[0:50]
results[id] = top_50_results
with open(args.result_save_path, "w") as f:
json.dump(results, f)
if __name__ == "__main__":
main()