sglang_v0.5.2/vision_0.23.0/references/optical_flow/presets.py

66 lines
1.9 KiB
Python

import torch
import transforms as T
class OpticalFlowPresetEval(torch.nn.Module):
def __init__(self):
super().__init__()
self.transforms = T.Compose(
[
T.PILToTensor(),
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1]
T.ValidateModelInput(),
]
)
def forward(self, img1, img2, flow, valid):
return self.transforms(img1, img2, flow, valid)
class OpticalFlowPresetTrain(torch.nn.Module):
def __init__(
self,
*,
# RandomResizeAndCrop params
crop_size,
min_scale=-0.2,
max_scale=0.5,
stretch_prob=0.8,
# AsymmetricColorJitter params
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.5 / 3.14,
# Random[H,V]Flip params
asymmetric_jitter_prob=0.2,
do_flip=True,
):
super().__init__()
transforms = [
T.PILToTensor(),
T.AsymmetricColorJitter(
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue, p=asymmetric_jitter_prob
),
T.RandomResizeAndCrop(
crop_size=crop_size, min_scale=min_scale, max_scale=max_scale, stretch_prob=stretch_prob
),
]
if do_flip:
transforms += [T.RandomHorizontalFlip(p=0.5), T.RandomVerticalFlip(p=0.1)]
transforms += [
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1]
T.RandomErasing(max_erase=2),
T.MakeValidFlowMask(),
T.ValidateModelInput(),
]
self.transforms = T.Compose(transforms)
def forward(self, img1, img2, flow, valid):
return self.transforms(img1, img2, flow, valid)