embed-bge-m3/FlagEmbedding/research/BGE_VL/eval/eval_fashioniq.py

342 lines
10 KiB
Python

import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
import sys
print(os.getcwd())
import faiss
import torch
import logging
import datasets
import numpy as np
from tqdm import tqdm
from typing import Optional
from dataclasses import dataclass, field
from transformers import HfArgumentParser
from flag_mmret import Flag_mmret
import json
logger = logging.getLogger(__name__)
@dataclass
class Args:
model_name: str = field(
default="BAAI/BGE-VL-large",
metadata={'help': 'Model Name'}
)
image_dir: str = field(
default="YOUR_FASHIONIQ_IMAGE_DIRECTORY",
metadata={'help': 'Where are the images located on.'}
)
fp16: bool = field(
default=False,
metadata={'help': 'Use fp16 in inference?'}
)
max_query_length: int = field(
default=64,
metadata={'help': 'Max query length.'}
)
max_passage_length: int = field(
default=77,
metadata={'help': 'Max passage length.'}
)
batch_size: int = field(
default=256,
metadata={'help': 'Inference batch size.'}
)
index_factory: str = field(
default="Flat",
metadata={'help': 'Faiss index factory.'}
)
k: int = field(
default=100,
metadata={'help': 'How many neighbors to retrieve?'}
)
save_embedding: bool = field(
default=False,
metadata={'help': 'Save embeddings in memmap at save_dir?'}
)
load_embedding: bool = field(
default=False,
metadata={'help': 'Load embeddings from save_dir?'}
)
save_path: str = field(
default="embeddings.memmap",
metadata={'help': 'Path to save embeddings.'}
)
def index(model: Flag_mmret, corpus: datasets.Dataset, batch_size: int = 256, max_length: int=512, index_factory: str = "Flat", save_path: str = None, save_embedding: bool = False, load_embedding: bool = False):
"""
1. Encode the entire corpus into dense embeddings;
2. Create faiss index;
3. Optionally save embeddings.
"""
if load_embedding:
test = model.encode("test")
dtype = test.dtype
dim = len(test)
corpus_embeddings = np.memmap(
save_path,
mode="r",
dtype=dtype
).reshape(-1, dim)
else:
corpus_embeddings = model.encode_corpus(corpus, batch_size=batch_size, max_length=max_length, corpus_type='image')
dim = corpus_embeddings.shape[-1]
if save_embedding:
logger.info(f"saving embeddings at {save_path}...")
memmap = np.memmap(
save_path,
shape=corpus_embeddings.shape,
mode="w+",
dtype=corpus_embeddings.dtype
)
length = corpus_embeddings.shape[0]
# add in batch
save_batch_size = 10000
if length > save_batch_size:
for i in tqdm(range(0, length, save_batch_size), leave=False, desc="Saving Embeddings"):
j = min(i + save_batch_size, length)
memmap[i: j] = corpus_embeddings[i: j]
else:
memmap[:] = corpus_embeddings
# create faiss index
faiss_index = faiss.index_factory(dim, index_factory, faiss.METRIC_INNER_PRODUCT)
if model.device == torch.device("cuda"):
# co = faiss.GpuClonerOptions()
co = faiss.GpuMultipleClonerOptions()
co.useFloat16 = True
# faiss_index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, faiss_index, co)
faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co)
# NOTE: faiss only accepts float32
logger.info("Adding embeddings...")
corpus_embeddings = corpus_embeddings.astype(np.float32)
faiss_index.train(corpus_embeddings)
faiss_index.add(corpus_embeddings)
return faiss_index
def search(model: Flag_mmret, queries: datasets, faiss_index: faiss.Index, k:int = 100, batch_size: int = 256, max_length: int=512):
"""
1. Encode queries into dense embeddings;
2. Search through faiss index
"""
query_embeddings = model.encode_queries([queries["q_text"], queries["q_img"]],
batch_size=batch_size,
max_length=max_length,
query_type='mm_it')
query_size = len(query_embeddings)
all_scores = []
all_indices = []
for i in tqdm(range(0, query_size, batch_size), desc="Searching"):
j = min(i + batch_size, query_size)
query_embedding = query_embeddings[i: j]
score, indice = faiss_index.search(query_embedding.astype(np.float32), k=k)
all_scores.append(score)
all_indices.append(indice)
all_scores = np.concatenate(all_scores, axis=0)
all_indices = np.concatenate(all_indices, axis=0)
return all_scores, all_indices
def evaluate(preds, labels, cutoffs=[1,5,10,20,50,100]):
"""
Evaluate MRR and Recall at cutoffs.
"""
metrics = {}
# MRR
mrrs = np.zeros(len(cutoffs))
for pred, label in zip(preds, labels):
jump = False
for i, x in enumerate(pred, 1):
if x in label:
for k, cutoff in enumerate(cutoffs):
if i <= cutoff:
mrrs[k] += 1 / i
jump = True
if jump:
break
mrrs /= len(preds)
for i, cutoff in enumerate(cutoffs):
mrr = mrrs[i]
metrics[f"MRR@{cutoff}"] = mrr
# Recall
recalls = np.zeros(len(cutoffs))
for pred, label in zip(preds, labels):
if not isinstance(label, list):
label = [label]
for k, cutoff in enumerate(cutoffs):
recall = np.intersect1d(label, pred[:cutoff])
recalls[k] += len(recall) / len(label)
recalls /= len(preds)
for i, cutoff in enumerate(cutoffs):
recall = recalls[i]
metrics[f"Recall@{cutoff}"] = recall
return metrics
def main():
parser = HfArgumentParser([Args])
args: Args = parser.parse_args_into_dataclasses()[0]
model = Flag_mmret(model_name=args.model_name,
normlized = True,
image_dir=args.image_dir,
use_fp16=False,
)
eval_data = datasets.load_dataset('json', data_files="./eval/data/fashioniq_shirt_query_val.jsonl", split='train')
image_corpus = datasets.load_dataset('json', data_files="./eval/data/fashioniq_shirt_corpus.jsonl", split='train')
faiss_index = index(
model=model,
corpus=image_corpus,
batch_size=args.batch_size,
max_length=args.max_passage_length,
index_factory=args.index_factory,
save_path=args.save_path,
save_embedding=args.save_embedding,
load_embedding=args.load_embedding
)
scores, indices = search(
model=model,
queries=eval_data,
faiss_index=faiss_index,
k=args.k,
batch_size=args.batch_size,
max_length=args.max_query_length
)
retrieval_results = []
for indice in indices:
# filter invalid indices
indice = indice[indice != -1].tolist()
retrieval_results.append(image_corpus[indice]["content"])
ground_truths = []
for sample in eval_data:
ground_truths.append(sample["positive_key"])
metrics_shirt = evaluate(retrieval_results, ground_truths)
print("FashionIQ tasks (shirt):")
print(metrics_shirt)
eval_data = datasets.load_dataset('json', data_files="./eval/data/fashioniq_dress_query_val.jsonl", split='train')
image_corpus = datasets.load_dataset('json', data_files="./eval/data/fashioniq_dress_corpus.jsonl", split='train')
faiss_index = index(
model=model,
corpus=image_corpus,
batch_size=args.batch_size,
max_length=args.max_passage_length,
index_factory=args.index_factory,
save_path=args.save_path,
save_embedding=args.save_embedding,
load_embedding=args.load_embedding
)
scores, indices = search(
model=model,
queries=eval_data,
faiss_index=faiss_index,
k=args.k,
batch_size=args.batch_size,
max_length=args.max_query_length
)
retrieval_results = []
for indice in indices:
# filter invalid indices
indice = indice[indice != -1].tolist()
retrieval_results.append(image_corpus[indice]["content"])
ground_truths = []
for sample in eval_data:
ground_truths.append(sample["positive_key"])
metrics_dress = evaluate(retrieval_results, ground_truths)
print("FashionIQ tasks (dress):")
print(metrics_dress)
eval_data = datasets.load_dataset('json', data_files="./eval/data/fashioniq_toptee_query_val.jsonl", split='train')
image_corpus = datasets.load_dataset('json', data_files="./eval/data/fashioniq_toptee_corpus.jsonl", split='train')
faiss_index = index(
model=model,
corpus=image_corpus,
batch_size=args.batch_size,
max_length=args.max_passage_length,
index_factory=args.index_factory,
save_path=args.save_path,
save_embedding=args.save_embedding,
load_embedding=args.load_embedding
)
scores, indices = search(
model=model,
queries=eval_data,
faiss_index=faiss_index,
k=args.k,
batch_size=args.batch_size,
max_length=args.max_query_length
)
retrieval_results = []
for indice in indices:
# filter invalid indices
indice = indice[indice != -1].tolist()
retrieval_results.append(image_corpus[indice]["content"])
ground_truths = []
for sample in eval_data:
ground_truths.append(sample["positive_key"])
metrics_toptee = evaluate(retrieval_results, ground_truths)
print("FashionIQ tasks (toptee):")
print(metrics_toptee)
print(f"shirt: {metrics_shirt['Recall@10'] * 100:.2f} / {metrics_shirt['Recall@50'] * 100:.2f}")
print(f"dress: {metrics_dress['Recall@10'] * 100:.2f} / {metrics_dress['Recall@50'] * 100:.2f}")
print(f"toptee: {metrics_toptee['Recall@10'] * 100:.2f} / {metrics_toptee['Recall@50'] * 100:.2f}")
print(f"overall: {(metrics_shirt['Recall@10'] + metrics_dress['Recall@10'] + metrics_toptee['Recall@10']) * 100 / 3:.2f} / {(metrics_shirt['Recall@50'] + metrics_dress['Recall@50'] + metrics_toptee['Recall@50']) * 100 / 3:.2f}")
if __name__ == "__main__":
main()