342 lines
10 KiB
Python
342 lines
10 KiB
Python
import os
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
|
|
|
|
import sys
|
|
print(os.getcwd())
|
|
|
|
import faiss
|
|
import torch
|
|
import logging
|
|
import datasets
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from typing import Optional
|
|
from dataclasses import dataclass, field
|
|
from transformers import HfArgumentParser
|
|
from flag_mmret import Flag_mmret
|
|
import json
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class Args:
|
|
model_name: str = field(
|
|
default="BAAI/BGE-VL-large",
|
|
metadata={'help': 'Model Name'}
|
|
)
|
|
image_dir: str = field(
|
|
default="YOUR_FASHIONIQ_IMAGE_DIRECTORY",
|
|
metadata={'help': 'Where are the images located on.'}
|
|
)
|
|
fp16: bool = field(
|
|
default=False,
|
|
metadata={'help': 'Use fp16 in inference?'}
|
|
)
|
|
max_query_length: int = field(
|
|
default=64,
|
|
metadata={'help': 'Max query length.'}
|
|
)
|
|
max_passage_length: int = field(
|
|
default=77,
|
|
metadata={'help': 'Max passage length.'}
|
|
)
|
|
batch_size: int = field(
|
|
default=256,
|
|
metadata={'help': 'Inference batch size.'}
|
|
)
|
|
index_factory: str = field(
|
|
default="Flat",
|
|
metadata={'help': 'Faiss index factory.'}
|
|
)
|
|
k: int = field(
|
|
default=100,
|
|
metadata={'help': 'How many neighbors to retrieve?'}
|
|
)
|
|
save_embedding: bool = field(
|
|
default=False,
|
|
metadata={'help': 'Save embeddings in memmap at save_dir?'}
|
|
)
|
|
load_embedding: bool = field(
|
|
default=False,
|
|
metadata={'help': 'Load embeddings from save_dir?'}
|
|
)
|
|
save_path: str = field(
|
|
default="embeddings.memmap",
|
|
metadata={'help': 'Path to save embeddings.'}
|
|
)
|
|
|
|
|
|
|
|
def index(model: Flag_mmret, corpus: datasets.Dataset, batch_size: int = 256, max_length: int=512, index_factory: str = "Flat", save_path: str = None, save_embedding: bool = False, load_embedding: bool = False):
|
|
"""
|
|
1. Encode the entire corpus into dense embeddings;
|
|
2. Create faiss index;
|
|
3. Optionally save embeddings.
|
|
"""
|
|
if load_embedding:
|
|
test = model.encode("test")
|
|
dtype = test.dtype
|
|
dim = len(test)
|
|
|
|
corpus_embeddings = np.memmap(
|
|
save_path,
|
|
mode="r",
|
|
dtype=dtype
|
|
).reshape(-1, dim)
|
|
|
|
else:
|
|
|
|
corpus_embeddings = model.encode_corpus(corpus, batch_size=batch_size, max_length=max_length, corpus_type='image')
|
|
|
|
dim = corpus_embeddings.shape[-1]
|
|
|
|
if save_embedding:
|
|
logger.info(f"saving embeddings at {save_path}...")
|
|
memmap = np.memmap(
|
|
save_path,
|
|
shape=corpus_embeddings.shape,
|
|
mode="w+",
|
|
dtype=corpus_embeddings.dtype
|
|
)
|
|
|
|
length = corpus_embeddings.shape[0]
|
|
# add in batch
|
|
save_batch_size = 10000
|
|
if length > save_batch_size:
|
|
for i in tqdm(range(0, length, save_batch_size), leave=False, desc="Saving Embeddings"):
|
|
j = min(i + save_batch_size, length)
|
|
memmap[i: j] = corpus_embeddings[i: j]
|
|
else:
|
|
memmap[:] = corpus_embeddings
|
|
|
|
# create faiss index
|
|
faiss_index = faiss.index_factory(dim, index_factory, faiss.METRIC_INNER_PRODUCT)
|
|
|
|
|
|
if model.device == torch.device("cuda"):
|
|
# co = faiss.GpuClonerOptions()
|
|
co = faiss.GpuMultipleClonerOptions()
|
|
co.useFloat16 = True
|
|
# faiss_index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, faiss_index, co)
|
|
faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co)
|
|
|
|
# NOTE: faiss only accepts float32
|
|
logger.info("Adding embeddings...")
|
|
corpus_embeddings = corpus_embeddings.astype(np.float32)
|
|
faiss_index.train(corpus_embeddings)
|
|
faiss_index.add(corpus_embeddings)
|
|
|
|
|
|
|
|
return faiss_index
|
|
|
|
|
|
def search(model: Flag_mmret, queries: datasets, faiss_index: faiss.Index, k:int = 100, batch_size: int = 256, max_length: int=512):
|
|
"""
|
|
1. Encode queries into dense embeddings;
|
|
2. Search through faiss index
|
|
"""
|
|
query_embeddings = model.encode_queries([queries["q_text"], queries["q_img"]],
|
|
batch_size=batch_size,
|
|
max_length=max_length,
|
|
query_type='mm_it')
|
|
|
|
query_size = len(query_embeddings)
|
|
|
|
all_scores = []
|
|
all_indices = []
|
|
|
|
for i in tqdm(range(0, query_size, batch_size), desc="Searching"):
|
|
j = min(i + batch_size, query_size)
|
|
query_embedding = query_embeddings[i: j]
|
|
score, indice = faiss_index.search(query_embedding.astype(np.float32), k=k)
|
|
all_scores.append(score)
|
|
all_indices.append(indice)
|
|
|
|
all_scores = np.concatenate(all_scores, axis=0)
|
|
all_indices = np.concatenate(all_indices, axis=0)
|
|
return all_scores, all_indices
|
|
|
|
|
|
def evaluate(preds, labels, cutoffs=[1,5,10,20,50,100]):
|
|
"""
|
|
Evaluate MRR and Recall at cutoffs.
|
|
"""
|
|
metrics = {}
|
|
|
|
# MRR
|
|
mrrs = np.zeros(len(cutoffs))
|
|
for pred, label in zip(preds, labels):
|
|
jump = False
|
|
for i, x in enumerate(pred, 1):
|
|
if x in label:
|
|
for k, cutoff in enumerate(cutoffs):
|
|
if i <= cutoff:
|
|
mrrs[k] += 1 / i
|
|
jump = True
|
|
if jump:
|
|
break
|
|
mrrs /= len(preds)
|
|
for i, cutoff in enumerate(cutoffs):
|
|
mrr = mrrs[i]
|
|
metrics[f"MRR@{cutoff}"] = mrr
|
|
|
|
# Recall
|
|
recalls = np.zeros(len(cutoffs))
|
|
for pred, label in zip(preds, labels):
|
|
if not isinstance(label, list):
|
|
label = [label]
|
|
for k, cutoff in enumerate(cutoffs):
|
|
recall = np.intersect1d(label, pred[:cutoff])
|
|
recalls[k] += len(recall) / len(label)
|
|
recalls /= len(preds)
|
|
for i, cutoff in enumerate(cutoffs):
|
|
recall = recalls[i]
|
|
metrics[f"Recall@{cutoff}"] = recall
|
|
|
|
return metrics
|
|
|
|
def main():
|
|
parser = HfArgumentParser([Args])
|
|
args: Args = parser.parse_args_into_dataclasses()[0]
|
|
|
|
model = Flag_mmret(model_name=args.model_name,
|
|
normlized = True,
|
|
image_dir=args.image_dir,
|
|
use_fp16=False,
|
|
)
|
|
|
|
eval_data = datasets.load_dataset('json', data_files="./eval/data/fashioniq_shirt_query_val.jsonl", split='train')
|
|
image_corpus = datasets.load_dataset('json', data_files="./eval/data/fashioniq_shirt_corpus.jsonl", split='train')
|
|
|
|
faiss_index = index(
|
|
model=model,
|
|
corpus=image_corpus,
|
|
batch_size=args.batch_size,
|
|
max_length=args.max_passage_length,
|
|
index_factory=args.index_factory,
|
|
save_path=args.save_path,
|
|
save_embedding=args.save_embedding,
|
|
load_embedding=args.load_embedding
|
|
)
|
|
|
|
scores, indices = search(
|
|
model=model,
|
|
queries=eval_data,
|
|
faiss_index=faiss_index,
|
|
k=args.k,
|
|
batch_size=args.batch_size,
|
|
max_length=args.max_query_length
|
|
)
|
|
|
|
|
|
|
|
retrieval_results = []
|
|
for indice in indices:
|
|
# filter invalid indices
|
|
indice = indice[indice != -1].tolist()
|
|
retrieval_results.append(image_corpus[indice]["content"])
|
|
|
|
ground_truths = []
|
|
for sample in eval_data:
|
|
ground_truths.append(sample["positive_key"])
|
|
|
|
metrics_shirt = evaluate(retrieval_results, ground_truths)
|
|
print("FashionIQ tasks (shirt):")
|
|
print(metrics_shirt)
|
|
|
|
|
|
|
|
|
|
|
|
eval_data = datasets.load_dataset('json', data_files="./eval/data/fashioniq_dress_query_val.jsonl", split='train')
|
|
image_corpus = datasets.load_dataset('json', data_files="./eval/data/fashioniq_dress_corpus.jsonl", split='train')
|
|
|
|
faiss_index = index(
|
|
model=model,
|
|
corpus=image_corpus,
|
|
batch_size=args.batch_size,
|
|
max_length=args.max_passage_length,
|
|
index_factory=args.index_factory,
|
|
save_path=args.save_path,
|
|
save_embedding=args.save_embedding,
|
|
load_embedding=args.load_embedding
|
|
)
|
|
|
|
scores, indices = search(
|
|
model=model,
|
|
queries=eval_data,
|
|
faiss_index=faiss_index,
|
|
k=args.k,
|
|
batch_size=args.batch_size,
|
|
max_length=args.max_query_length
|
|
)
|
|
|
|
|
|
|
|
retrieval_results = []
|
|
for indice in indices:
|
|
# filter invalid indices
|
|
indice = indice[indice != -1].tolist()
|
|
retrieval_results.append(image_corpus[indice]["content"])
|
|
|
|
ground_truths = []
|
|
for sample in eval_data:
|
|
ground_truths.append(sample["positive_key"])
|
|
|
|
metrics_dress = evaluate(retrieval_results, ground_truths)
|
|
print("FashionIQ tasks (dress):")
|
|
print(metrics_dress)
|
|
|
|
|
|
eval_data = datasets.load_dataset('json', data_files="./eval/data/fashioniq_toptee_query_val.jsonl", split='train')
|
|
image_corpus = datasets.load_dataset('json', data_files="./eval/data/fashioniq_toptee_corpus.jsonl", split='train')
|
|
|
|
|
|
faiss_index = index(
|
|
model=model,
|
|
corpus=image_corpus,
|
|
batch_size=args.batch_size,
|
|
max_length=args.max_passage_length,
|
|
index_factory=args.index_factory,
|
|
save_path=args.save_path,
|
|
save_embedding=args.save_embedding,
|
|
load_embedding=args.load_embedding
|
|
)
|
|
|
|
scores, indices = search(
|
|
model=model,
|
|
queries=eval_data,
|
|
faiss_index=faiss_index,
|
|
k=args.k,
|
|
batch_size=args.batch_size,
|
|
max_length=args.max_query_length
|
|
)
|
|
|
|
|
|
|
|
retrieval_results = []
|
|
for indice in indices:
|
|
# filter invalid indices
|
|
indice = indice[indice != -1].tolist()
|
|
retrieval_results.append(image_corpus[indice]["content"])
|
|
|
|
ground_truths = []
|
|
for sample in eval_data:
|
|
ground_truths.append(sample["positive_key"])
|
|
|
|
metrics_toptee = evaluate(retrieval_results, ground_truths)
|
|
print("FashionIQ tasks (toptee):")
|
|
print(metrics_toptee)
|
|
|
|
|
|
|
|
print(f"shirt: {metrics_shirt['Recall@10'] * 100:.2f} / {metrics_shirt['Recall@50'] * 100:.2f}")
|
|
print(f"dress: {metrics_dress['Recall@10'] * 100:.2f} / {metrics_dress['Recall@50'] * 100:.2f}")
|
|
print(f"toptee: {metrics_toptee['Recall@10'] * 100:.2f} / {metrics_toptee['Recall@50'] * 100:.2f}")
|
|
print(f"overall: {(metrics_shirt['Recall@10'] + metrics_dress['Recall@10'] + metrics_toptee['Recall@10']) * 100 / 3:.2f} / {(metrics_shirt['Recall@50'] + metrics_dress['Recall@50'] + metrics_toptee['Recall@50']) * 100 / 3:.2f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |