109 lines
3.1 KiB
Python
109 lines
3.1 KiB
Python
import math
|
|
import os.path
|
|
import random
|
|
from dataclasses import dataclass
|
|
from typing import Iterator
|
|
|
|
import datasets
|
|
from torch.utils.data import Dataset, IterableDataset
|
|
from transformers import DataCollatorWithPadding
|
|
from transformers import PreTrainedTokenizer, BatchEncoding
|
|
from transformers import CLIPImageProcessor
|
|
|
|
|
|
from PIL import Image
|
|
import json
|
|
import torch
|
|
import torch.distributed
|
|
|
|
from io import BytesIO
|
|
import warnings
|
|
|
|
class MMIT_Dataset(Dataset):
|
|
def __init__(self, captions, image_ids, image_dir, image_processor) -> None:
|
|
img_id_example = image_ids[0]
|
|
img_id_example = str(img_id_example)
|
|
if img_id_example[-4:] in [".jpg", ".png", "JPEG"]:
|
|
self.image_path =[os.path.join(image_dir, str(id)) for id in image_ids]
|
|
else:
|
|
warnings.warn("Not found file extention in image_ids, will forcefully add '.jpg'.", UserWarning)
|
|
self.image_path =[os.path.join(image_dir, str(id) + '.jpg') for id in image_ids]
|
|
self.captions = captions
|
|
self.image_processor = image_processor
|
|
|
|
def __getitem__(self, item):
|
|
pil_data = Image.open(self.image_path[item])
|
|
pil_data = pil_data.convert('RGB')
|
|
image = self.image_processor(pil_data)
|
|
|
|
|
|
|
|
|
|
caption = self.captions[item]
|
|
|
|
return caption, image
|
|
|
|
def __len__(self):
|
|
return len(self.image_path)
|
|
|
|
|
|
class MMIT_Collator:
|
|
def __init__(self, tokenizer, caption_max_len):
|
|
self.tokenizer = tokenizer
|
|
self.caption_max_len = caption_max_len
|
|
|
|
|
|
|
|
def __call__(self, features):
|
|
caption = [f[0] for f in features]
|
|
images = [f[1] for f in features]
|
|
|
|
c_collated = self.tokenizer(
|
|
caption,
|
|
truncation=True,
|
|
padding = True,
|
|
max_length=self.caption_max_len,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
# i_collated = torch.stack(images)
|
|
|
|
# for clip model
|
|
images = [f["pixel_values"][0] for f in images]
|
|
images = [torch.tensor(arr) for arr in images]
|
|
i_collated = torch.stack(images)
|
|
##clip_end
|
|
|
|
return c_collated, i_collated
|
|
|
|
class Image_Dataset(Dataset):
|
|
def __init__(self, image_ids, image_dir, image_processor) -> None:
|
|
|
|
self.image_path =[os.path.join(image_dir, str(id)) for id in image_ids]
|
|
self.image_processor = image_processor
|
|
|
|
def __getitem__(self, item):
|
|
pil_data = Image.open(self.image_path[item])
|
|
image = self.image_processor(pil_data)
|
|
|
|
return image
|
|
|
|
def __len__(self):
|
|
return len(self.image_path)
|
|
|
|
class Image_Collator:
|
|
def __init__(self, tokenizer, caption_max_len):
|
|
self.tokenizer = tokenizer
|
|
self.caption_max_len = caption_max_len
|
|
|
|
|
|
def __call__(self, features):
|
|
# images = features
|
|
# i_collated = torch.stack(images)
|
|
|
|
# for clip model
|
|
images = [f["pixel_values"][0] for f in features]
|
|
images = [torch.tensor(arr) for arr in images]
|
|
i_collated = torch.stack(images)
|
|
## clip-end
|
|
return i_collated |