225 lines
6.9 KiB
Python
225 lines
6.9 KiB
Python
import os
|
|
import faiss
|
|
import torch
|
|
import logging
|
|
import datasets
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from typing import Optional
|
|
from dataclasses import dataclass, field
|
|
from transformers import HfArgumentParser
|
|
from flag_mmret import Flag_mmret
|
|
import json
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class Args:
|
|
model_name: str = field(
|
|
default="BAAI/BGE-VL-large",
|
|
metadata={'help': 'Model Name'}
|
|
)
|
|
result_save_path: str = field(
|
|
default="./eval/mmret_large_circo.json",
|
|
metadata={'help': 'Where to save the results.'}
|
|
)
|
|
image_dir: str = field(
|
|
default="YOUR_COCO_IMAGE_DIRECTORY",
|
|
metadata={'help': 'Where the images located on.'}
|
|
)
|
|
fp16: bool = field(
|
|
default=False,
|
|
metadata={'help': 'Use fp16 in inference?'}
|
|
)
|
|
max_query_length: int = field(
|
|
default=64,
|
|
metadata={'help': 'Max query length.'}
|
|
)
|
|
max_passage_length: int = field(
|
|
default=77,
|
|
metadata={'help': 'Max passage length.'}
|
|
)
|
|
batch_size: int = field(
|
|
default=256,
|
|
metadata={'help': 'Inference batch size.'}
|
|
)
|
|
index_factory: str = field(
|
|
default="Flat",
|
|
metadata={'help': 'Faiss index factory.'}
|
|
)
|
|
k: int = field(
|
|
default=100,
|
|
metadata={'help': 'How many neighbors to retrieve?'}
|
|
)
|
|
save_embedding: bool = field(
|
|
default=False,
|
|
metadata={'help': 'Save embeddings in memmap at save_dir?'}
|
|
)
|
|
load_embedding: bool = field(
|
|
default=False,
|
|
metadata={'help': 'Load embeddings from save_dir?'}
|
|
)
|
|
save_path: str = field(
|
|
default="embeddings.memmap",
|
|
metadata={'help': 'Path to save embeddings.'}
|
|
)
|
|
|
|
|
|
|
|
def index(model: Flag_mmret, corpus: datasets.Dataset, batch_size: int = 256, max_length: int=512, index_factory: str = "Flat", save_path: str = None, save_embedding: bool = False, load_embedding: bool = False):
|
|
"""
|
|
1. Encode the entire corpus into dense embeddings;
|
|
2. Create faiss index;
|
|
3. Optionally save embeddings.
|
|
"""
|
|
if load_embedding:
|
|
test = model.encode("test")
|
|
dtype = test.dtype
|
|
dim = len(test)
|
|
|
|
corpus_embeddings = np.memmap(
|
|
save_path,
|
|
mode="r",
|
|
dtype=dtype
|
|
).reshape(-1, dim)
|
|
|
|
else:
|
|
|
|
corpus_embeddings = model.encode_corpus(corpus, batch_size=batch_size, max_length=max_length, corpus_type='image')
|
|
|
|
dim = corpus_embeddings.shape[-1]
|
|
|
|
if save_embedding:
|
|
logger.info(f"saving embeddings at {save_path}...")
|
|
memmap = np.memmap(
|
|
save_path,
|
|
shape=corpus_embeddings.shape,
|
|
mode="w+",
|
|
dtype=corpus_embeddings.dtype
|
|
)
|
|
|
|
length = corpus_embeddings.shape[0]
|
|
# add in batch
|
|
save_batch_size = 10000
|
|
if length > save_batch_size:
|
|
for i in tqdm(range(0, length, save_batch_size), leave=False, desc="Saving Embeddings"):
|
|
j = min(i + save_batch_size, length)
|
|
memmap[i: j] = corpus_embeddings[i: j]
|
|
else:
|
|
memmap[:] = corpus_embeddings
|
|
|
|
# create faiss index
|
|
faiss_index = faiss.index_factory(dim, index_factory, faiss.METRIC_INNER_PRODUCT)
|
|
|
|
|
|
if model.device == torch.device("cuda"):
|
|
# co = faiss.GpuClonerOptions()
|
|
co = faiss.GpuMultipleClonerOptions()
|
|
co.useFloat16 = True
|
|
# faiss_index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, faiss_index, co)
|
|
faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co)
|
|
|
|
# NOTE: faiss only accepts float32
|
|
logger.info("Adding embeddings...")
|
|
corpus_embeddings = corpus_embeddings.astype(np.float32)
|
|
faiss_index.train(corpus_embeddings)
|
|
faiss_index.add(corpus_embeddings)
|
|
|
|
|
|
|
|
return faiss_index
|
|
|
|
|
|
def search(model: Flag_mmret, queries: datasets, faiss_index: faiss.Index, k:int = 100, batch_size: int = 256, max_length: int=512):
|
|
"""
|
|
1. Encode queries into dense embeddings;
|
|
2. Search through faiss index
|
|
"""
|
|
query_embeddings = model.encode_queries([queries["q_text"], queries["q_img"]],
|
|
batch_size=batch_size,
|
|
max_length=max_length,
|
|
query_type='mm_it')
|
|
|
|
|
|
query_size = len(query_embeddings)
|
|
|
|
all_scores = []
|
|
all_indices = []
|
|
|
|
for i in tqdm(range(0, query_size, batch_size), desc="Searching"):
|
|
j = min(i + batch_size, query_size)
|
|
query_embedding = query_embeddings[i: j]
|
|
score, indice = faiss_index.search(query_embedding.astype(np.float32), k=k)
|
|
all_scores.append(score)
|
|
all_indices.append(indice)
|
|
|
|
all_scores = np.concatenate(all_scores, axis=0)
|
|
all_indices = np.concatenate(all_indices, axis=0)
|
|
return all_scores, all_indices
|
|
|
|
|
|
def main():
|
|
parser = HfArgumentParser([Args])
|
|
args: Args = parser.parse_args_into_dataclasses()[0]
|
|
|
|
print(f"Results will be saved in {args.result_save_path}")
|
|
eval_data = datasets.load_dataset('json', data_files="./eval/data/circo_query.jsonl", split='train')
|
|
image_corpus_test = datasets.load_dataset('json', data_files="./eval/data/circo_corpus.jsonl", split='train')
|
|
|
|
model = Flag_mmret(model_name=args.model_name,
|
|
normlized = True,
|
|
image_dir=args.image_dir,
|
|
use_fp16=False,
|
|
)
|
|
|
|
|
|
faiss_index = index(
|
|
model=model,
|
|
corpus=image_corpus_test,
|
|
batch_size=args.batch_size,
|
|
max_length=args.max_passage_length,
|
|
index_factory=args.index_factory,
|
|
save_path=args.save_path,
|
|
save_embedding=args.save_embedding,
|
|
load_embedding=args.load_embedding
|
|
)
|
|
|
|
scores, indices = search(
|
|
model=model,
|
|
queries=eval_data,
|
|
faiss_index=faiss_index,
|
|
k=args.k,
|
|
batch_size=args.batch_size,
|
|
max_length=args.max_query_length
|
|
)
|
|
|
|
|
|
retrieval_results = []
|
|
for indice in indices:
|
|
# filter invalid indices
|
|
indice = indice[indice != -1].tolist()
|
|
retrieval_results.append(image_corpus_test[indice]["content"])
|
|
|
|
########## results in test corpus #########
|
|
q_images = eval_data["q_img"]
|
|
|
|
q_ids = []
|
|
for _img in q_images:
|
|
_id = os.path.basename(_img)
|
|
_id = os.path.splitext(_id)[0]
|
|
q_ids.append(_id)
|
|
|
|
pairids = eval_data["id"]
|
|
results = {}
|
|
for pairid, re_results, q_img in zip(pairids, retrieval_results, q_images):
|
|
id = str(pairid)
|
|
top_50_results = re_results[0:50]
|
|
results[id] = top_50_results
|
|
|
|
with open(args.result_save_path, "w") as f:
|
|
json.dump(results, f)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |