206 lines
8.0 KiB
Python
206 lines
8.0 KiB
Python
from typing import cast, List, Union, Tuple
|
|
import numpy as np
|
|
import torch
|
|
from tqdm import tqdm
|
|
from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification, CLIPModel, CLIPImageProcessor, CLIPTokenizer
|
|
import os
|
|
from PIL import Image
|
|
from torch.utils.data import DataLoader
|
|
from torch import nn
|
|
from flag_dataset import MMIT_Dataset, MMIT_Collator, Image_Dataset, Image_Collator
|
|
class Flag_mmret(nn.Module):
|
|
def __init__(
|
|
self,
|
|
model_name: str = None,
|
|
normlized: bool = True,
|
|
pooling_method: str = 'cls',
|
|
use_fp16: bool=True,
|
|
image_dir: str = None,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.model = AutoModel.from_pretrained(model_name)
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
self.image_processor = CLIPImageProcessor.from_pretrained(model_name)
|
|
|
|
|
|
self.normalize_embeddings = normlized
|
|
self.pooling_method = pooling_method
|
|
|
|
self.image_dir = image_dir
|
|
|
|
if use_fp16:
|
|
self.use_fp16 = True
|
|
self.model.half()
|
|
else:
|
|
self.use_fp16 = False
|
|
|
|
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
self.model = self.model.to(self.device)
|
|
|
|
self.num_gpus = torch.cuda.device_count()
|
|
if self.num_gpus > 1:
|
|
print(f"----------using {self.num_gpus}*GPUs----------")
|
|
self.model = torch.nn.DataParallel(self.model)
|
|
|
|
|
|
def encode_queries(self, queries: Union[List[str], str],
|
|
batch_size: int=256,
|
|
max_length: int=77,
|
|
query_type: str = None,
|
|
) -> np.ndarray:
|
|
|
|
|
|
if query_type == 'text':
|
|
input_texts = queries
|
|
|
|
return self.encode_text(input_texts, batch_size=batch_size, max_length=max_length)
|
|
elif query_type == 'mm_it':
|
|
q_text, q_img = queries
|
|
|
|
input_texts = q_text
|
|
|
|
|
|
return self.encode_mm_it(input_texts, q_img, batch_size=batch_size)
|
|
elif query_type == 'image':
|
|
q_img = queries
|
|
return self.encode_image(q_img, batch_size=batch_size)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
def encode_corpus(self,
|
|
corpus: dict,
|
|
batch_size: int=256,
|
|
max_length: int=77,
|
|
corpus_type: str = None,
|
|
) -> np.ndarray:
|
|
if corpus_type == 'text':
|
|
return self.encode_text(corpus["text"], batch_size=batch_size, max_length=max_length)
|
|
elif corpus_type == 'mm_it':
|
|
return self.encode_mm_it(corpus["text"], corpus["image"], batch_size=batch_size, max_length=max_length)
|
|
elif corpus_type == 'image':
|
|
return self.encode_image(corpus["image"], batch_size=batch_size, max_length=max_length)
|
|
else:
|
|
raise RuntimeError(f"You must choose a corpus type from: [mm_it, text, image]")
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
def encode_text(self, sentences: Union[List[str], str], batch_size: int=256, max_length: int=77) -> np.ndarray:
|
|
if self.num_gpus > 0:
|
|
batch_size = batch_size * self.num_gpus
|
|
self.model.eval()
|
|
|
|
input_was_string = False
|
|
if isinstance(sentences, str):
|
|
sentences = [sentences]
|
|
input_was_string = True
|
|
|
|
all_embeddings = []
|
|
for start_index in tqdm(range(0, len(sentences), batch_size), desc="Inference Embeddings", disable=len(sentences)<256):
|
|
sentences_batch = sentences[start_index:start_index + batch_size]
|
|
inputs = self.tokenizer(
|
|
sentences_batch,
|
|
padding=True,
|
|
truncation=True,
|
|
return_tensors='pt',
|
|
max_length=max_length,
|
|
).to(self.device)
|
|
|
|
embeddings = self.model.get_text_features(**inputs)
|
|
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
|
|
embeddings = cast(torch.Tensor, embeddings)
|
|
all_embeddings.append(embeddings.cpu().numpy())
|
|
|
|
all_embeddings = np.concatenate(all_embeddings, axis=0)
|
|
if input_was_string:
|
|
return all_embeddings[0]
|
|
return all_embeddings
|
|
|
|
|
|
@torch.no_grad()
|
|
def encode_mm_it(self, captions: Union[List[str], str], image_ids: Union[List[str], str], batch_size: int=256, max_length: int=77) -> np.ndarray:
|
|
if self.num_gpus > 0:
|
|
batch_size = batch_size * self.num_gpus
|
|
self.model.eval()
|
|
|
|
input_was_string = False
|
|
if isinstance(captions, str):
|
|
captions = [captions]
|
|
image_ids = [image_ids]
|
|
input_was_string = True
|
|
|
|
all_embeddings = []
|
|
mm_it_dataset = MMIT_Dataset(captions=captions,
|
|
image_ids=image_ids,
|
|
image_dir=self.image_dir,
|
|
image_processor=self.image_processor
|
|
)
|
|
mm_it_collator = MMIT_Collator(self.tokenizer, caption_max_len=75)
|
|
|
|
mm_it_dataloader = DataLoader(dataset=mm_it_dataset,
|
|
collate_fn=mm_it_collator,
|
|
num_workers=8,
|
|
batch_size=batch_size,
|
|
shuffle=False,
|
|
drop_last=False,)
|
|
|
|
for data in tqdm(mm_it_dataloader, desc="Inference Embeddings", disable=len(captions)<256):
|
|
captions_inputs = data[0].to(self.device)
|
|
|
|
images = data[1].to(self.device)
|
|
if self.use_fp16 and images.dtype != torch.float16:
|
|
images = images.half()
|
|
|
|
text_embeddings = self.model.get_text_features(**captions_inputs)
|
|
image_embeddings = self.model.get_image_features(images)
|
|
|
|
embeddings = text_embeddings + image_embeddings
|
|
|
|
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
|
|
|
|
embeddings = cast(torch.Tensor, embeddings)
|
|
all_embeddings.append(embeddings.cpu().numpy())
|
|
|
|
all_embeddings = np.concatenate(all_embeddings, axis=0)
|
|
if input_was_string:
|
|
return all_embeddings[0]
|
|
return all_embeddings
|
|
|
|
@torch.no_grad()
|
|
def encode_image(self, image_ids: Union[List[str], str], batch_size: int=256, max_length: int=77) -> np.ndarray:
|
|
if self.num_gpus > 0:
|
|
batch_size = batch_size * self.num_gpus
|
|
self.model.eval()
|
|
|
|
all_embeddings = []
|
|
image_dataset = Image_Dataset(image_ids=image_ids,
|
|
image_dir=self.image_dir,
|
|
image_processor=self.image_processor
|
|
)
|
|
image_collator = Image_Collator(self.tokenizer, caption_max_len=312)
|
|
|
|
image_dataloader = DataLoader(dataset=image_dataset,
|
|
collate_fn=image_collator,
|
|
num_workers=8,
|
|
batch_size=batch_size,
|
|
shuffle=False,
|
|
drop_last=False,)
|
|
|
|
for data in tqdm(image_dataloader, desc="Inference Image Embeddings"):
|
|
|
|
images = data.to(self.device)
|
|
if self.use_fp16 and images.dtype != torch.float16:
|
|
images = images.half()
|
|
|
|
|
|
|
|
embeddings = self.model.get_image_features(images)
|
|
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
|
|
embeddings = cast(torch.Tensor, embeddings)
|
|
all_embeddings.append(embeddings.cpu().numpy())
|
|
|
|
all_embeddings = np.concatenate(all_embeddings, axis=0)
|
|
|
|
return all_embeddings |