embed-bge-m3/FlagEmbedding/research/BGE_VL/eval/flag_dataset.py

109 lines
3.1 KiB
Python

import math
import os.path
import random
from dataclasses import dataclass
from typing import Iterator
import datasets
from torch.utils.data import Dataset, IterableDataset
from transformers import DataCollatorWithPadding
from transformers import PreTrainedTokenizer, BatchEncoding
from transformers import CLIPImageProcessor
from PIL import Image
import json
import torch
import torch.distributed
from io import BytesIO
import warnings
class MMIT_Dataset(Dataset):
def __init__(self, captions, image_ids, image_dir, image_processor) -> None:
img_id_example = image_ids[0]
img_id_example = str(img_id_example)
if img_id_example[-4:] in [".jpg", ".png", "JPEG"]:
self.image_path =[os.path.join(image_dir, str(id)) for id in image_ids]
else:
warnings.warn("Not found file extention in image_ids, will forcefully add '.jpg'.", UserWarning)
self.image_path =[os.path.join(image_dir, str(id) + '.jpg') for id in image_ids]
self.captions = captions
self.image_processor = image_processor
def __getitem__(self, item):
pil_data = Image.open(self.image_path[item])
pil_data = pil_data.convert('RGB')
image = self.image_processor(pil_data)
caption = self.captions[item]
return caption, image
def __len__(self):
return len(self.image_path)
class MMIT_Collator:
def __init__(self, tokenizer, caption_max_len):
self.tokenizer = tokenizer
self.caption_max_len = caption_max_len
def __call__(self, features):
caption = [f[0] for f in features]
images = [f[1] for f in features]
c_collated = self.tokenizer(
caption,
truncation=True,
padding = True,
max_length=self.caption_max_len,
return_tensors="pt",
)
# i_collated = torch.stack(images)
# for clip model
images = [f["pixel_values"][0] for f in images]
images = [torch.tensor(arr) for arr in images]
i_collated = torch.stack(images)
##clip_end
return c_collated, i_collated
class Image_Dataset(Dataset):
def __init__(self, image_ids, image_dir, image_processor) -> None:
self.image_path =[os.path.join(image_dir, str(id)) for id in image_ids]
self.image_processor = image_processor
def __getitem__(self, item):
pil_data = Image.open(self.image_path[item])
image = self.image_processor(pil_data)
return image
def __len__(self):
return len(self.image_path)
class Image_Collator:
def __init__(self, tokenizer, caption_max_len):
self.tokenizer = tokenizer
self.caption_max_len = caption_max_len
def __call__(self, features):
# images = features
# i_collated = torch.stack(images)
# for clip model
images = [f["pixel_values"][0] for f in features]
images = [torch.tensor(arr) for arr in images]
i_collated = torch.stack(images)
## clip-end
return i_collated