sglang_v0.5.2/vision_0.22.1/references/depth/stereo/parsing.py

90 lines
3.4 KiB
Python

import argparse
from functools import partial
import torch
from presets import StereoMatchingEvalPreset, StereoMatchingTrainPreset
from torchvision.datasets import (
CarlaStereo,
CREStereo,
ETH3DStereo,
FallingThingsStereo,
InStereo2k,
Kitti2012Stereo,
Kitti2015Stereo,
Middlebury2014Stereo,
SceneFlowStereo,
SintelStereo,
)
VALID_DATASETS = {
"crestereo": partial(CREStereo),
"carla-highres": partial(CarlaStereo),
"instereo2k": partial(InStereo2k),
"sintel": partial(SintelStereo),
"sceneflow-monkaa": partial(SceneFlowStereo, variant="Monkaa", pass_name="both"),
"sceneflow-flyingthings": partial(SceneFlowStereo, variant="FlyingThings3D", pass_name="both"),
"sceneflow-driving": partial(SceneFlowStereo, variant="Driving", pass_name="both"),
"fallingthings": partial(FallingThingsStereo, variant="both"),
"eth3d-train": partial(ETH3DStereo, split="train"),
"eth3d-test": partial(ETH3DStereo, split="test"),
"kitti2015-train": partial(Kitti2015Stereo, split="train"),
"kitti2015-test": partial(Kitti2015Stereo, split="test"),
"kitti2012-train": partial(Kitti2012Stereo, split="train"),
"kitti2012-test": partial(Kitti2012Stereo, split="train"),
"middlebury2014-other": partial(
Middlebury2014Stereo, split="additional", use_ambient_view=True, calibration="both"
),
"middlebury2014-train": partial(Middlebury2014Stereo, split="train", calibration="perfect"),
"middlebury2014-test": partial(Middlebury2014Stereo, split="test", calibration=None),
"middlebury2014-train-ambient": partial(
Middlebury2014Stereo, split="train", use_ambient_views=True, calibrartion="perfect"
),
}
def make_train_transform(args: argparse.Namespace) -> torch.nn.Module:
return StereoMatchingTrainPreset(
resize_size=args.resize_size,
crop_size=args.crop_size,
rescale_prob=args.rescale_prob,
scaling_type=args.scaling_type,
scale_range=args.scale_range,
scale_interpolation_type=args.interpolation_strategy,
use_grayscale=args.use_grayscale,
mean=args.norm_mean,
std=args.norm_std,
horizontal_flip_prob=args.flip_prob,
gpu_transforms=args.gpu_transforms,
max_disparity=args.max_disparity,
spatial_shift_prob=args.spatial_shift_prob,
spatial_shift_max_angle=args.spatial_shift_max_angle,
spatial_shift_max_displacement=args.spatial_shift_max_displacement,
spatial_shift_interpolation_type=args.interpolation_strategy,
gamma_range=args.gamma_range,
brightness=args.brightness_range,
contrast=args.contrast_range,
saturation=args.saturation_range,
hue=args.hue_range,
asymmetric_jitter_prob=args.asymmetric_jitter_prob,
)
def make_eval_transform(args: argparse.Namespace) -> torch.nn.Module:
if args.eval_size is None:
resize_size = args.crop_size
else:
resize_size = args.eval_size
return StereoMatchingEvalPreset(
mean=args.norm_mean,
std=args.norm_std,
use_grayscale=args.use_grayscale,
resize_size=resize_size,
interpolation_type=args.interpolation_strategy,
)
def make_dataset(dataset_name: str, dataset_root: str, transforms: torch.nn.Module) -> torch.utils.data.Dataset:
return VALID_DATASETS[dataset_name](root=dataset_root, transforms=transforms)