651 lines
26 KiB
Python
651 lines
26 KiB
Python
import random
|
|
from typing import Callable, List, Optional, Sequence, Tuple, Union
|
|
|
|
import numpy as np
|
|
import PIL.Image
|
|
import torch
|
|
import torchvision.transforms as T
|
|
import torchvision.transforms.functional as F
|
|
from torch import Tensor
|
|
|
|
T_FLOW = Union[Tensor, np.ndarray, None]
|
|
T_MASK = Union[Tensor, np.ndarray, None]
|
|
T_STEREO_TENSOR = Tuple[Tensor, Tensor]
|
|
T_COLOR_AUG_PARAM = Union[float, Tuple[float, float]]
|
|
|
|
|
|
def rand_float_range(size: Sequence[int], low: float, high: float) -> Tensor:
|
|
return (low - high) * torch.rand(size) + high
|
|
|
|
|
|
class InterpolationStrategy:
|
|
|
|
_valid_modes: List[str] = ["mixed", "bicubic", "bilinear"]
|
|
|
|
def __init__(self, mode: str = "mixed") -> None:
|
|
if mode not in self._valid_modes:
|
|
raise ValueError(f"Invalid interpolation mode: {mode}. Valid modes are: {self._valid_modes}")
|
|
|
|
if mode == "mixed":
|
|
self.strategies = [F.InterpolationMode.BILINEAR, F.InterpolationMode.BICUBIC]
|
|
elif mode == "bicubic":
|
|
self.strategies = [F.InterpolationMode.BICUBIC]
|
|
elif mode == "bilinear":
|
|
self.strategies = [F.InterpolationMode.BILINEAR]
|
|
|
|
def __call__(self) -> F.InterpolationMode:
|
|
return random.choice(self.strategies)
|
|
|
|
@classmethod
|
|
def is_valid(mode: str) -> bool:
|
|
return mode in InterpolationStrategy._valid_modes
|
|
|
|
@property
|
|
def valid_modes() -> List[str]:
|
|
return InterpolationStrategy._valid_modes
|
|
|
|
|
|
class ValidateModelInput(torch.nn.Module):
|
|
# Pass-through transform that checks the shape and dtypes to make sure the model gets what it expects
|
|
def forward(self, images: T_STEREO_TENSOR, disparities: T_FLOW, masks: T_MASK):
|
|
if images[0].shape != images[1].shape:
|
|
raise ValueError("img1 and img2 should have the same shape.")
|
|
h, w = images[0].shape[-2:]
|
|
if disparities[0] is not None and disparities[0].shape != (1, h, w):
|
|
raise ValueError(f"disparities[0].shape should be (1, {h}, {w}) instead of {disparities[0].shape}")
|
|
if masks[0] is not None:
|
|
if masks[0].shape != (h, w):
|
|
raise ValueError(f"masks[0].shape should be ({h}, {w}) instead of {masks[0].shape}")
|
|
if masks[0].dtype != torch.bool:
|
|
raise TypeError(f"masks[0] should be of dtype torch.bool instead of {masks[0].dtype}")
|
|
|
|
return images, disparities, masks
|
|
|
|
|
|
class ConvertToGrayscale(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(
|
|
self,
|
|
images: Tuple[PIL.Image.Image, PIL.Image.Image],
|
|
disparities: Tuple[T_FLOW, T_FLOW],
|
|
masks: Tuple[T_MASK, T_MASK],
|
|
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
|
|
img_left = F.rgb_to_grayscale(images[0], num_output_channels=3)
|
|
img_right = F.rgb_to_grayscale(images[1], num_output_channels=3)
|
|
|
|
return (img_left, img_right), disparities, masks
|
|
|
|
|
|
class MakeValidDisparityMask(torch.nn.Module):
|
|
def __init__(self, max_disparity: Optional[int] = 256) -> None:
|
|
super().__init__()
|
|
self.max_disparity = max_disparity
|
|
|
|
def forward(
|
|
self,
|
|
images: T_STEREO_TENSOR,
|
|
disparities: Tuple[T_FLOW, T_FLOW],
|
|
masks: Tuple[T_MASK, T_MASK],
|
|
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
|
|
valid_masks = tuple(
|
|
torch.ones(images[idx].shape[-2:], dtype=torch.bool, device=images[idx].device) if mask is None else mask
|
|
for idx, mask in enumerate(masks)
|
|
)
|
|
|
|
valid_masks = tuple(
|
|
torch.logical_and(mask, disparity > 0).squeeze(0) if disparity is not None else mask
|
|
for mask, disparity in zip(valid_masks, disparities)
|
|
)
|
|
|
|
if self.max_disparity is not None:
|
|
valid_masks = tuple(
|
|
torch.logical_and(mask, disparity < self.max_disparity).squeeze(0) if disparity is not None else mask
|
|
for mask, disparity in zip(valid_masks, disparities)
|
|
)
|
|
|
|
return images, disparities, valid_masks
|
|
|
|
|
|
class ToGPU(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(
|
|
self,
|
|
images: T_STEREO_TENSOR,
|
|
disparities: Tuple[T_FLOW, T_FLOW],
|
|
masks: Tuple[T_MASK, T_MASK],
|
|
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
|
|
dev_images = tuple(image.cuda() for image in images)
|
|
dev_disparities = tuple(map(lambda x: x.cuda() if x is not None else None, disparities))
|
|
dev_masks = tuple(map(lambda x: x.cuda() if x is not None else None, masks))
|
|
return dev_images, dev_disparities, dev_masks
|
|
|
|
|
|
class ConvertImageDtype(torch.nn.Module):
|
|
def __init__(self, dtype: torch.dtype):
|
|
super().__init__()
|
|
self.dtype = dtype
|
|
|
|
def forward(
|
|
self,
|
|
images: T_STEREO_TENSOR,
|
|
disparities: Tuple[T_FLOW, T_FLOW],
|
|
masks: Tuple[T_MASK, T_MASK],
|
|
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
|
|
img_left = F.convert_image_dtype(images[0], dtype=self.dtype)
|
|
img_right = F.convert_image_dtype(images[1], dtype=self.dtype)
|
|
|
|
img_left = img_left.contiguous()
|
|
img_right = img_right.contiguous()
|
|
|
|
return (img_left, img_right), disparities, masks
|
|
|
|
|
|
class Normalize(torch.nn.Module):
|
|
def __init__(self, mean: List[float], std: List[float]) -> None:
|
|
super().__init__()
|
|
self.mean = mean
|
|
self.std = std
|
|
|
|
def forward(
|
|
self,
|
|
images: T_STEREO_TENSOR,
|
|
disparities: Tuple[T_FLOW, T_FLOW],
|
|
masks: Tuple[T_MASK, T_MASK],
|
|
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
|
|
|
|
img_left = F.normalize(images[0], mean=self.mean, std=self.std)
|
|
img_right = F.normalize(images[1], mean=self.mean, std=self.std)
|
|
|
|
img_left = img_left.contiguous()
|
|
img_right = img_right.contiguous()
|
|
|
|
return (img_left, img_right), disparities, masks
|
|
|
|
|
|
class ToTensor(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
images: Tuple[PIL.Image.Image, PIL.Image.Image],
|
|
disparities: Tuple[T_FLOW, T_FLOW],
|
|
masks: Tuple[T_MASK, T_MASK],
|
|
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
|
|
if images[0] is None:
|
|
raise ValueError("img_left is None")
|
|
if images[1] is None:
|
|
raise ValueError("img_right is None")
|
|
|
|
img_left = F.pil_to_tensor(images[0])
|
|
img_right = F.pil_to_tensor(images[1])
|
|
disparity_tensors = ()
|
|
mask_tensors = ()
|
|
|
|
for idx in range(2):
|
|
disparity_tensors += (torch.from_numpy(disparities[idx]),) if disparities[idx] is not None else (None,)
|
|
mask_tensors += (torch.from_numpy(masks[idx]),) if masks[idx] is not None else (None,)
|
|
|
|
return (img_left, img_right), disparity_tensors, mask_tensors
|
|
|
|
|
|
class AsymmetricColorJitter(T.ColorJitter):
|
|
# p determines the probability of doing asymmetric vs symmetric color jittering
|
|
def __init__(
|
|
self,
|
|
brightness: T_COLOR_AUG_PARAM = 0,
|
|
contrast: T_COLOR_AUG_PARAM = 0,
|
|
saturation: T_COLOR_AUG_PARAM = 0,
|
|
hue: T_COLOR_AUG_PARAM = 0,
|
|
p: float = 0.2,
|
|
):
|
|
super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
|
|
self.p = p
|
|
|
|
def forward(
|
|
self,
|
|
images: T_STEREO_TENSOR,
|
|
disparities: Tuple[T_FLOW, T_FLOW],
|
|
masks: Tuple[T_MASK, T_MASK],
|
|
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
|
|
|
|
if torch.rand(1) < self.p:
|
|
# asymmetric: different transform for img1 and img2
|
|
img_left = super().forward(images[0])
|
|
img_right = super().forward(images[1])
|
|
else:
|
|
# symmetric: same transform for img1 and img2
|
|
batch = torch.stack(images)
|
|
batch = super().forward(batch)
|
|
img_left, img_right = batch[0], batch[1]
|
|
|
|
return (img_left, img_right), disparities, masks
|
|
|
|
|
|
class AsymetricGammaAdjust(torch.nn.Module):
|
|
def __init__(self, p: float, gamma_range: Tuple[float, float], gain: float = 1) -> None:
|
|
super().__init__()
|
|
self.gamma_range = gamma_range
|
|
self.gain = gain
|
|
self.p = p
|
|
|
|
def forward(
|
|
self,
|
|
images: T_STEREO_TENSOR,
|
|
disparities: Tuple[T_FLOW, T_FLOW],
|
|
masks: Tuple[T_MASK, T_MASK],
|
|
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
|
|
|
|
gamma = rand_float_range((1,), low=self.gamma_range[0], high=self.gamma_range[1]).item()
|
|
|
|
if torch.rand(1) < self.p:
|
|
# asymmetric: different transform for img1 and img2
|
|
img_left = F.adjust_gamma(images[0], gamma, gain=self.gain)
|
|
img_right = F.adjust_gamma(images[1], gamma, gain=self.gain)
|
|
else:
|
|
# symmetric: same transform for img1 and img2
|
|
batch = torch.stack(images)
|
|
batch = F.adjust_gamma(batch, gamma, gain=self.gain)
|
|
img_left, img_right = batch[0], batch[1]
|
|
|
|
return (img_left, img_right), disparities, masks
|
|
|
|
|
|
class RandomErase(torch.nn.Module):
|
|
# Produces multiple symmetric random erasures
|
|
# these can be viewed as occlusions present in both camera views.
|
|
# Similarly to Optical Flow occlusion prediction tasks, we mask these pixels in the disparity map
|
|
def __init__(
|
|
self,
|
|
p: float = 0.5,
|
|
erase_px_range: Tuple[int, int] = (50, 100),
|
|
value: Union[Tensor, float] = 0,
|
|
inplace: bool = False,
|
|
max_erase: int = 2,
|
|
):
|
|
super().__init__()
|
|
self.min_px_erase = erase_px_range[0]
|
|
self.max_px_erase = erase_px_range[1]
|
|
if self.max_px_erase < 0:
|
|
raise ValueError("erase_px_range[1] should be equal or greater than 0")
|
|
if self.min_px_erase < 0:
|
|
raise ValueError("erase_px_range[0] should be equal or greater than 0")
|
|
if self.min_px_erase > self.max_px_erase:
|
|
raise ValueError("erase_prx_range[0] should be equal or lower than erase_px_range[1]")
|
|
|
|
self.p = p
|
|
self.value = value
|
|
self.inplace = inplace
|
|
self.max_erase = max_erase
|
|
|
|
def forward(
|
|
self,
|
|
images: T_STEREO_TENSOR,
|
|
disparities: T_STEREO_TENSOR,
|
|
masks: T_STEREO_TENSOR,
|
|
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
|
|
|
|
if torch.rand(1) < self.p:
|
|
return images, disparities, masks
|
|
|
|
image_left, image_right = images
|
|
mask_left, mask_right = masks
|
|
for _ in range(torch.randint(self.max_erase, size=(1,)).item()):
|
|
y, x, h, w, v = self._get_params(image_left)
|
|
image_right = F.erase(image_right, y, x, h, w, v, self.inplace)
|
|
image_left = F.erase(image_left, y, x, h, w, v, self.inplace)
|
|
# similarly to optical flow occlusion prediction, we consider
|
|
# any erasure pixels that are in both images to be occluded therefore
|
|
# we mark them as invalid
|
|
if mask_left is not None:
|
|
mask_left = F.erase(mask_left, y, x, h, w, False, self.inplace)
|
|
if mask_right is not None:
|
|
mask_right = F.erase(mask_right, y, x, h, w, False, self.inplace)
|
|
|
|
return (image_left, image_right), disparities, (mask_left, mask_right)
|
|
|
|
def _get_params(self, img: torch.Tensor) -> Tuple[int, int, int, int, float]:
|
|
img_h, img_w = img.shape[-2:]
|
|
crop_h, crop_w = (
|
|
random.randint(self.min_px_erase, self.max_px_erase),
|
|
random.randint(self.min_px_erase, self.max_px_erase),
|
|
)
|
|
crop_x, crop_y = (random.randint(0, img_w - crop_w), random.randint(0, img_h - crop_h))
|
|
|
|
return crop_y, crop_x, crop_h, crop_w, self.value
|
|
|
|
|
|
class RandomOcclusion(torch.nn.Module):
|
|
# This adds an occlusion in the right image
|
|
# the occluded patch works as a patch erase where the erase value is the mean
|
|
# of the pixels from the selected zone
|
|
def __init__(self, p: float = 0.5, occlusion_px_range: Tuple[int, int] = (50, 100), inplace: bool = False):
|
|
super().__init__()
|
|
|
|
self.min_px_occlusion = occlusion_px_range[0]
|
|
self.max_px_occlusion = occlusion_px_range[1]
|
|
|
|
if self.max_px_occlusion < 0:
|
|
raise ValueError("occlusion_px_range[1] should be greater or equal than 0")
|
|
if self.min_px_occlusion < 0:
|
|
raise ValueError("occlusion_px_range[0] should be greater or equal than 0")
|
|
if self.min_px_occlusion > self.max_px_occlusion:
|
|
raise ValueError("occlusion_px_range[0] should be lower than occlusion_px_range[1]")
|
|
|
|
self.p = p
|
|
self.inplace = inplace
|
|
|
|
def forward(
|
|
self,
|
|
images: T_STEREO_TENSOR,
|
|
disparities: T_STEREO_TENSOR,
|
|
masks: T_STEREO_TENSOR,
|
|
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
|
|
|
|
left_image, right_image = images
|
|
|
|
if torch.rand(1) < self.p:
|
|
return images, disparities, masks
|
|
|
|
y, x, h, w, v = self._get_params(right_image)
|
|
right_image = F.erase(right_image, y, x, h, w, v, self.inplace)
|
|
|
|
return ((left_image, right_image), disparities, masks)
|
|
|
|
def _get_params(self, img: torch.Tensor) -> Tuple[int, int, int, int, float]:
|
|
img_h, img_w = img.shape[-2:]
|
|
crop_h, crop_w = (
|
|
random.randint(self.min_px_occlusion, self.max_px_occlusion),
|
|
random.randint(self.min_px_occlusion, self.max_px_occlusion),
|
|
)
|
|
|
|
crop_x, crop_y = (random.randint(0, img_w - crop_w), random.randint(0, img_h - crop_h))
|
|
occlusion_value = img[..., crop_y : crop_y + crop_h, crop_x : crop_x + crop_w].mean(dim=(-2, -1), keepdim=True)
|
|
|
|
return (crop_y, crop_x, crop_h, crop_w, occlusion_value)
|
|
|
|
|
|
class RandomSpatialShift(torch.nn.Module):
|
|
# This transform applies a vertical shift and a slight angle rotation and the same time
|
|
def __init__(
|
|
self, p: float = 0.5, max_angle: float = 0.1, max_px_shift: int = 2, interpolation_type: str = "bilinear"
|
|
) -> None:
|
|
super().__init__()
|
|
self.p = p
|
|
self.max_angle = max_angle
|
|
self.max_px_shift = max_px_shift
|
|
self._interpolation_mode_strategy = InterpolationStrategy(interpolation_type)
|
|
|
|
def forward(
|
|
self,
|
|
images: T_STEREO_TENSOR,
|
|
disparities: T_STEREO_TENSOR,
|
|
masks: T_STEREO_TENSOR,
|
|
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
|
|
# the transform is applied only on the right image
|
|
# in order to mimic slight calibration issues
|
|
img_left, img_right = images
|
|
|
|
INTERP_MODE = self._interpolation_mode_strategy()
|
|
|
|
if torch.rand(1) < self.p:
|
|
# [0, 1] -> [-a, a]
|
|
shift = rand_float_range((1,), low=-self.max_px_shift, high=self.max_px_shift).item()
|
|
angle = rand_float_range((1,), low=-self.max_angle, high=self.max_angle).item()
|
|
# sample center point for the rotation matrix
|
|
y = torch.randint(size=(1,), low=0, high=img_right.shape[-2]).item()
|
|
x = torch.randint(size=(1,), low=0, high=img_right.shape[-1]).item()
|
|
# apply affine transformations
|
|
img_right = F.affine(
|
|
img_right,
|
|
angle=angle,
|
|
translate=[0, shift], # translation only on the y-axis
|
|
center=[x, y],
|
|
scale=1.0,
|
|
shear=0.0,
|
|
interpolation=INTERP_MODE,
|
|
)
|
|
|
|
return ((img_left, img_right), disparities, masks)
|
|
|
|
|
|
class RandomHorizontalFlip(torch.nn.Module):
|
|
def __init__(self, p: float = 0.5) -> None:
|
|
super().__init__()
|
|
self.p = p
|
|
|
|
def forward(
|
|
self,
|
|
images: T_STEREO_TENSOR,
|
|
disparities: Tuple[T_FLOW, T_FLOW],
|
|
masks: Tuple[T_MASK, T_MASK],
|
|
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
|
|
|
|
img_left, img_right = images
|
|
dsp_left, dsp_right = disparities
|
|
mask_left, mask_right = masks
|
|
|
|
if dsp_right is not None and torch.rand(1) < self.p:
|
|
img_left, img_right = F.hflip(img_left), F.hflip(img_right)
|
|
dsp_left, dsp_right = F.hflip(dsp_left), F.hflip(dsp_right)
|
|
if mask_left is not None and mask_right is not None:
|
|
mask_left, mask_right = F.hflip(mask_left), F.hflip(mask_right)
|
|
return ((img_right, img_left), (dsp_right, dsp_left), (mask_right, mask_left))
|
|
|
|
return images, disparities, masks
|
|
|
|
|
|
class Resize(torch.nn.Module):
|
|
def __init__(self, resize_size: Tuple[int, ...], interpolation_type: str = "bilinear") -> None:
|
|
super().__init__()
|
|
self.resize_size = list(resize_size) # doing this to keep mypy happy
|
|
self._interpolation_mode_strategy = InterpolationStrategy(interpolation_type)
|
|
|
|
def forward(
|
|
self,
|
|
images: T_STEREO_TENSOR,
|
|
disparities: Tuple[T_FLOW, T_FLOW],
|
|
masks: Tuple[T_MASK, T_MASK],
|
|
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
|
|
resized_images = ()
|
|
resized_disparities = ()
|
|
resized_masks = ()
|
|
|
|
INTERP_MODE = self._interpolation_mode_strategy()
|
|
|
|
for img in images:
|
|
# We hard-code antialias=False to preserve results after we changed
|
|
# its default from None to True (see
|
|
# https://github.com/pytorch/vision/pull/7160)
|
|
# TODO: we could re-train the stereo models with antialias=True?
|
|
resized_images += (F.resize(img, self.resize_size, interpolation=INTERP_MODE, antialias=False),)
|
|
|
|
for dsp in disparities:
|
|
if dsp is not None:
|
|
# rescale disparity to match the new image size
|
|
scale_x = self.resize_size[1] / dsp.shape[-1]
|
|
resized_disparities += (F.resize(dsp, self.resize_size, interpolation=INTERP_MODE) * scale_x,)
|
|
else:
|
|
resized_disparities += (None,)
|
|
|
|
for mask in masks:
|
|
if mask is not None:
|
|
resized_masks += (
|
|
# we squeeze and unsqueeze because the API requires > 3D tensors
|
|
F.resize(
|
|
mask.unsqueeze(0),
|
|
self.resize_size,
|
|
interpolation=F.InterpolationMode.NEAREST,
|
|
).squeeze(0),
|
|
)
|
|
else:
|
|
resized_masks += (None,)
|
|
|
|
return resized_images, resized_disparities, resized_masks
|
|
|
|
|
|
class RandomRescaleAndCrop(torch.nn.Module):
|
|
# This transform will resize the input with a given proba, and then crop it.
|
|
# These are the reversed operations of the built-in RandomResizedCrop,
|
|
# although the order of the operations doesn't matter too much: resizing a
|
|
# crop would give the same result as cropping a resized image, up to
|
|
# interpolation artifact at the borders of the output.
|
|
#
|
|
# The reason we don't rely on RandomResizedCrop is because of a significant
|
|
# difference in the parametrization of both transforms, in particular,
|
|
# because of the way the random parameters are sampled in both transforms,
|
|
# which leads to fairly different results (and different epe). For more details see
|
|
# https://github.com/pytorch/vision/pull/5026/files#r762932579
|
|
def __init__(
|
|
self,
|
|
crop_size: Tuple[int, int],
|
|
scale_range: Tuple[float, float] = (-0.2, 0.5),
|
|
rescale_prob: float = 0.8,
|
|
scaling_type: str = "exponential",
|
|
interpolation_type: str = "bilinear",
|
|
) -> None:
|
|
super().__init__()
|
|
self.crop_size = crop_size
|
|
self.min_scale = scale_range[0]
|
|
self.max_scale = scale_range[1]
|
|
self.rescale_prob = rescale_prob
|
|
self.scaling_type = scaling_type
|
|
self._interpolation_mode_strategy = InterpolationStrategy(interpolation_type)
|
|
|
|
if self.scaling_type == "linear" and self.min_scale < 0:
|
|
raise ValueError("min_scale must be >= 0 for linear scaling")
|
|
|
|
def forward(
|
|
self,
|
|
images: T_STEREO_TENSOR,
|
|
disparities: Tuple[T_FLOW, T_FLOW],
|
|
masks: Tuple[T_MASK, T_MASK],
|
|
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
|
|
|
|
img_left, img_right = images
|
|
dsp_left, dsp_right = disparities
|
|
mask_left, mask_right = masks
|
|
INTERP_MODE = self._interpolation_mode_strategy()
|
|
|
|
# randomly sample scale
|
|
h, w = img_left.shape[-2:]
|
|
# Note: in original code, they use + 1 instead of + 8 for sparse datasets (e.g. Kitti)
|
|
# It shouldn't matter much
|
|
min_scale = max((self.crop_size[0] + 8) / h, (self.crop_size[1] + 8) / w)
|
|
|
|
# exponential scaling will draw a random scale in (min_scale, max_scale) and then raise
|
|
# 2 to the power of that random value. This final scale distribution will have a different
|
|
# mean and variance than a uniform distribution. Note that a scale of 1 will result in
|
|
# a rescaling of 2X the original size, whereas a scale of -1 will result in a rescaling
|
|
# of 0.5X the original size.
|
|
if self.scaling_type == "exponential":
|
|
scale = 2 ** torch.empty(1, dtype=torch.float32).uniform_(self.min_scale, self.max_scale).item()
|
|
# linear scaling will draw a random scale in (min_scale, max_scale)
|
|
elif self.scaling_type == "linear":
|
|
scale = torch.empty(1, dtype=torch.float32).uniform_(self.min_scale, self.max_scale).item()
|
|
|
|
scale = max(scale, min_scale)
|
|
|
|
new_h, new_w = round(h * scale), round(w * scale)
|
|
|
|
if torch.rand(1).item() < self.rescale_prob:
|
|
# rescale the images
|
|
img_left = F.resize(img_left, size=(new_h, new_w), interpolation=INTERP_MODE)
|
|
img_right = F.resize(img_right, size=(new_h, new_w), interpolation=INTERP_MODE)
|
|
|
|
resized_masks, resized_disparities = (), ()
|
|
|
|
for disparity, mask in zip(disparities, masks):
|
|
if disparity is not None:
|
|
if mask is None:
|
|
resized_disparity = F.resize(disparity, size=(new_h, new_w), interpolation=INTERP_MODE)
|
|
# rescale the disparity
|
|
resized_disparity = (
|
|
resized_disparity * torch.tensor([scale], device=resized_disparity.device)[:, None, None]
|
|
)
|
|
resized_mask = None
|
|
else:
|
|
resized_disparity, resized_mask = _resize_sparse_flow(
|
|
disparity, mask, scale_x=scale, scale_y=scale
|
|
)
|
|
resized_masks += (resized_mask,)
|
|
resized_disparities += (resized_disparity,)
|
|
|
|
else:
|
|
resized_disparities = disparities
|
|
resized_masks = masks
|
|
|
|
disparities = resized_disparities
|
|
masks = resized_masks
|
|
|
|
# Note: For sparse datasets (Kitti), the original code uses a "margin"
|
|
# See e.g. https://github.com/princeton-vl/RAFT/blob/master/core/utils/augmentor.py#L220:L220
|
|
# We don't, not sure if it matters much
|
|
y0 = torch.randint(0, img_left.shape[1] - self.crop_size[0], size=(1,)).item()
|
|
x0 = torch.randint(0, img_right.shape[2] - self.crop_size[1], size=(1,)).item()
|
|
|
|
img_left = F.crop(img_left, y0, x0, self.crop_size[0], self.crop_size[1])
|
|
img_right = F.crop(img_right, y0, x0, self.crop_size[0], self.crop_size[1])
|
|
if dsp_left is not None:
|
|
dsp_left = F.crop(disparities[0], y0, x0, self.crop_size[0], self.crop_size[1])
|
|
if dsp_right is not None:
|
|
dsp_right = F.crop(disparities[1], y0, x0, self.crop_size[0], self.crop_size[1])
|
|
|
|
cropped_masks = ()
|
|
for mask in masks:
|
|
if mask is not None:
|
|
mask = F.crop(mask, y0, x0, self.crop_size[0], self.crop_size[1])
|
|
cropped_masks += (mask,)
|
|
|
|
return ((img_left, img_right), (dsp_left, dsp_right), cropped_masks)
|
|
|
|
|
|
def _resize_sparse_flow(
|
|
flow: Tensor, valid_flow_mask: Tensor, scale_x: float = 1.0, scale_y: float = 0.0
|
|
) -> Tuple[Tensor, Tensor]:
|
|
# This resizes both the flow and the valid_flow_mask mask (which is assumed to be reasonably sparse)
|
|
# There are as-many non-zero values in the original flow as in the resized flow (up to OOB)
|
|
# So for example if scale_x = scale_y = 2, the sparsity of the output flow is multiplied by 4
|
|
|
|
h, w = flow.shape[-2:]
|
|
|
|
h_new = int(round(h * scale_y))
|
|
w_new = int(round(w * scale_x))
|
|
flow_new = torch.zeros(size=[1, h_new, w_new], dtype=flow.dtype)
|
|
valid_new = torch.zeros(size=[h_new, w_new], dtype=valid_flow_mask.dtype)
|
|
|
|
jj, ii = torch.meshgrid(torch.arange(w), torch.arange(h), indexing="xy")
|
|
|
|
ii_valid, jj_valid = ii[valid_flow_mask], jj[valid_flow_mask]
|
|
|
|
ii_valid_new = torch.round(ii_valid.to(float) * scale_y).to(torch.long)
|
|
jj_valid_new = torch.round(jj_valid.to(float) * scale_x).to(torch.long)
|
|
|
|
within_bounds_mask = (0 <= ii_valid_new) & (ii_valid_new < h_new) & (0 <= jj_valid_new) & (jj_valid_new < w_new)
|
|
|
|
ii_valid = ii_valid[within_bounds_mask]
|
|
jj_valid = jj_valid[within_bounds_mask]
|
|
ii_valid_new = ii_valid_new[within_bounds_mask]
|
|
jj_valid_new = jj_valid_new[within_bounds_mask]
|
|
|
|
valid_flow_new = flow[:, ii_valid, jj_valid]
|
|
valid_flow_new *= scale_x
|
|
|
|
flow_new[:, ii_valid_new, jj_valid_new] = valid_flow_new
|
|
valid_new[ii_valid_new, jj_valid_new] = valid_flow_mask[ii_valid, jj_valid]
|
|
|
|
return flow_new, valid_new.bool()
|
|
|
|
|
|
class Compose(torch.nn.Module):
|
|
def __init__(self, transforms: List[Callable]):
|
|
super().__init__()
|
|
self.transforms = transforms
|
|
|
|
@torch.inference_mode()
|
|
def forward(self, images, disparities, masks):
|
|
for t in self.transforms:
|
|
images, disparities, masks = t(images, disparities, masks)
|
|
return images, disparities, masks
|