51 lines
1.9 KiB
Python
51 lines
1.9 KiB
Python
import torch
|
|
from torchvision.transforms import transforms
|
|
from transforms import ConvertBCHWtoCBHW
|
|
|
|
|
|
class VideoClassificationPresetTrain:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
crop_size,
|
|
resize_size,
|
|
mean=(0.43216, 0.394666, 0.37645),
|
|
std=(0.22803, 0.22145, 0.216989),
|
|
hflip_prob=0.5,
|
|
):
|
|
trans = [
|
|
transforms.ConvertImageDtype(torch.float32),
|
|
# We hard-code antialias=False to preserve results after we changed
|
|
# its default from None to True (see
|
|
# https://github.com/pytorch/vision/pull/7160)
|
|
# TODO: we could re-train the video models with antialias=True?
|
|
transforms.Resize(resize_size, antialias=False),
|
|
]
|
|
if hflip_prob > 0:
|
|
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
|
|
trans.extend([transforms.Normalize(mean=mean, std=std), transforms.RandomCrop(crop_size), ConvertBCHWtoCBHW()])
|
|
self.transforms = transforms.Compose(trans)
|
|
|
|
def __call__(self, x):
|
|
return self.transforms(x)
|
|
|
|
|
|
class VideoClassificationPresetEval:
|
|
def __init__(self, *, crop_size, resize_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)):
|
|
self.transforms = transforms.Compose(
|
|
[
|
|
transforms.ConvertImageDtype(torch.float32),
|
|
# We hard-code antialias=False to preserve results after we changed
|
|
# its default from None to True (see
|
|
# https://github.com/pytorch/vision/pull/7160)
|
|
# TODO: we could re-train the video models with antialias=True?
|
|
transforms.Resize(resize_size, antialias=False),
|
|
transforms.Normalize(mean=mean, std=std),
|
|
transforms.CenterCrop(crop_size),
|
|
ConvertBCHWtoCBHW(),
|
|
]
|
|
)
|
|
|
|
def __call__(self, x):
|
|
return self.transforms(x)
|