66 lines
1.9 KiB
Python
66 lines
1.9 KiB
Python
import torch
|
|
import transforms as T
|
|
|
|
|
|
class OpticalFlowPresetEval(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
self.transforms = T.Compose(
|
|
[
|
|
T.PILToTensor(),
|
|
T.ConvertImageDtype(torch.float32),
|
|
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1]
|
|
T.ValidateModelInput(),
|
|
]
|
|
)
|
|
|
|
def forward(self, img1, img2, flow, valid):
|
|
return self.transforms(img1, img2, flow, valid)
|
|
|
|
|
|
class OpticalFlowPresetTrain(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
# RandomResizeAndCrop params
|
|
crop_size,
|
|
min_scale=-0.2,
|
|
max_scale=0.5,
|
|
stretch_prob=0.8,
|
|
# AsymmetricColorJitter params
|
|
brightness=0.4,
|
|
contrast=0.4,
|
|
saturation=0.4,
|
|
hue=0.5 / 3.14,
|
|
# Random[H,V]Flip params
|
|
asymmetric_jitter_prob=0.2,
|
|
do_flip=True,
|
|
):
|
|
super().__init__()
|
|
|
|
transforms = [
|
|
T.PILToTensor(),
|
|
T.AsymmetricColorJitter(
|
|
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue, p=asymmetric_jitter_prob
|
|
),
|
|
T.RandomResizeAndCrop(
|
|
crop_size=crop_size, min_scale=min_scale, max_scale=max_scale, stretch_prob=stretch_prob
|
|
),
|
|
]
|
|
|
|
if do_flip:
|
|
transforms += [T.RandomHorizontalFlip(p=0.5), T.RandomVerticalFlip(p=0.1)]
|
|
|
|
transforms += [
|
|
T.ConvertImageDtype(torch.float32),
|
|
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1]
|
|
T.RandomErasing(max_erase=2),
|
|
T.MakeValidFlowMask(),
|
|
T.ValidateModelInput(),
|
|
]
|
|
self.transforms = T.Compose(transforms)
|
|
|
|
def forward(self, img1, img2, flow, valid):
|
|
return self.transforms(img1, img2, flow, valid)
|