104 lines
2.8 KiB
Python
104 lines
2.8 KiB
Python
import random
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torchvision import transforms as T
|
|
from torchvision.transforms import functional as F
|
|
|
|
|
|
def pad_if_smaller(img, size, fill=0):
|
|
min_size = min(img.size)
|
|
if min_size < size:
|
|
ow, oh = img.size
|
|
padh = size - oh if oh < size else 0
|
|
padw = size - ow if ow < size else 0
|
|
img = F.pad(img, (0, 0, padw, padh), fill=fill)
|
|
return img
|
|
|
|
|
|
class Compose:
|
|
def __init__(self, transforms):
|
|
self.transforms = transforms
|
|
|
|
def __call__(self, image, target):
|
|
for t in self.transforms:
|
|
image, target = t(image, target)
|
|
return image, target
|
|
|
|
|
|
class RandomResize:
|
|
def __init__(self, min_size, max_size=None):
|
|
self.min_size = min_size
|
|
if max_size is None:
|
|
max_size = min_size
|
|
self.max_size = max_size
|
|
|
|
def __call__(self, image, target):
|
|
size = random.randint(self.min_size, self.max_size)
|
|
image = F.resize(image, size, antialias=True)
|
|
target = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST)
|
|
return image, target
|
|
|
|
|
|
class RandomHorizontalFlip:
|
|
def __init__(self, flip_prob):
|
|
self.flip_prob = flip_prob
|
|
|
|
def __call__(self, image, target):
|
|
if random.random() < self.flip_prob:
|
|
image = F.hflip(image)
|
|
target = F.hflip(target)
|
|
return image, target
|
|
|
|
|
|
class RandomCrop:
|
|
def __init__(self, size):
|
|
self.size = size
|
|
|
|
def __call__(self, image, target):
|
|
image = pad_if_smaller(image, self.size)
|
|
target = pad_if_smaller(target, self.size, fill=255)
|
|
crop_params = T.RandomCrop.get_params(image, (self.size, self.size))
|
|
image = F.crop(image, *crop_params)
|
|
target = F.crop(target, *crop_params)
|
|
return image, target
|
|
|
|
|
|
class CenterCrop:
|
|
def __init__(self, size):
|
|
self.size = size
|
|
|
|
def __call__(self, image, target):
|
|
image = F.center_crop(image, self.size)
|
|
target = F.center_crop(target, self.size)
|
|
return image, target
|
|
|
|
|
|
class PILToTensor:
|
|
def __call__(self, image, target):
|
|
image = F.pil_to_tensor(image)
|
|
target = torch.as_tensor(np.array(target), dtype=torch.int64)
|
|
return image, target
|
|
|
|
|
|
class ToDtype:
|
|
def __init__(self, dtype, scale=False):
|
|
self.dtype = dtype
|
|
self.scale = scale
|
|
|
|
def __call__(self, image, target):
|
|
if not self.scale:
|
|
return image.to(dtype=self.dtype), target
|
|
image = F.convert_image_dtype(image, self.dtype)
|
|
return image, target
|
|
|
|
|
|
class Normalize:
|
|
def __init__(self, mean, std):
|
|
self.mean = mean
|
|
self.std = std
|
|
|
|
def __call__(self, image, target):
|
|
image = F.normalize(image, mean=self.mean, std=self.std)
|
|
return image, target
|