sglang_v0.5.2/vision_0.23.0/references/depth/stereo/presets.py

145 lines
5.1 KiB
Python

from typing import Optional, Tuple, Union
import torch
import transforms as T
class StereoMatchingEvalPreset(torch.nn.Module):
def __init__(
self,
mean: float = 0.5,
std: float = 0.5,
resize_size: Optional[Tuple[int, ...]] = None,
max_disparity: Optional[float] = None,
interpolation_type: str = "bilinear",
use_grayscale: bool = False,
) -> None:
super().__init__()
transforms = [
T.ToTensor(),
T.ConvertImageDtype(torch.float32),
]
if use_grayscale:
transforms.append(T.ConvertToGrayscale())
if resize_size is not None:
transforms.append(T.Resize(resize_size, interpolation_type=interpolation_type))
transforms.extend(
[
T.Normalize(mean=mean, std=std),
T.MakeValidDisparityMask(max_disparity=max_disparity),
T.ValidateModelInput(),
]
)
self.transforms = T.Compose(transforms)
def forward(self, images, disparities, masks):
return self.transforms(images, disparities, masks)
class StereoMatchingTrainPreset(torch.nn.Module):
def __init__(
self,
*,
resize_size: Optional[Tuple[int, ...]],
resize_interpolation_type: str = "bilinear",
# RandomResizeAndCrop params
crop_size: Tuple[int, int],
rescale_prob: float = 1.0,
scaling_type: str = "exponential",
scale_range: Tuple[float, float] = (-0.2, 0.5),
scale_interpolation_type: str = "bilinear",
# convert to grayscale
use_grayscale: bool = False,
# normalization params
mean: float = 0.5,
std: float = 0.5,
# processing device
gpu_transforms: bool = False,
# masking
max_disparity: Optional[int] = 256,
# SpatialShift params
spatial_shift_prob: float = 0.5,
spatial_shift_max_angle: float = 0.5,
spatial_shift_max_displacement: float = 0.5,
spatial_shift_interpolation_type: str = "bilinear",
# AssymetricColorJitter
gamma_range: Tuple[float, float] = (0.8, 1.2),
brightness: Union[int, Tuple[int, int]] = (0.8, 1.2),
contrast: Union[int, Tuple[int, int]] = (0.8, 1.2),
saturation: Union[int, Tuple[int, int]] = 0.0,
hue: Union[int, Tuple[int, int]] = 0.0,
asymmetric_jitter_prob: float = 1.0,
# RandomHorizontalFlip
horizontal_flip_prob: float = 0.5,
# RandomOcclusion
occlusion_prob: float = 0.0,
occlusion_px_range: Tuple[int, int] = (50, 100),
# RandomErase
erase_prob: float = 0.0,
erase_px_range: Tuple[int, int] = (50, 100),
erase_num_repeats: int = 1,
) -> None:
if scaling_type not in ["linear", "exponential"]:
raise ValueError(f"Unknown scaling type: {scaling_type}. Available types: linear, exponential")
super().__init__()
transforms = [T.ToTensor()]
# when fixing size across multiple datasets, we ensure
# that the same size is used for all datasets when cropping
if resize_size is not None:
transforms.append(T.Resize(resize_size, interpolation_type=resize_interpolation_type))
if gpu_transforms:
transforms.append(T.ToGPU())
# color handling
color_transforms = [
T.AsymmetricColorJitter(
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue, p=asymmetric_jitter_prob
),
T.AsymetricGammaAdjust(p=asymmetric_jitter_prob, gamma_range=gamma_range),
]
if use_grayscale:
color_transforms.append(T.ConvertToGrayscale())
transforms.extend(color_transforms)
transforms.extend(
[
T.RandomSpatialShift(
p=spatial_shift_prob,
max_angle=spatial_shift_max_angle,
max_px_shift=spatial_shift_max_displacement,
interpolation_type=spatial_shift_interpolation_type,
),
T.ConvertImageDtype(torch.float32),
T.RandomRescaleAndCrop(
crop_size=crop_size,
scale_range=scale_range,
rescale_prob=rescale_prob,
scaling_type=scaling_type,
interpolation_type=scale_interpolation_type,
),
T.RandomHorizontalFlip(horizontal_flip_prob),
# occlusion after flip, otherwise we're occluding the reference image
T.RandomOcclusion(p=occlusion_prob, occlusion_px_range=occlusion_px_range),
T.RandomErase(p=erase_prob, erase_px_range=erase_px_range, max_erase=erase_num_repeats),
T.Normalize(mean=mean, std=std),
T.MakeValidDisparityMask(max_disparity),
T.ValidateModelInput(),
]
)
self.transforms = T.Compose(transforms)
def forward(self, images, disparties, mask):
return self.transforms(images, disparties, mask)