sglang_v0.5.2/vision_0.22.1/references/depth/stereo/train.py

789 lines
33 KiB
Python

import argparse
import os
import warnings
from pathlib import Path
from typing import List, Union
import numpy as np
import torch
import torch.distributed as dist
import torchvision.models.optical_flow
import torchvision.prototype.models.depth.stereo
import utils
import visualization
from parsing import make_dataset, make_eval_transform, make_train_transform, VALID_DATASETS
from torch import nn
from torchvision.transforms.functional import get_dimensions, InterpolationMode, resize
from utils.metrics import AVAILABLE_METRICS
from utils.norm import freeze_batch_norm
def make_stereo_flow(flow: Union[torch.Tensor, List[torch.Tensor]], model_out_channels: int) -> torch.Tensor:
"""Helper function to make stereo flow from a given model output"""
if isinstance(flow, list):
return [make_stereo_flow(flow_i, model_out_channels) for flow_i in flow]
B, C, H, W = flow.shape
# we need to add zero flow if the model outputs 2 channels
if C == 1 and model_out_channels == 2:
zero_flow = torch.zeros_like(flow)
# by convention the flow is X-Y axis, so we need the Y flow last
flow = torch.cat([flow, zero_flow], dim=1)
return flow
def make_lr_schedule(args: argparse.Namespace, optimizer: torch.optim.Optimizer) -> np.ndarray:
"""Helper function to return a learning rate scheduler for CRE-stereo"""
if args.decay_after_steps < args.warmup_steps:
raise ValueError(f"decay_after_steps: {args.function} must be greater than warmup_steps: {args.warmup_steps}")
warmup_steps = args.warmup_steps if args.warmup_steps else 0
flat_lr_steps = args.decay_after_steps - warmup_steps if args.decay_after_steps else 0
decay_lr_steps = args.total_iterations - flat_lr_steps
max_lr = args.lr
min_lr = args.min_lr
schedulers = []
milestones = []
if warmup_steps > 0:
if args.lr_warmup_method == "linear":
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=args.lr_warmup_factor, total_iters=warmup_steps
)
elif args.lr_warmup_method == "constant":
warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
optimizer, factor=args.lr_warmup_factor, total_iters=warmup_steps
)
else:
raise ValueError(f"Unknown lr warmup method {args.lr_warmup_method}")
schedulers.append(warmup_lr_scheduler)
milestones.append(warmup_steps)
if flat_lr_steps > 0:
flat_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=max_lr, total_iters=flat_lr_steps)
schedulers.append(flat_lr_scheduler)
milestones.append(flat_lr_steps + warmup_steps)
if decay_lr_steps > 0:
if args.lr_decay_method == "cosine":
decay_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=decay_lr_steps, eta_min=min_lr
)
elif args.lr_decay_method == "linear":
decay_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=max_lr, end_factor=min_lr, total_iters=decay_lr_steps
)
elif args.lr_decay_method == "exponential":
decay_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
optimizer, gamma=args.lr_decay_gamma, last_epoch=-1
)
else:
raise ValueError(f"Unknown lr decay method {args.lr_decay_method}")
schedulers.append(decay_lr_scheduler)
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers, milestones=milestones)
return scheduler
def shuffle_dataset(dataset):
"""Shuffle the dataset"""
perm = torch.randperm(len(dataset))
return torch.utils.data.Subset(dataset, perm)
def resize_dataset_to_n_steps(
dataset: torch.utils.data.Dataset, dataset_steps: int, samples_per_step: int, args: argparse.Namespace
) -> torch.utils.data.Dataset:
original_size = len(dataset)
if args.steps_is_epochs:
samples_per_step = original_size
target_size = dataset_steps * samples_per_step
dataset_copies = []
n_expands, remainder = divmod(target_size, original_size)
for idx in range(n_expands):
dataset_copies.append(dataset)
if remainder > 0:
dataset_copies.append(torch.utils.data.Subset(dataset, list(range(remainder))))
if args.dataset_shuffle:
dataset_copies = [shuffle_dataset(dataset_copy) for dataset_copy in dataset_copies]
dataset = torch.utils.data.ConcatDataset(dataset_copies)
return dataset
def get_train_dataset(dataset_root: str, args: argparse.Namespace) -> torch.utils.data.Dataset:
datasets = []
for dataset_name in args.train_datasets:
transform = make_train_transform(args)
dataset = make_dataset(dataset_name, dataset_root, transform)
datasets.append(dataset)
if len(datasets) == 0:
raise ValueError("No datasets specified for training")
samples_per_step = args.world_size * args.batch_size
for idx, (dataset, steps_per_dataset) in enumerate(zip(datasets, args.dataset_steps)):
datasets[idx] = resize_dataset_to_n_steps(dataset, steps_per_dataset, samples_per_step, args)
dataset = torch.utils.data.ConcatDataset(datasets)
if args.dataset_order_shuffle:
dataset = shuffle_dataset(dataset)
print(f"Training dataset: {len(dataset)} samples")
return dataset
@torch.inference_mode()
def _evaluate(
model,
args,
val_loader,
*,
padder_mode,
print_freq=10,
writer=None,
step=None,
iterations=None,
batch_size=None,
header=None,
):
"""Helper function to compute various metrics (epe, etc.) for a model on a given dataset."""
model.eval()
header = header or "Test:"
device = torch.device(args.device)
metric_logger = utils.MetricLogger(delimiter=" ")
iterations = iterations or args.recurrent_updates
logger = utils.MetricLogger()
for meter_name in args.metrics:
logger.add_meter(meter_name, fmt="{global_avg:.4f}")
if "fl-all" not in args.metrics:
logger.add_meter("fl-all", fmt="{global_avg:.4f}")
num_processed_samples = 0
with torch.cuda.amp.autocast(enabled=args.mixed_precision, dtype=torch.float16):
for blob in metric_logger.log_every(val_loader, print_freq, header):
image_left, image_right, disp_gt, valid_disp_mask = (x.to(device) for x in blob)
padder = utils.InputPadder(image_left.shape, mode=padder_mode)
image_left, image_right = padder.pad(image_left, image_right)
disp_predictions = model(image_left, image_right, flow_init=None, num_iters=iterations)
disp_pred = disp_predictions[-1][:, :1, :, :]
disp_pred = padder.unpad(disp_pred)
metrics, _ = utils.compute_metrics(disp_pred, disp_gt, valid_disp_mask, metrics=logger.meters.keys())
num_processed_samples += image_left.shape[0]
for name in metrics:
logger.meters[name].update(metrics[name], n=1)
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
print("Num_processed_samples: ", num_processed_samples)
if (
hasattr(val_loader.dataset, "__len__")
and len(val_loader.dataset) != num_processed_samples
and torch.distributed.get_rank() == 0
):
warnings.warn(
f"Number of processed samples {num_processed_samples} is different"
f"from the dataset size {len(val_loader.dataset)}. This may happen if"
"the dataset is not divisible by the batch size. Try lowering the batch size or GPU number for more accurate results."
)
if writer is not None and args.rank == 0:
for meter_name, meter_value in logger.meters.items():
scalar_name = f"{meter_name} {header}"
writer.add_scalar(scalar_name, meter_value.avg, step)
logger.synchronize_between_processes()
print(header, logger)
def make_eval_loader(dataset_name: str, args: argparse.Namespace) -> torch.utils.data.DataLoader:
if args.weights:
weights = torchvision.models.get_weight(args.weights)
trans = weights.transforms()
def preprocessing(image_left, image_right, disp, valid_disp_mask):
C_o, H_o, W_o = get_dimensions(image_left)
image_left, image_right = trans(image_left, image_right)
C_t, H_t, W_t = get_dimensions(image_left)
scale_factor = W_t / W_o
if disp is not None and not isinstance(disp, torch.Tensor):
disp = torch.from_numpy(disp)
if W_t != W_o:
disp = resize(disp, (H_t, W_t), mode=InterpolationMode.BILINEAR) * scale_factor
if valid_disp_mask is not None and not isinstance(valid_disp_mask, torch.Tensor):
valid_disp_mask = torch.from_numpy(valid_disp_mask)
if W_t != W_o:
valid_disp_mask = resize(valid_disp_mask, (H_t, W_t), mode=InterpolationMode.NEAREST)
return image_left, image_right, disp, valid_disp_mask
else:
preprocessing = make_eval_transform(args)
val_dataset = make_dataset(dataset_name, args.dataset_root, transforms=preprocessing)
if args.distributed:
sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=False)
else:
sampler = torch.utils.data.SequentialSampler(val_dataset)
val_loader = torch.utils.data.DataLoader(
val_dataset,
sampler=sampler,
batch_size=args.batch_size,
pin_memory=True,
num_workers=args.workers,
)
return val_loader
def evaluate(model, loaders, args, writer=None, step=None):
for loader_name, loader in loaders.items():
_evaluate(
model,
args,
loader,
iterations=args.recurrent_updates,
padder_mode=args.padder_type,
header=f"{loader_name} evaluation",
batch_size=args.batch_size,
writer=writer,
step=step,
)
def run(model, optimizer, scheduler, train_loader, val_loaders, logger, writer, scaler, args):
device = torch.device(args.device)
# wrap the loader in a logger
loader = iter(logger.log_every(train_loader))
# output channels
model_out_channels = model.module.output_channels if args.distributed else model.output_channels
torch.set_num_threads(args.threads)
sequence_criterion = utils.SequenceLoss(
gamma=args.gamma,
max_flow=args.max_disparity,
exclude_large_flows=args.flow_loss_exclude_large,
).to(device)
if args.consistency_weight:
consistency_criterion = utils.FlowSequenceConsistencyLoss(
args.gamma,
resize_factor=0.25,
rescale_factor=0.25,
rescale_mode="bilinear",
).to(device)
else:
consistency_criterion = None
if args.psnr_weight:
psnr_criterion = utils.PSNRLoss().to(device)
else:
psnr_criterion = None
if args.smoothness_weight:
smoothness_criterion = utils.SmoothnessLoss().to(device)
else:
smoothness_criterion = None
if args.photometric_weight:
photometric_criterion = utils.FlowPhotoMetricLoss(
ssim_weight=args.photometric_ssim_weight,
max_displacement_ratio=args.photometric_max_displacement_ratio,
ssim_use_padding=False,
).to(device)
else:
photometric_criterion = None
for step in range(args.start_step + 1, args.total_iterations + 1):
data_blob = next(loader)
optimizer.zero_grad()
# unpack the data blob
image_left, image_right, disp_mask, valid_disp_mask = (x.to(device) for x in data_blob)
with torch.cuda.amp.autocast(enabled=args.mixed_precision, dtype=torch.float16):
disp_predictions = model(image_left, image_right, flow_init=None, num_iters=args.recurrent_updates)
# different models have different outputs, make sure we get the right ones for this task
disp_predictions = make_stereo_flow(disp_predictions, model_out_channels)
# should the architecture or training loop require it, we have to adjust the disparity mask
# target to possibly look like an optical flow mask
disp_mask = make_stereo_flow(disp_mask, model_out_channels)
# sequence loss on top of the model outputs
loss = sequence_criterion(disp_predictions, disp_mask, valid_disp_mask) * args.flow_loss_weight
if args.consistency_weight > 0:
loss_consistency = consistency_criterion(disp_predictions)
loss += loss_consistency * args.consistency_weight
if args.psnr_weight > 0:
loss_psnr = 0.0
for pred in disp_predictions:
# predictions might have 2 channels
loss_psnr += psnr_criterion(
pred * valid_disp_mask.unsqueeze(1),
disp_mask * valid_disp_mask.unsqueeze(1),
).mean() # mean the psnr loss over the batch
loss += loss_psnr / len(disp_predictions) * args.psnr_weight
if args.photometric_weight > 0:
loss_photometric = 0.0
for pred in disp_predictions:
# predictions might have 1 channel, therefore we need to inpute 0s for the second channel
if model_out_channels == 1:
pred = torch.cat([pred, torch.zeros_like(pred)], dim=1)
loss_photometric += photometric_criterion(
image_left, image_right, pred, valid_disp_mask
) # photometric loss already comes out meaned over the batch
loss += loss_photometric / len(disp_predictions) * args.photometric_weight
if args.smoothness_weight > 0:
loss_smoothness = 0.0
for pred in disp_predictions:
# predictions might have 2 channels
loss_smoothness += smoothness_criterion(
image_left, pred[:, :1, :, :]
).mean() # mean the smoothness loss over the batch
loss += loss_smoothness / len(disp_predictions) * args.smoothness_weight
with torch.no_grad():
metrics, _ = utils.compute_metrics(
disp_predictions[-1][:, :1, :, :], # predictions might have 2 channels
disp_mask[:, :1, :, :], # so does the ground truth
valid_disp_mask,
args.metrics,
)
metrics.pop("fl-all", None)
logger.update(loss=loss, **metrics)
if scaler is not None:
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
if args.clip_grad_norm:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad_norm)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
if args.clip_grad_norm:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad_norm)
optimizer.step()
scheduler.step()
if not dist.is_initialized() or dist.get_rank() == 0:
if writer is not None and step % args.tensorboard_log_frequency == 0:
# log the loss and metrics to tensorboard
writer.add_scalar("loss", loss, step)
for name, value in logger.meters.items():
writer.add_scalar(name, value.avg, step)
# log the images to tensorboard
pred_grid = visualization.make_training_sample_grid(
image_left, image_right, disp_mask, valid_disp_mask, disp_predictions
)
writer.add_image("predictions", pred_grid, step, dataformats="HWC")
# second thing we want to see is how relevant the iterative refinement is
pred_sequence_grid = visualization.make_disparity_sequence_grid(disp_predictions, disp_mask)
writer.add_image("sequence", pred_sequence_grid, step, dataformats="HWC")
if step % args.save_frequency == 0:
if not args.distributed or args.rank == 0:
model_without_ddp = (
model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
)
checkpoint = {
"model": model_without_ddp.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"step": step,
"args": args,
}
os.makedirs(args.checkpoint_dir, exist_ok=True)
torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}_{step}.pth")
torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}.pth")
if step % args.valid_frequency == 0:
evaluate(model, val_loaders, args, writer, step)
model.train()
if args.freeze_batch_norm:
if isinstance(model, nn.parallel.DistributedDataParallel):
freeze_batch_norm(model.module)
else:
freeze_batch_norm(model)
# one final save at the end
if not args.distributed or args.rank == 0:
model_without_ddp = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
checkpoint = {
"model": model_without_ddp.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"step": step,
"args": args,
}
os.makedirs(args.checkpoint_dir, exist_ok=True)
torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}_{step}.pth")
torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}.pth")
def main(args):
args.total_iterations = sum(args.dataset_steps)
# initialize DDP setting
utils.setup_ddp(args)
print(args)
args.test_only = args.train_datasets is None
# set the appropriate devices
if args.distributed and args.device == "cpu":
raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun")
device = torch.device(args.device)
# select model architecture
model = torchvision.prototype.models.depth.stereo.__dict__[args.model](weights=args.weights)
# convert to DDP if need be
if args.distributed:
model = model.to(args.gpu)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
else:
model.to(device)
model_without_ddp = model
os.makedirs(args.checkpoint_dir, exist_ok=True)
val_loaders = {name: make_eval_loader(name, args) for name in args.test_datasets}
# EVAL ONLY configurations
if args.test_only:
evaluate(model, val_loaders, args)
return
# Sanity check for the parameter count
print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
# Compose the training dataset
train_dataset = get_train_dataset(args.dataset_root, args)
# initialize the optimizer
if args.optimizer == "adam":
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
elif args.optimizer == "sgd":
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9)
else:
raise ValueError(f"Unknown optimizer {args.optimizer}. Please choose between adam and sgd")
# initialize the learning rate schedule
scheduler = make_lr_schedule(args, optimizer)
# load them from checkpoint if needed
args.start_step = 0
if args.resume_path is not None:
checkpoint = torch.load(args.resume_path, map_location="cpu", weights_only=True)
if "model" in checkpoint:
# this means the user requested to resume from a training checkpoint
model_without_ddp.load_state_dict(checkpoint["model"])
# this means the user wants to continue training from where it was left off
if args.resume_schedule:
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
args.start_step = checkpoint["step"] + 1
# modify starting point of the dat
sample_start_step = args.start_step * args.batch_size * args.world_size
train_dataset = train_dataset[sample_start_step:]
else:
# this means the user wants to finetune on top of a model state dict
# and that no other changes are required
model_without_ddp.load_state_dict(checkpoint)
torch.backends.cudnn.benchmark = True
# enable training mode
model.train()
if args.freeze_batch_norm:
freeze_batch_norm(model_without_ddp)
# put dataloader on top of the dataset
# make sure to disable shuffling since the dataset is already shuffled
# in order to guarantee quasi randomness whilst retaining a deterministic
# dataset consumption order
if args.distributed:
# the train dataset is preshuffled in order to respect the iteration order
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=False, drop_last=True)
else:
# the train dataset is already shuffled, so we can use a simple SequentialSampler
sampler = torch.utils.data.SequentialSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset,
sampler=sampler,
batch_size=args.batch_size,
pin_memory=True,
num_workers=args.workers,
)
# initialize the logger
if args.tensorboard_summaries:
from torch.utils.tensorboard import SummaryWriter
tensorboard_path = Path(args.checkpoint_dir) / "tensorboard"
os.makedirs(tensorboard_path, exist_ok=True)
tensorboard_run = tensorboard_path / f"{args.name}"
writer = SummaryWriter(tensorboard_run)
else:
writer = None
logger = utils.MetricLogger(delimiter=" ")
scaler = torch.cuda.amp.GradScaler() if args.mixed_precision else None
# run the training loop
# this will perform optimization, respectively logging and saving checkpoints
# when need be
run(
model=model,
optimizer=optimizer,
scheduler=scheduler,
train_loader=train_loader,
val_loaders=val_loaders,
logger=logger,
writer=writer,
scaler=scaler,
args=args,
)
def get_args_parser(add_help=True):
import argparse
parser = argparse.ArgumentParser(description="PyTorch Stereo Matching Training", add_help=add_help)
# checkpointing
parser.add_argument("--name", default="crestereo", help="name of the experiment")
parser.add_argument("--resume", type=str, default=None, help="from which checkpoint to resume")
parser.add_argument("--checkpoint-dir", type=str, default="checkpoints", help="path to the checkpoint directory")
# dataset
parser.add_argument("--dataset-root", type=str, default="", help="path to the dataset root directory")
parser.add_argument(
"--train-datasets",
type=str,
nargs="+",
default=["crestereo"],
help="dataset(s) to train on",
choices=list(VALID_DATASETS.keys()),
)
parser.add_argument(
"--dataset-steps", type=int, nargs="+", default=[300_000], help="number of steps for each dataset"
)
parser.add_argument(
"--steps-is-epochs", action="store_true", help="if set, dataset-steps are interpreted as epochs"
)
parser.add_argument(
"--test-datasets",
type=str,
nargs="+",
default=["middlebury2014-train"],
help="dataset(s) to test on",
choices=["middlebury2014-train"],
)
parser.add_argument("--dataset-shuffle", type=bool, help="shuffle the dataset", default=True)
parser.add_argument("--dataset-order-shuffle", type=bool, help="shuffle the dataset order", default=True)
parser.add_argument("--batch-size", type=int, default=2, help="batch size per GPU")
parser.add_argument("--workers", type=int, default=4, help="number of workers per GPU")
parser.add_argument(
"--threads",
type=int,
default=16,
help="number of CPU threads per GPU. This can be changed around to speed-up transforms if needed. This can lead to worker thread contention so use with care.",
)
# model architecture
parser.add_argument(
"--model",
type=str,
default="crestereo_base",
help="model architecture",
choices=["crestereo_base", "raft_stereo"],
)
parser.add_argument("--recurrent-updates", type=int, default=10, help="number of recurrent updates")
parser.add_argument("--freeze-batch-norm", action="store_true", help="freeze batch norm parameters")
# loss parameters
parser.add_argument("--gamma", type=float, default=0.8, help="gamma parameter for the flow sequence loss")
parser.add_argument("--flow-loss-weight", type=float, default=1.0, help="weight for the flow loss")
parser.add_argument(
"--flow-loss-exclude-large",
action="store_true",
help="exclude large flow values from the loss. A large value is defined as a value greater than the ground truth flow norm",
default=False,
)
parser.add_argument("--consistency-weight", type=float, default=0.0, help="consistency loss weight")
parser.add_argument(
"--consistency-resize-factor",
type=float,
default=0.25,
help="consistency loss resize factor to account for the fact that the flow is computed on a downsampled image",
)
parser.add_argument("--psnr-weight", type=float, default=0.0, help="psnr loss weight")
parser.add_argument("--smoothness-weight", type=float, default=0.0, help="smoothness loss weight")
parser.add_argument("--photometric-weight", type=float, default=0.0, help="photometric loss weight")
parser.add_argument(
"--photometric-max-displacement-ratio",
type=float,
default=0.15,
help="Only pixels with a displacement smaller than this ratio of the image width will be considered for the photometric loss",
)
parser.add_argument("--photometric-ssim-weight", type=float, default=0.85, help="photometric ssim loss weight")
# transforms parameters
parser.add_argument("--gpu-transforms", action="store_true", help="use GPU transforms")
parser.add_argument(
"--eval-size", type=int, nargs="+", default=[384, 512], help="size of the images for evaluation"
)
parser.add_argument("--resize-size", type=int, nargs=2, default=None, help="resize size")
parser.add_argument("--crop-size", type=int, nargs=2, default=[384, 512], help="crop size")
parser.add_argument("--scale-range", type=float, nargs=2, default=[0.6, 1.0], help="random scale range")
parser.add_argument("--rescale-prob", type=float, default=1.0, help="probability of resizing the image")
parser.add_argument(
"--scaling-type", type=str, default="linear", help="scaling type", choices=["exponential", "linear"]
)
parser.add_argument("--flip-prob", type=float, default=0.5, help="probability of flipping the image")
parser.add_argument(
"--norm-mean", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="mean for image normalization"
)
parser.add_argument(
"--norm-std", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="std for image normalization"
)
parser.add_argument(
"--use-grayscale", action="store_true", help="use grayscale images instead of RGB", default=False
)
parser.add_argument("--max-disparity", type=float, default=None, help="maximum disparity")
parser.add_argument(
"--interpolation-strategy",
type=str,
default="bilinear",
help="interpolation strategy",
choices=["bilinear", "bicubic", "mixed"],
)
parser.add_argument("--spatial-shift-prob", type=float, default=1.0, help="probability of shifting the image")
parser.add_argument(
"--spatial-shift-max-angle", type=float, default=0.1, help="maximum angle for the spatial shift"
)
parser.add_argument(
"--spatial-shift-max-displacement", type=float, default=2.0, help="maximum displacement for the spatial shift"
)
parser.add_argument("--gamma-range", type=float, nargs="+", default=[0.8, 1.2], help="range for gamma correction")
parser.add_argument(
"--brightness-range", type=float, nargs="+", default=[0.8, 1.2], help="range for brightness correction"
)
parser.add_argument(
"--contrast-range", type=float, nargs="+", default=[0.8, 1.2], help="range for contrast correction"
)
parser.add_argument(
"--saturation-range", type=float, nargs="+", default=0.0, help="range for saturation correction"
)
parser.add_argument("--hue-range", type=float, nargs="+", default=0.0, help="range for hue correction")
parser.add_argument(
"--asymmetric-jitter-prob",
type=float,
default=1.0,
help="probability of using asymmetric jitter instead of symmetric jitter",
)
parser.add_argument("--occlusion-prob", type=float, default=0.5, help="probability of occluding the rightimage")
parser.add_argument(
"--occlusion-px-range", type=int, nargs="+", default=[50, 100], help="range for the number of occluded pixels"
)
parser.add_argument("--erase-prob", type=float, default=0.0, help="probability of erasing in both images")
parser.add_argument(
"--erase-px-range", type=int, nargs="+", default=[50, 100], help="range for the number of erased pixels"
)
parser.add_argument(
"--erase-num-repeats", type=int, default=1, help="number of times to repeat the erase operation"
)
# optimizer parameters
parser.add_argument("--optimizer", type=str, default="adam", help="optimizer", choices=["adam", "sgd"])
parser.add_argument("--lr", type=float, default=4e-4, help="learning rate")
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay")
parser.add_argument("--clip-grad-norm", type=float, default=0.0, help="clip grad norm")
# lr_scheduler parameters
parser.add_argument("--min-lr", type=float, default=2e-5, help="minimum learning rate")
parser.add_argument("--warmup-steps", type=int, default=6_000, help="number of warmup steps")
parser.add_argument(
"--decay-after-steps", type=int, default=180_000, help="number of steps after which to start decay the lr"
)
parser.add_argument(
"--lr-warmup-method", type=str, default="linear", help="warmup method", choices=["linear", "cosine"]
)
parser.add_argument("--lr-warmup-factor", type=float, default=0.02, help="warmup factor for the learning rate")
parser.add_argument(
"--lr-decay-method",
type=str,
default="linear",
help="decay method",
choices=["linear", "cosine", "exponential"],
)
parser.add_argument("--lr-decay-gamma", type=float, default=0.8, help="decay factor for the learning rate")
# deterministic behaviour
parser.add_argument("--seed", type=int, default=42, help="seed for random number generators")
# mixed precision training
parser.add_argument("--mixed-precision", action="store_true", help="use mixed precision training")
# logging
parser.add_argument("--tensorboard-summaries", action="store_true", help="log to tensorboard")
parser.add_argument("--tensorboard-log-frequency", type=int, default=100, help="log frequency")
parser.add_argument("--save-frequency", type=int, default=1_000, help="save frequency")
parser.add_argument("--valid-frequency", type=int, default=1_000, help="validation frequency")
parser.add_argument(
"--metrics",
type=str,
nargs="+",
default=["mae", "rmse", "1px", "3px", "5px", "relepe"],
help="metrics to log",
choices=AVAILABLE_METRICS,
)
# distributed parameters
parser.add_argument("--world-size", type=int, default=8, help="number of distributed processes")
parser.add_argument("--dist-url", type=str, default="env://", help="url used to set up distributed training")
parser.add_argument("--device", type=str, default="cuda", help="device to use for training")
# weights API
parser.add_argument("--weights", type=str, default=None, help="weights API url")
parser.add_argument(
"--resume-path", type=str, default=None, help="a path from which to resume or start fine-tuning"
)
parser.add_argument("--resume-schedule", action="store_true", help="resume optimizer state")
# padder parameters
parser.add_argument("--padder-type", type=str, default="kitti", help="padder type", choices=["kitti", "sintel"])
return parser
if __name__ == "__main__":
args = get_args_parser().parse_args()
main(args)