145 lines
5.1 KiB
Python
145 lines
5.1 KiB
Python
from typing import Optional, Tuple, Union
|
|
|
|
import torch
|
|
import transforms as T
|
|
|
|
|
|
class StereoMatchingEvalPreset(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
mean: float = 0.5,
|
|
std: float = 0.5,
|
|
resize_size: Optional[Tuple[int, ...]] = None,
|
|
max_disparity: Optional[float] = None,
|
|
interpolation_type: str = "bilinear",
|
|
use_grayscale: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
transforms = [
|
|
T.ToTensor(),
|
|
T.ConvertImageDtype(torch.float32),
|
|
]
|
|
|
|
if use_grayscale:
|
|
transforms.append(T.ConvertToGrayscale())
|
|
|
|
if resize_size is not None:
|
|
transforms.append(T.Resize(resize_size, interpolation_type=interpolation_type))
|
|
|
|
transforms.extend(
|
|
[
|
|
T.Normalize(mean=mean, std=std),
|
|
T.MakeValidDisparityMask(max_disparity=max_disparity),
|
|
T.ValidateModelInput(),
|
|
]
|
|
)
|
|
|
|
self.transforms = T.Compose(transforms)
|
|
|
|
def forward(self, images, disparities, masks):
|
|
return self.transforms(images, disparities, masks)
|
|
|
|
|
|
class StereoMatchingTrainPreset(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
resize_size: Optional[Tuple[int, ...]],
|
|
resize_interpolation_type: str = "bilinear",
|
|
# RandomResizeAndCrop params
|
|
crop_size: Tuple[int, int],
|
|
rescale_prob: float = 1.0,
|
|
scaling_type: str = "exponential",
|
|
scale_range: Tuple[float, float] = (-0.2, 0.5),
|
|
scale_interpolation_type: str = "bilinear",
|
|
# convert to grayscale
|
|
use_grayscale: bool = False,
|
|
# normalization params
|
|
mean: float = 0.5,
|
|
std: float = 0.5,
|
|
# processing device
|
|
gpu_transforms: bool = False,
|
|
# masking
|
|
max_disparity: Optional[int] = 256,
|
|
# SpatialShift params
|
|
spatial_shift_prob: float = 0.5,
|
|
spatial_shift_max_angle: float = 0.5,
|
|
spatial_shift_max_displacement: float = 0.5,
|
|
spatial_shift_interpolation_type: str = "bilinear",
|
|
# AssymetricColorJitter
|
|
gamma_range: Tuple[float, float] = (0.8, 1.2),
|
|
brightness: Union[int, Tuple[int, int]] = (0.8, 1.2),
|
|
contrast: Union[int, Tuple[int, int]] = (0.8, 1.2),
|
|
saturation: Union[int, Tuple[int, int]] = 0.0,
|
|
hue: Union[int, Tuple[int, int]] = 0.0,
|
|
asymmetric_jitter_prob: float = 1.0,
|
|
# RandomHorizontalFlip
|
|
horizontal_flip_prob: float = 0.5,
|
|
# RandomOcclusion
|
|
occlusion_prob: float = 0.0,
|
|
occlusion_px_range: Tuple[int, int] = (50, 100),
|
|
# RandomErase
|
|
erase_prob: float = 0.0,
|
|
erase_px_range: Tuple[int, int] = (50, 100),
|
|
erase_num_repeats: int = 1,
|
|
) -> None:
|
|
|
|
if scaling_type not in ["linear", "exponential"]:
|
|
raise ValueError(f"Unknown scaling type: {scaling_type}. Available types: linear, exponential")
|
|
|
|
super().__init__()
|
|
transforms = [T.ToTensor()]
|
|
|
|
# when fixing size across multiple datasets, we ensure
|
|
# that the same size is used for all datasets when cropping
|
|
if resize_size is not None:
|
|
transforms.append(T.Resize(resize_size, interpolation_type=resize_interpolation_type))
|
|
|
|
if gpu_transforms:
|
|
transforms.append(T.ToGPU())
|
|
|
|
# color handling
|
|
color_transforms = [
|
|
T.AsymmetricColorJitter(
|
|
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue, p=asymmetric_jitter_prob
|
|
),
|
|
T.AsymetricGammaAdjust(p=asymmetric_jitter_prob, gamma_range=gamma_range),
|
|
]
|
|
|
|
if use_grayscale:
|
|
color_transforms.append(T.ConvertToGrayscale())
|
|
|
|
transforms.extend(color_transforms)
|
|
|
|
transforms.extend(
|
|
[
|
|
T.RandomSpatialShift(
|
|
p=spatial_shift_prob,
|
|
max_angle=spatial_shift_max_angle,
|
|
max_px_shift=spatial_shift_max_displacement,
|
|
interpolation_type=spatial_shift_interpolation_type,
|
|
),
|
|
T.ConvertImageDtype(torch.float32),
|
|
T.RandomRescaleAndCrop(
|
|
crop_size=crop_size,
|
|
scale_range=scale_range,
|
|
rescale_prob=rescale_prob,
|
|
scaling_type=scaling_type,
|
|
interpolation_type=scale_interpolation_type,
|
|
),
|
|
T.RandomHorizontalFlip(horizontal_flip_prob),
|
|
# occlusion after flip, otherwise we're occluding the reference image
|
|
T.RandomOcclusion(p=occlusion_prob, occlusion_px_range=occlusion_px_range),
|
|
T.RandomErase(p=erase_prob, erase_px_range=erase_px_range, max_erase=erase_num_repeats),
|
|
T.Normalize(mean=mean, std=std),
|
|
T.MakeValidDisparityMask(max_disparity),
|
|
T.ValidateModelInput(),
|
|
]
|
|
)
|
|
|
|
self.transforms = T.Compose(transforms)
|
|
|
|
def forward(self, images, disparties, mask):
|
|
return self.transforms(images, disparties, mask)
|