embed-bge-m3/FlagEmbedding/research/BGE_VL/eval/flag_mmret.py

206 lines
8.0 KiB
Python

from typing import cast, List, Union, Tuple
import numpy as np
import torch
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification, CLIPModel, CLIPImageProcessor, CLIPTokenizer
import os
from PIL import Image
from torch.utils.data import DataLoader
from torch import nn
from flag_dataset import MMIT_Dataset, MMIT_Collator, Image_Dataset, Image_Collator
class Flag_mmret(nn.Module):
def __init__(
self,
model_name: str = None,
normlized: bool = True,
pooling_method: str = 'cls',
use_fp16: bool=True,
image_dir: str = None,
) -> None:
super().__init__()
self.model = AutoModel.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.image_processor = CLIPImageProcessor.from_pretrained(model_name)
self.normalize_embeddings = normlized
self.pooling_method = pooling_method
self.image_dir = image_dir
if use_fp16:
self.use_fp16 = True
self.model.half()
else:
self.use_fp16 = False
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
self.model = self.model.to(self.device)
self.num_gpus = torch.cuda.device_count()
if self.num_gpus > 1:
print(f"----------using {self.num_gpus}*GPUs----------")
self.model = torch.nn.DataParallel(self.model)
def encode_queries(self, queries: Union[List[str], str],
batch_size: int=256,
max_length: int=77,
query_type: str = None,
) -> np.ndarray:
if query_type == 'text':
input_texts = queries
return self.encode_text(input_texts, batch_size=batch_size, max_length=max_length)
elif query_type == 'mm_it':
q_text, q_img = queries
input_texts = q_text
return self.encode_mm_it(input_texts, q_img, batch_size=batch_size)
elif query_type == 'image':
q_img = queries
return self.encode_image(q_img, batch_size=batch_size)
else:
raise NotImplementedError
def encode_corpus(self,
corpus: dict,
batch_size: int=256,
max_length: int=77,
corpus_type: str = None,
) -> np.ndarray:
if corpus_type == 'text':
return self.encode_text(corpus["text"], batch_size=batch_size, max_length=max_length)
elif corpus_type == 'mm_it':
return self.encode_mm_it(corpus["text"], corpus["image"], batch_size=batch_size, max_length=max_length)
elif corpus_type == 'image':
return self.encode_image(corpus["image"], batch_size=batch_size, max_length=max_length)
else:
raise RuntimeError(f"You must choose a corpus type from: [mm_it, text, image]")
@torch.no_grad()
def encode_text(self, sentences: Union[List[str], str], batch_size: int=256, max_length: int=77) -> np.ndarray:
if self.num_gpus > 0:
batch_size = batch_size * self.num_gpus
self.model.eval()
input_was_string = False
if isinstance(sentences, str):
sentences = [sentences]
input_was_string = True
all_embeddings = []
for start_index in tqdm(range(0, len(sentences), batch_size), desc="Inference Embeddings", disable=len(sentences)<256):
sentences_batch = sentences[start_index:start_index + batch_size]
inputs = self.tokenizer(
sentences_batch,
padding=True,
truncation=True,
return_tensors='pt',
max_length=max_length,
).to(self.device)
embeddings = self.model.get_text_features(**inputs)
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
embeddings = cast(torch.Tensor, embeddings)
all_embeddings.append(embeddings.cpu().numpy())
all_embeddings = np.concatenate(all_embeddings, axis=0)
if input_was_string:
return all_embeddings[0]
return all_embeddings
@torch.no_grad()
def encode_mm_it(self, captions: Union[List[str], str], image_ids: Union[List[str], str], batch_size: int=256, max_length: int=77) -> np.ndarray:
if self.num_gpus > 0:
batch_size = batch_size * self.num_gpus
self.model.eval()
input_was_string = False
if isinstance(captions, str):
captions = [captions]
image_ids = [image_ids]
input_was_string = True
all_embeddings = []
mm_it_dataset = MMIT_Dataset(captions=captions,
image_ids=image_ids,
image_dir=self.image_dir,
image_processor=self.image_processor
)
mm_it_collator = MMIT_Collator(self.tokenizer, caption_max_len=75)
mm_it_dataloader = DataLoader(dataset=mm_it_dataset,
collate_fn=mm_it_collator,
num_workers=8,
batch_size=batch_size,
shuffle=False,
drop_last=False,)
for data in tqdm(mm_it_dataloader, desc="Inference Embeddings", disable=len(captions)<256):
captions_inputs = data[0].to(self.device)
images = data[1].to(self.device)
if self.use_fp16 and images.dtype != torch.float16:
images = images.half()
text_embeddings = self.model.get_text_features(**captions_inputs)
image_embeddings = self.model.get_image_features(images)
embeddings = text_embeddings + image_embeddings
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
embeddings = cast(torch.Tensor, embeddings)
all_embeddings.append(embeddings.cpu().numpy())
all_embeddings = np.concatenate(all_embeddings, axis=0)
if input_was_string:
return all_embeddings[0]
return all_embeddings
@torch.no_grad()
def encode_image(self, image_ids: Union[List[str], str], batch_size: int=256, max_length: int=77) -> np.ndarray:
if self.num_gpus > 0:
batch_size = batch_size * self.num_gpus
self.model.eval()
all_embeddings = []
image_dataset = Image_Dataset(image_ids=image_ids,
image_dir=self.image_dir,
image_processor=self.image_processor
)
image_collator = Image_Collator(self.tokenizer, caption_max_len=312)
image_dataloader = DataLoader(dataset=image_dataset,
collate_fn=image_collator,
num_workers=8,
batch_size=batch_size,
shuffle=False,
drop_last=False,)
for data in tqdm(image_dataloader, desc="Inference Image Embeddings"):
images = data.to(self.device)
if self.use_fp16 and images.dtype != torch.float16:
images = images.half()
embeddings = self.model.get_image_features(images)
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
embeddings = cast(torch.Tensor, embeddings)
all_embeddings.append(embeddings.cpu().numpy())
all_embeddings = np.concatenate(all_embeddings, axis=0)
return all_embeddings