sglang_v0.5.2/vision_0.23.0/references/depth/stereo/transforms.py

651 lines
26 KiB
Python

import random
from typing import Callable, List, Optional, Sequence, Tuple, Union
import numpy as np
import PIL.Image
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
from torch import Tensor
T_FLOW = Union[Tensor, np.ndarray, None]
T_MASK = Union[Tensor, np.ndarray, None]
T_STEREO_TENSOR = Tuple[Tensor, Tensor]
T_COLOR_AUG_PARAM = Union[float, Tuple[float, float]]
def rand_float_range(size: Sequence[int], low: float, high: float) -> Tensor:
return (low - high) * torch.rand(size) + high
class InterpolationStrategy:
_valid_modes: List[str] = ["mixed", "bicubic", "bilinear"]
def __init__(self, mode: str = "mixed") -> None:
if mode not in self._valid_modes:
raise ValueError(f"Invalid interpolation mode: {mode}. Valid modes are: {self._valid_modes}")
if mode == "mixed":
self.strategies = [F.InterpolationMode.BILINEAR, F.InterpolationMode.BICUBIC]
elif mode == "bicubic":
self.strategies = [F.InterpolationMode.BICUBIC]
elif mode == "bilinear":
self.strategies = [F.InterpolationMode.BILINEAR]
def __call__(self) -> F.InterpolationMode:
return random.choice(self.strategies)
@classmethod
def is_valid(mode: str) -> bool:
return mode in InterpolationStrategy._valid_modes
@property
def valid_modes() -> List[str]:
return InterpolationStrategy._valid_modes
class ValidateModelInput(torch.nn.Module):
# Pass-through transform that checks the shape and dtypes to make sure the model gets what it expects
def forward(self, images: T_STEREO_TENSOR, disparities: T_FLOW, masks: T_MASK):
if images[0].shape != images[1].shape:
raise ValueError("img1 and img2 should have the same shape.")
h, w = images[0].shape[-2:]
if disparities[0] is not None and disparities[0].shape != (1, h, w):
raise ValueError(f"disparities[0].shape should be (1, {h}, {w}) instead of {disparities[0].shape}")
if masks[0] is not None:
if masks[0].shape != (h, w):
raise ValueError(f"masks[0].shape should be ({h}, {w}) instead of {masks[0].shape}")
if masks[0].dtype != torch.bool:
raise TypeError(f"masks[0] should be of dtype torch.bool instead of {masks[0].dtype}")
return images, disparities, masks
class ConvertToGrayscale(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(
self,
images: Tuple[PIL.Image.Image, PIL.Image.Image],
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
img_left = F.rgb_to_grayscale(images[0], num_output_channels=3)
img_right = F.rgb_to_grayscale(images[1], num_output_channels=3)
return (img_left, img_right), disparities, masks
class MakeValidDisparityMask(torch.nn.Module):
def __init__(self, max_disparity: Optional[int] = 256) -> None:
super().__init__()
self.max_disparity = max_disparity
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
valid_masks = tuple(
torch.ones(images[idx].shape[-2:], dtype=torch.bool, device=images[idx].device) if mask is None else mask
for idx, mask in enumerate(masks)
)
valid_masks = tuple(
torch.logical_and(mask, disparity > 0).squeeze(0) if disparity is not None else mask
for mask, disparity in zip(valid_masks, disparities)
)
if self.max_disparity is not None:
valid_masks = tuple(
torch.logical_and(mask, disparity < self.max_disparity).squeeze(0) if disparity is not None else mask
for mask, disparity in zip(valid_masks, disparities)
)
return images, disparities, valid_masks
class ToGPU(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
dev_images = tuple(image.cuda() for image in images)
dev_disparities = tuple(map(lambda x: x.cuda() if x is not None else None, disparities))
dev_masks = tuple(map(lambda x: x.cuda() if x is not None else None, masks))
return dev_images, dev_disparities, dev_masks
class ConvertImageDtype(torch.nn.Module):
def __init__(self, dtype: torch.dtype):
super().__init__()
self.dtype = dtype
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
img_left = F.convert_image_dtype(images[0], dtype=self.dtype)
img_right = F.convert_image_dtype(images[1], dtype=self.dtype)
img_left = img_left.contiguous()
img_right = img_right.contiguous()
return (img_left, img_right), disparities, masks
class Normalize(torch.nn.Module):
def __init__(self, mean: List[float], std: List[float]) -> None:
super().__init__()
self.mean = mean
self.std = std
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
img_left = F.normalize(images[0], mean=self.mean, std=self.std)
img_right = F.normalize(images[1], mean=self.mean, std=self.std)
img_left = img_left.contiguous()
img_right = img_right.contiguous()
return (img_left, img_right), disparities, masks
class ToTensor(torch.nn.Module):
def forward(
self,
images: Tuple[PIL.Image.Image, PIL.Image.Image],
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
if images[0] is None:
raise ValueError("img_left is None")
if images[1] is None:
raise ValueError("img_right is None")
img_left = F.pil_to_tensor(images[0])
img_right = F.pil_to_tensor(images[1])
disparity_tensors = ()
mask_tensors = ()
for idx in range(2):
disparity_tensors += (torch.from_numpy(disparities[idx]),) if disparities[idx] is not None else (None,)
mask_tensors += (torch.from_numpy(masks[idx]),) if masks[idx] is not None else (None,)
return (img_left, img_right), disparity_tensors, mask_tensors
class AsymmetricColorJitter(T.ColorJitter):
# p determines the probability of doing asymmetric vs symmetric color jittering
def __init__(
self,
brightness: T_COLOR_AUG_PARAM = 0,
contrast: T_COLOR_AUG_PARAM = 0,
saturation: T_COLOR_AUG_PARAM = 0,
hue: T_COLOR_AUG_PARAM = 0,
p: float = 0.2,
):
super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
self.p = p
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
if torch.rand(1) < self.p:
# asymmetric: different transform for img1 and img2
img_left = super().forward(images[0])
img_right = super().forward(images[1])
else:
# symmetric: same transform for img1 and img2
batch = torch.stack(images)
batch = super().forward(batch)
img_left, img_right = batch[0], batch[1]
return (img_left, img_right), disparities, masks
class AsymetricGammaAdjust(torch.nn.Module):
def __init__(self, p: float, gamma_range: Tuple[float, float], gain: float = 1) -> None:
super().__init__()
self.gamma_range = gamma_range
self.gain = gain
self.p = p
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
gamma = rand_float_range((1,), low=self.gamma_range[0], high=self.gamma_range[1]).item()
if torch.rand(1) < self.p:
# asymmetric: different transform for img1 and img2
img_left = F.adjust_gamma(images[0], gamma, gain=self.gain)
img_right = F.adjust_gamma(images[1], gamma, gain=self.gain)
else:
# symmetric: same transform for img1 and img2
batch = torch.stack(images)
batch = F.adjust_gamma(batch, gamma, gain=self.gain)
img_left, img_right = batch[0], batch[1]
return (img_left, img_right), disparities, masks
class RandomErase(torch.nn.Module):
# Produces multiple symmetric random erasures
# these can be viewed as occlusions present in both camera views.
# Similarly to Optical Flow occlusion prediction tasks, we mask these pixels in the disparity map
def __init__(
self,
p: float = 0.5,
erase_px_range: Tuple[int, int] = (50, 100),
value: Union[Tensor, float] = 0,
inplace: bool = False,
max_erase: int = 2,
):
super().__init__()
self.min_px_erase = erase_px_range[0]
self.max_px_erase = erase_px_range[1]
if self.max_px_erase < 0:
raise ValueError("erase_px_range[1] should be equal or greater than 0")
if self.min_px_erase < 0:
raise ValueError("erase_px_range[0] should be equal or greater than 0")
if self.min_px_erase > self.max_px_erase:
raise ValueError("erase_prx_range[0] should be equal or lower than erase_px_range[1]")
self.p = p
self.value = value
self.inplace = inplace
self.max_erase = max_erase
def forward(
self,
images: T_STEREO_TENSOR,
disparities: T_STEREO_TENSOR,
masks: T_STEREO_TENSOR,
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
if torch.rand(1) < self.p:
return images, disparities, masks
image_left, image_right = images
mask_left, mask_right = masks
for _ in range(torch.randint(self.max_erase, size=(1,)).item()):
y, x, h, w, v = self._get_params(image_left)
image_right = F.erase(image_right, y, x, h, w, v, self.inplace)
image_left = F.erase(image_left, y, x, h, w, v, self.inplace)
# similarly to optical flow occlusion prediction, we consider
# any erasure pixels that are in both images to be occluded therefore
# we mark them as invalid
if mask_left is not None:
mask_left = F.erase(mask_left, y, x, h, w, False, self.inplace)
if mask_right is not None:
mask_right = F.erase(mask_right, y, x, h, w, False, self.inplace)
return (image_left, image_right), disparities, (mask_left, mask_right)
def _get_params(self, img: torch.Tensor) -> Tuple[int, int, int, int, float]:
img_h, img_w = img.shape[-2:]
crop_h, crop_w = (
random.randint(self.min_px_erase, self.max_px_erase),
random.randint(self.min_px_erase, self.max_px_erase),
)
crop_x, crop_y = (random.randint(0, img_w - crop_w), random.randint(0, img_h - crop_h))
return crop_y, crop_x, crop_h, crop_w, self.value
class RandomOcclusion(torch.nn.Module):
# This adds an occlusion in the right image
# the occluded patch works as a patch erase where the erase value is the mean
# of the pixels from the selected zone
def __init__(self, p: float = 0.5, occlusion_px_range: Tuple[int, int] = (50, 100), inplace: bool = False):
super().__init__()
self.min_px_occlusion = occlusion_px_range[0]
self.max_px_occlusion = occlusion_px_range[1]
if self.max_px_occlusion < 0:
raise ValueError("occlusion_px_range[1] should be greater or equal than 0")
if self.min_px_occlusion < 0:
raise ValueError("occlusion_px_range[0] should be greater or equal than 0")
if self.min_px_occlusion > self.max_px_occlusion:
raise ValueError("occlusion_px_range[0] should be lower than occlusion_px_range[1]")
self.p = p
self.inplace = inplace
def forward(
self,
images: T_STEREO_TENSOR,
disparities: T_STEREO_TENSOR,
masks: T_STEREO_TENSOR,
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
left_image, right_image = images
if torch.rand(1) < self.p:
return images, disparities, masks
y, x, h, w, v = self._get_params(right_image)
right_image = F.erase(right_image, y, x, h, w, v, self.inplace)
return ((left_image, right_image), disparities, masks)
def _get_params(self, img: torch.Tensor) -> Tuple[int, int, int, int, float]:
img_h, img_w = img.shape[-2:]
crop_h, crop_w = (
random.randint(self.min_px_occlusion, self.max_px_occlusion),
random.randint(self.min_px_occlusion, self.max_px_occlusion),
)
crop_x, crop_y = (random.randint(0, img_w - crop_w), random.randint(0, img_h - crop_h))
occlusion_value = img[..., crop_y : crop_y + crop_h, crop_x : crop_x + crop_w].mean(dim=(-2, -1), keepdim=True)
return (crop_y, crop_x, crop_h, crop_w, occlusion_value)
class RandomSpatialShift(torch.nn.Module):
# This transform applies a vertical shift and a slight angle rotation and the same time
def __init__(
self, p: float = 0.5, max_angle: float = 0.1, max_px_shift: int = 2, interpolation_type: str = "bilinear"
) -> None:
super().__init__()
self.p = p
self.max_angle = max_angle
self.max_px_shift = max_px_shift
self._interpolation_mode_strategy = InterpolationStrategy(interpolation_type)
def forward(
self,
images: T_STEREO_TENSOR,
disparities: T_STEREO_TENSOR,
masks: T_STEREO_TENSOR,
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
# the transform is applied only on the right image
# in order to mimic slight calibration issues
img_left, img_right = images
INTERP_MODE = self._interpolation_mode_strategy()
if torch.rand(1) < self.p:
# [0, 1] -> [-a, a]
shift = rand_float_range((1,), low=-self.max_px_shift, high=self.max_px_shift).item()
angle = rand_float_range((1,), low=-self.max_angle, high=self.max_angle).item()
# sample center point for the rotation matrix
y = torch.randint(size=(1,), low=0, high=img_right.shape[-2]).item()
x = torch.randint(size=(1,), low=0, high=img_right.shape[-1]).item()
# apply affine transformations
img_right = F.affine(
img_right,
angle=angle,
translate=[0, shift], # translation only on the y-axis
center=[x, y],
scale=1.0,
shear=0.0,
interpolation=INTERP_MODE,
)
return ((img_left, img_right), disparities, masks)
class RandomHorizontalFlip(torch.nn.Module):
def __init__(self, p: float = 0.5) -> None:
super().__init__()
self.p = p
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
img_left, img_right = images
dsp_left, dsp_right = disparities
mask_left, mask_right = masks
if dsp_right is not None and torch.rand(1) < self.p:
img_left, img_right = F.hflip(img_left), F.hflip(img_right)
dsp_left, dsp_right = F.hflip(dsp_left), F.hflip(dsp_right)
if mask_left is not None and mask_right is not None:
mask_left, mask_right = F.hflip(mask_left), F.hflip(mask_right)
return ((img_right, img_left), (dsp_right, dsp_left), (mask_right, mask_left))
return images, disparities, masks
class Resize(torch.nn.Module):
def __init__(self, resize_size: Tuple[int, ...], interpolation_type: str = "bilinear") -> None:
super().__init__()
self.resize_size = list(resize_size) # doing this to keep mypy happy
self._interpolation_mode_strategy = InterpolationStrategy(interpolation_type)
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
resized_images = ()
resized_disparities = ()
resized_masks = ()
INTERP_MODE = self._interpolation_mode_strategy()
for img in images:
# We hard-code antialias=False to preserve results after we changed
# its default from None to True (see
# https://github.com/pytorch/vision/pull/7160)
# TODO: we could re-train the stereo models with antialias=True?
resized_images += (F.resize(img, self.resize_size, interpolation=INTERP_MODE, antialias=False),)
for dsp in disparities:
if dsp is not None:
# rescale disparity to match the new image size
scale_x = self.resize_size[1] / dsp.shape[-1]
resized_disparities += (F.resize(dsp, self.resize_size, interpolation=INTERP_MODE) * scale_x,)
else:
resized_disparities += (None,)
for mask in masks:
if mask is not None:
resized_masks += (
# we squeeze and unsqueeze because the API requires > 3D tensors
F.resize(
mask.unsqueeze(0),
self.resize_size,
interpolation=F.InterpolationMode.NEAREST,
).squeeze(0),
)
else:
resized_masks += (None,)
return resized_images, resized_disparities, resized_masks
class RandomRescaleAndCrop(torch.nn.Module):
# This transform will resize the input with a given proba, and then crop it.
# These are the reversed operations of the built-in RandomResizedCrop,
# although the order of the operations doesn't matter too much: resizing a
# crop would give the same result as cropping a resized image, up to
# interpolation artifact at the borders of the output.
#
# The reason we don't rely on RandomResizedCrop is because of a significant
# difference in the parametrization of both transforms, in particular,
# because of the way the random parameters are sampled in both transforms,
# which leads to fairly different results (and different epe). For more details see
# https://github.com/pytorch/vision/pull/5026/files#r762932579
def __init__(
self,
crop_size: Tuple[int, int],
scale_range: Tuple[float, float] = (-0.2, 0.5),
rescale_prob: float = 0.8,
scaling_type: str = "exponential",
interpolation_type: str = "bilinear",
) -> None:
super().__init__()
self.crop_size = crop_size
self.min_scale = scale_range[0]
self.max_scale = scale_range[1]
self.rescale_prob = rescale_prob
self.scaling_type = scaling_type
self._interpolation_mode_strategy = InterpolationStrategy(interpolation_type)
if self.scaling_type == "linear" and self.min_scale < 0:
raise ValueError("min_scale must be >= 0 for linear scaling")
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
img_left, img_right = images
dsp_left, dsp_right = disparities
mask_left, mask_right = masks
INTERP_MODE = self._interpolation_mode_strategy()
# randomly sample scale
h, w = img_left.shape[-2:]
# Note: in original code, they use + 1 instead of + 8 for sparse datasets (e.g. Kitti)
# It shouldn't matter much
min_scale = max((self.crop_size[0] + 8) / h, (self.crop_size[1] + 8) / w)
# exponential scaling will draw a random scale in (min_scale, max_scale) and then raise
# 2 to the power of that random value. This final scale distribution will have a different
# mean and variance than a uniform distribution. Note that a scale of 1 will result in
# a rescaling of 2X the original size, whereas a scale of -1 will result in a rescaling
# of 0.5X the original size.
if self.scaling_type == "exponential":
scale = 2 ** torch.empty(1, dtype=torch.float32).uniform_(self.min_scale, self.max_scale).item()
# linear scaling will draw a random scale in (min_scale, max_scale)
elif self.scaling_type == "linear":
scale = torch.empty(1, dtype=torch.float32).uniform_(self.min_scale, self.max_scale).item()
scale = max(scale, min_scale)
new_h, new_w = round(h * scale), round(w * scale)
if torch.rand(1).item() < self.rescale_prob:
# rescale the images
img_left = F.resize(img_left, size=(new_h, new_w), interpolation=INTERP_MODE)
img_right = F.resize(img_right, size=(new_h, new_w), interpolation=INTERP_MODE)
resized_masks, resized_disparities = (), ()
for disparity, mask in zip(disparities, masks):
if disparity is not None:
if mask is None:
resized_disparity = F.resize(disparity, size=(new_h, new_w), interpolation=INTERP_MODE)
# rescale the disparity
resized_disparity = (
resized_disparity * torch.tensor([scale], device=resized_disparity.device)[:, None, None]
)
resized_mask = None
else:
resized_disparity, resized_mask = _resize_sparse_flow(
disparity, mask, scale_x=scale, scale_y=scale
)
resized_masks += (resized_mask,)
resized_disparities += (resized_disparity,)
else:
resized_disparities = disparities
resized_masks = masks
disparities = resized_disparities
masks = resized_masks
# Note: For sparse datasets (Kitti), the original code uses a "margin"
# See e.g. https://github.com/princeton-vl/RAFT/blob/master/core/utils/augmentor.py#L220:L220
# We don't, not sure if it matters much
y0 = torch.randint(0, img_left.shape[1] - self.crop_size[0], size=(1,)).item()
x0 = torch.randint(0, img_right.shape[2] - self.crop_size[1], size=(1,)).item()
img_left = F.crop(img_left, y0, x0, self.crop_size[0], self.crop_size[1])
img_right = F.crop(img_right, y0, x0, self.crop_size[0], self.crop_size[1])
if dsp_left is not None:
dsp_left = F.crop(disparities[0], y0, x0, self.crop_size[0], self.crop_size[1])
if dsp_right is not None:
dsp_right = F.crop(disparities[1], y0, x0, self.crop_size[0], self.crop_size[1])
cropped_masks = ()
for mask in masks:
if mask is not None:
mask = F.crop(mask, y0, x0, self.crop_size[0], self.crop_size[1])
cropped_masks += (mask,)
return ((img_left, img_right), (dsp_left, dsp_right), cropped_masks)
def _resize_sparse_flow(
flow: Tensor, valid_flow_mask: Tensor, scale_x: float = 1.0, scale_y: float = 0.0
) -> Tuple[Tensor, Tensor]:
# This resizes both the flow and the valid_flow_mask mask (which is assumed to be reasonably sparse)
# There are as-many non-zero values in the original flow as in the resized flow (up to OOB)
# So for example if scale_x = scale_y = 2, the sparsity of the output flow is multiplied by 4
h, w = flow.shape[-2:]
h_new = int(round(h * scale_y))
w_new = int(round(w * scale_x))
flow_new = torch.zeros(size=[1, h_new, w_new], dtype=flow.dtype)
valid_new = torch.zeros(size=[h_new, w_new], dtype=valid_flow_mask.dtype)
jj, ii = torch.meshgrid(torch.arange(w), torch.arange(h), indexing="xy")
ii_valid, jj_valid = ii[valid_flow_mask], jj[valid_flow_mask]
ii_valid_new = torch.round(ii_valid.to(float) * scale_y).to(torch.long)
jj_valid_new = torch.round(jj_valid.to(float) * scale_x).to(torch.long)
within_bounds_mask = (0 <= ii_valid_new) & (ii_valid_new < h_new) & (0 <= jj_valid_new) & (jj_valid_new < w_new)
ii_valid = ii_valid[within_bounds_mask]
jj_valid = jj_valid[within_bounds_mask]
ii_valid_new = ii_valid_new[within_bounds_mask]
jj_valid_new = jj_valid_new[within_bounds_mask]
valid_flow_new = flow[:, ii_valid, jj_valid]
valid_flow_new *= scale_x
flow_new[:, ii_valid_new, jj_valid_new] = valid_flow_new
valid_new[ii_valid_new, jj_valid_new] = valid_flow_mask[ii_valid, jj_valid]
return flow_new, valid_new.bool()
class Compose(torch.nn.Module):
def __init__(self, transforms: List[Callable]):
super().__init__()
self.transforms = transforms
@torch.inference_mode()
def forward(self, images, disparities, masks):
for t in self.transforms:
images, disparities, masks = t(images, disparities, masks)
return images, disparities, masks