789 lines
33 KiB
Python
789 lines
33 KiB
Python
import argparse
|
|
import os
|
|
import warnings
|
|
from pathlib import Path
|
|
from typing import List, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torchvision.models.optical_flow
|
|
import torchvision.prototype.models.depth.stereo
|
|
import utils
|
|
import visualization
|
|
|
|
from parsing import make_dataset, make_eval_transform, make_train_transform, VALID_DATASETS
|
|
from torch import nn
|
|
from torchvision.transforms.functional import get_dimensions, InterpolationMode, resize
|
|
from utils.metrics import AVAILABLE_METRICS
|
|
from utils.norm import freeze_batch_norm
|
|
|
|
|
|
def make_stereo_flow(flow: Union[torch.Tensor, List[torch.Tensor]], model_out_channels: int) -> torch.Tensor:
|
|
"""Helper function to make stereo flow from a given model output"""
|
|
if isinstance(flow, list):
|
|
return [make_stereo_flow(flow_i, model_out_channels) for flow_i in flow]
|
|
|
|
B, C, H, W = flow.shape
|
|
# we need to add zero flow if the model outputs 2 channels
|
|
if C == 1 and model_out_channels == 2:
|
|
zero_flow = torch.zeros_like(flow)
|
|
# by convention the flow is X-Y axis, so we need the Y flow last
|
|
flow = torch.cat([flow, zero_flow], dim=1)
|
|
return flow
|
|
|
|
|
|
def make_lr_schedule(args: argparse.Namespace, optimizer: torch.optim.Optimizer) -> np.ndarray:
|
|
"""Helper function to return a learning rate scheduler for CRE-stereo"""
|
|
if args.decay_after_steps < args.warmup_steps:
|
|
raise ValueError(f"decay_after_steps: {args.function} must be greater than warmup_steps: {args.warmup_steps}")
|
|
|
|
warmup_steps = args.warmup_steps if args.warmup_steps else 0
|
|
flat_lr_steps = args.decay_after_steps - warmup_steps if args.decay_after_steps else 0
|
|
decay_lr_steps = args.total_iterations - flat_lr_steps
|
|
|
|
max_lr = args.lr
|
|
min_lr = args.min_lr
|
|
|
|
schedulers = []
|
|
milestones = []
|
|
|
|
if warmup_steps > 0:
|
|
if args.lr_warmup_method == "linear":
|
|
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
|
|
optimizer, start_factor=args.lr_warmup_factor, total_iters=warmup_steps
|
|
)
|
|
elif args.lr_warmup_method == "constant":
|
|
warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
|
|
optimizer, factor=args.lr_warmup_factor, total_iters=warmup_steps
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown lr warmup method {args.lr_warmup_method}")
|
|
schedulers.append(warmup_lr_scheduler)
|
|
milestones.append(warmup_steps)
|
|
|
|
if flat_lr_steps > 0:
|
|
flat_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=max_lr, total_iters=flat_lr_steps)
|
|
schedulers.append(flat_lr_scheduler)
|
|
milestones.append(flat_lr_steps + warmup_steps)
|
|
|
|
if decay_lr_steps > 0:
|
|
if args.lr_decay_method == "cosine":
|
|
decay_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
optimizer, T_max=decay_lr_steps, eta_min=min_lr
|
|
)
|
|
elif args.lr_decay_method == "linear":
|
|
decay_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
|
|
optimizer, start_factor=max_lr, end_factor=min_lr, total_iters=decay_lr_steps
|
|
)
|
|
elif args.lr_decay_method == "exponential":
|
|
decay_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
|
|
optimizer, gamma=args.lr_decay_gamma, last_epoch=-1
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown lr decay method {args.lr_decay_method}")
|
|
schedulers.append(decay_lr_scheduler)
|
|
|
|
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers, milestones=milestones)
|
|
return scheduler
|
|
|
|
|
|
def shuffle_dataset(dataset):
|
|
"""Shuffle the dataset"""
|
|
perm = torch.randperm(len(dataset))
|
|
return torch.utils.data.Subset(dataset, perm)
|
|
|
|
|
|
def resize_dataset_to_n_steps(
|
|
dataset: torch.utils.data.Dataset, dataset_steps: int, samples_per_step: int, args: argparse.Namespace
|
|
) -> torch.utils.data.Dataset:
|
|
original_size = len(dataset)
|
|
if args.steps_is_epochs:
|
|
samples_per_step = original_size
|
|
target_size = dataset_steps * samples_per_step
|
|
|
|
dataset_copies = []
|
|
n_expands, remainder = divmod(target_size, original_size)
|
|
for idx in range(n_expands):
|
|
dataset_copies.append(dataset)
|
|
|
|
if remainder > 0:
|
|
dataset_copies.append(torch.utils.data.Subset(dataset, list(range(remainder))))
|
|
|
|
if args.dataset_shuffle:
|
|
dataset_copies = [shuffle_dataset(dataset_copy) for dataset_copy in dataset_copies]
|
|
|
|
dataset = torch.utils.data.ConcatDataset(dataset_copies)
|
|
return dataset
|
|
|
|
|
|
def get_train_dataset(dataset_root: str, args: argparse.Namespace) -> torch.utils.data.Dataset:
|
|
datasets = []
|
|
for dataset_name in args.train_datasets:
|
|
transform = make_train_transform(args)
|
|
dataset = make_dataset(dataset_name, dataset_root, transform)
|
|
datasets.append(dataset)
|
|
|
|
if len(datasets) == 0:
|
|
raise ValueError("No datasets specified for training")
|
|
|
|
samples_per_step = args.world_size * args.batch_size
|
|
|
|
for idx, (dataset, steps_per_dataset) in enumerate(zip(datasets, args.dataset_steps)):
|
|
datasets[idx] = resize_dataset_to_n_steps(dataset, steps_per_dataset, samples_per_step, args)
|
|
|
|
dataset = torch.utils.data.ConcatDataset(datasets)
|
|
if args.dataset_order_shuffle:
|
|
dataset = shuffle_dataset(dataset)
|
|
|
|
print(f"Training dataset: {len(dataset)} samples")
|
|
return dataset
|
|
|
|
|
|
@torch.inference_mode()
|
|
def _evaluate(
|
|
model,
|
|
args,
|
|
val_loader,
|
|
*,
|
|
padder_mode,
|
|
print_freq=10,
|
|
writer=None,
|
|
step=None,
|
|
iterations=None,
|
|
batch_size=None,
|
|
header=None,
|
|
):
|
|
"""Helper function to compute various metrics (epe, etc.) for a model on a given dataset."""
|
|
model.eval()
|
|
header = header or "Test:"
|
|
device = torch.device(args.device)
|
|
metric_logger = utils.MetricLogger(delimiter=" ")
|
|
|
|
iterations = iterations or args.recurrent_updates
|
|
|
|
logger = utils.MetricLogger()
|
|
for meter_name in args.metrics:
|
|
logger.add_meter(meter_name, fmt="{global_avg:.4f}")
|
|
if "fl-all" not in args.metrics:
|
|
logger.add_meter("fl-all", fmt="{global_avg:.4f}")
|
|
|
|
num_processed_samples = 0
|
|
with torch.cuda.amp.autocast(enabled=args.mixed_precision, dtype=torch.float16):
|
|
for blob in metric_logger.log_every(val_loader, print_freq, header):
|
|
image_left, image_right, disp_gt, valid_disp_mask = (x.to(device) for x in blob)
|
|
padder = utils.InputPadder(image_left.shape, mode=padder_mode)
|
|
image_left, image_right = padder.pad(image_left, image_right)
|
|
|
|
disp_predictions = model(image_left, image_right, flow_init=None, num_iters=iterations)
|
|
disp_pred = disp_predictions[-1][:, :1, :, :]
|
|
disp_pred = padder.unpad(disp_pred)
|
|
|
|
metrics, _ = utils.compute_metrics(disp_pred, disp_gt, valid_disp_mask, metrics=logger.meters.keys())
|
|
num_processed_samples += image_left.shape[0]
|
|
for name in metrics:
|
|
logger.meters[name].update(metrics[name], n=1)
|
|
|
|
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
|
|
|
|
print("Num_processed_samples: ", num_processed_samples)
|
|
if (
|
|
hasattr(val_loader.dataset, "__len__")
|
|
and len(val_loader.dataset) != num_processed_samples
|
|
and torch.distributed.get_rank() == 0
|
|
):
|
|
warnings.warn(
|
|
f"Number of processed samples {num_processed_samples} is different"
|
|
f"from the dataset size {len(val_loader.dataset)}. This may happen if"
|
|
"the dataset is not divisible by the batch size. Try lowering the batch size or GPU number for more accurate results."
|
|
)
|
|
|
|
if writer is not None and args.rank == 0:
|
|
for meter_name, meter_value in logger.meters.items():
|
|
scalar_name = f"{meter_name} {header}"
|
|
writer.add_scalar(scalar_name, meter_value.avg, step)
|
|
|
|
logger.synchronize_between_processes()
|
|
print(header, logger)
|
|
|
|
|
|
def make_eval_loader(dataset_name: str, args: argparse.Namespace) -> torch.utils.data.DataLoader:
|
|
if args.weights:
|
|
weights = torchvision.models.get_weight(args.weights)
|
|
trans = weights.transforms()
|
|
|
|
def preprocessing(image_left, image_right, disp, valid_disp_mask):
|
|
C_o, H_o, W_o = get_dimensions(image_left)
|
|
image_left, image_right = trans(image_left, image_right)
|
|
|
|
C_t, H_t, W_t = get_dimensions(image_left)
|
|
scale_factor = W_t / W_o
|
|
|
|
if disp is not None and not isinstance(disp, torch.Tensor):
|
|
disp = torch.from_numpy(disp)
|
|
if W_t != W_o:
|
|
disp = resize(disp, (H_t, W_t), mode=InterpolationMode.BILINEAR) * scale_factor
|
|
if valid_disp_mask is not None and not isinstance(valid_disp_mask, torch.Tensor):
|
|
valid_disp_mask = torch.from_numpy(valid_disp_mask)
|
|
if W_t != W_o:
|
|
valid_disp_mask = resize(valid_disp_mask, (H_t, W_t), mode=InterpolationMode.NEAREST)
|
|
return image_left, image_right, disp, valid_disp_mask
|
|
|
|
else:
|
|
preprocessing = make_eval_transform(args)
|
|
|
|
val_dataset = make_dataset(dataset_name, args.dataset_root, transforms=preprocessing)
|
|
if args.distributed:
|
|
sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=False)
|
|
else:
|
|
sampler = torch.utils.data.SequentialSampler(val_dataset)
|
|
|
|
val_loader = torch.utils.data.DataLoader(
|
|
val_dataset,
|
|
sampler=sampler,
|
|
batch_size=args.batch_size,
|
|
pin_memory=True,
|
|
num_workers=args.workers,
|
|
)
|
|
|
|
return val_loader
|
|
|
|
|
|
def evaluate(model, loaders, args, writer=None, step=None):
|
|
for loader_name, loader in loaders.items():
|
|
_evaluate(
|
|
model,
|
|
args,
|
|
loader,
|
|
iterations=args.recurrent_updates,
|
|
padder_mode=args.padder_type,
|
|
header=f"{loader_name} evaluation",
|
|
batch_size=args.batch_size,
|
|
writer=writer,
|
|
step=step,
|
|
)
|
|
|
|
|
|
def run(model, optimizer, scheduler, train_loader, val_loaders, logger, writer, scaler, args):
|
|
device = torch.device(args.device)
|
|
# wrap the loader in a logger
|
|
loader = iter(logger.log_every(train_loader))
|
|
# output channels
|
|
model_out_channels = model.module.output_channels if args.distributed else model.output_channels
|
|
|
|
torch.set_num_threads(args.threads)
|
|
|
|
sequence_criterion = utils.SequenceLoss(
|
|
gamma=args.gamma,
|
|
max_flow=args.max_disparity,
|
|
exclude_large_flows=args.flow_loss_exclude_large,
|
|
).to(device)
|
|
|
|
if args.consistency_weight:
|
|
consistency_criterion = utils.FlowSequenceConsistencyLoss(
|
|
args.gamma,
|
|
resize_factor=0.25,
|
|
rescale_factor=0.25,
|
|
rescale_mode="bilinear",
|
|
).to(device)
|
|
else:
|
|
consistency_criterion = None
|
|
|
|
if args.psnr_weight:
|
|
psnr_criterion = utils.PSNRLoss().to(device)
|
|
else:
|
|
psnr_criterion = None
|
|
|
|
if args.smoothness_weight:
|
|
smoothness_criterion = utils.SmoothnessLoss().to(device)
|
|
else:
|
|
smoothness_criterion = None
|
|
|
|
if args.photometric_weight:
|
|
photometric_criterion = utils.FlowPhotoMetricLoss(
|
|
ssim_weight=args.photometric_ssim_weight,
|
|
max_displacement_ratio=args.photometric_max_displacement_ratio,
|
|
ssim_use_padding=False,
|
|
).to(device)
|
|
else:
|
|
photometric_criterion = None
|
|
|
|
for step in range(args.start_step + 1, args.total_iterations + 1):
|
|
data_blob = next(loader)
|
|
optimizer.zero_grad()
|
|
|
|
# unpack the data blob
|
|
image_left, image_right, disp_mask, valid_disp_mask = (x.to(device) for x in data_blob)
|
|
with torch.cuda.amp.autocast(enabled=args.mixed_precision, dtype=torch.float16):
|
|
disp_predictions = model(image_left, image_right, flow_init=None, num_iters=args.recurrent_updates)
|
|
# different models have different outputs, make sure we get the right ones for this task
|
|
disp_predictions = make_stereo_flow(disp_predictions, model_out_channels)
|
|
# should the architecture or training loop require it, we have to adjust the disparity mask
|
|
# target to possibly look like an optical flow mask
|
|
disp_mask = make_stereo_flow(disp_mask, model_out_channels)
|
|
# sequence loss on top of the model outputs
|
|
|
|
loss = sequence_criterion(disp_predictions, disp_mask, valid_disp_mask) * args.flow_loss_weight
|
|
|
|
if args.consistency_weight > 0:
|
|
loss_consistency = consistency_criterion(disp_predictions)
|
|
loss += loss_consistency * args.consistency_weight
|
|
|
|
if args.psnr_weight > 0:
|
|
loss_psnr = 0.0
|
|
for pred in disp_predictions:
|
|
# predictions might have 2 channels
|
|
loss_psnr += psnr_criterion(
|
|
pred * valid_disp_mask.unsqueeze(1),
|
|
disp_mask * valid_disp_mask.unsqueeze(1),
|
|
).mean() # mean the psnr loss over the batch
|
|
loss += loss_psnr / len(disp_predictions) * args.psnr_weight
|
|
|
|
if args.photometric_weight > 0:
|
|
loss_photometric = 0.0
|
|
for pred in disp_predictions:
|
|
# predictions might have 1 channel, therefore we need to inpute 0s for the second channel
|
|
if model_out_channels == 1:
|
|
pred = torch.cat([pred, torch.zeros_like(pred)], dim=1)
|
|
|
|
loss_photometric += photometric_criterion(
|
|
image_left, image_right, pred, valid_disp_mask
|
|
) # photometric loss already comes out meaned over the batch
|
|
loss += loss_photometric / len(disp_predictions) * args.photometric_weight
|
|
|
|
if args.smoothness_weight > 0:
|
|
loss_smoothness = 0.0
|
|
for pred in disp_predictions:
|
|
# predictions might have 2 channels
|
|
loss_smoothness += smoothness_criterion(
|
|
image_left, pred[:, :1, :, :]
|
|
).mean() # mean the smoothness loss over the batch
|
|
loss += loss_smoothness / len(disp_predictions) * args.smoothness_weight
|
|
|
|
with torch.no_grad():
|
|
metrics, _ = utils.compute_metrics(
|
|
disp_predictions[-1][:, :1, :, :], # predictions might have 2 channels
|
|
disp_mask[:, :1, :, :], # so does the ground truth
|
|
valid_disp_mask,
|
|
args.metrics,
|
|
)
|
|
|
|
metrics.pop("fl-all", None)
|
|
logger.update(loss=loss, **metrics)
|
|
|
|
if scaler is not None:
|
|
scaler.scale(loss).backward()
|
|
scaler.unscale_(optimizer)
|
|
if args.clip_grad_norm:
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad_norm)
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
else:
|
|
loss.backward()
|
|
if args.clip_grad_norm:
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad_norm)
|
|
optimizer.step()
|
|
|
|
scheduler.step()
|
|
|
|
if not dist.is_initialized() or dist.get_rank() == 0:
|
|
if writer is not None and step % args.tensorboard_log_frequency == 0:
|
|
# log the loss and metrics to tensorboard
|
|
|
|
writer.add_scalar("loss", loss, step)
|
|
for name, value in logger.meters.items():
|
|
writer.add_scalar(name, value.avg, step)
|
|
# log the images to tensorboard
|
|
pred_grid = visualization.make_training_sample_grid(
|
|
image_left, image_right, disp_mask, valid_disp_mask, disp_predictions
|
|
)
|
|
writer.add_image("predictions", pred_grid, step, dataformats="HWC")
|
|
|
|
# second thing we want to see is how relevant the iterative refinement is
|
|
pred_sequence_grid = visualization.make_disparity_sequence_grid(disp_predictions, disp_mask)
|
|
writer.add_image("sequence", pred_sequence_grid, step, dataformats="HWC")
|
|
|
|
if step % args.save_frequency == 0:
|
|
if not args.distributed or args.rank == 0:
|
|
model_without_ddp = (
|
|
model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
|
|
)
|
|
checkpoint = {
|
|
"model": model_without_ddp.state_dict(),
|
|
"optimizer": optimizer.state_dict(),
|
|
"scheduler": scheduler.state_dict(),
|
|
"step": step,
|
|
"args": args,
|
|
}
|
|
os.makedirs(args.checkpoint_dir, exist_ok=True)
|
|
torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}_{step}.pth")
|
|
torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}.pth")
|
|
|
|
if step % args.valid_frequency == 0:
|
|
evaluate(model, val_loaders, args, writer, step)
|
|
model.train()
|
|
if args.freeze_batch_norm:
|
|
if isinstance(model, nn.parallel.DistributedDataParallel):
|
|
freeze_batch_norm(model.module)
|
|
else:
|
|
freeze_batch_norm(model)
|
|
|
|
# one final save at the end
|
|
if not args.distributed or args.rank == 0:
|
|
model_without_ddp = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
|
|
checkpoint = {
|
|
"model": model_without_ddp.state_dict(),
|
|
"optimizer": optimizer.state_dict(),
|
|
"scheduler": scheduler.state_dict(),
|
|
"step": step,
|
|
"args": args,
|
|
}
|
|
os.makedirs(args.checkpoint_dir, exist_ok=True)
|
|
torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}_{step}.pth")
|
|
torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}.pth")
|
|
|
|
|
|
def main(args):
|
|
args.total_iterations = sum(args.dataset_steps)
|
|
|
|
# initialize DDP setting
|
|
utils.setup_ddp(args)
|
|
print(args)
|
|
|
|
args.test_only = args.train_datasets is None
|
|
|
|
# set the appropriate devices
|
|
if args.distributed and args.device == "cpu":
|
|
raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun")
|
|
device = torch.device(args.device)
|
|
|
|
# select model architecture
|
|
model = torchvision.prototype.models.depth.stereo.__dict__[args.model](weights=args.weights)
|
|
|
|
# convert to DDP if need be
|
|
if args.distributed:
|
|
model = model.to(args.gpu)
|
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
|
model_without_ddp = model.module
|
|
else:
|
|
model.to(device)
|
|
model_without_ddp = model
|
|
|
|
os.makedirs(args.checkpoint_dir, exist_ok=True)
|
|
|
|
val_loaders = {name: make_eval_loader(name, args) for name in args.test_datasets}
|
|
|
|
# EVAL ONLY configurations
|
|
if args.test_only:
|
|
evaluate(model, val_loaders, args)
|
|
return
|
|
|
|
# Sanity check for the parameter count
|
|
print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
|
|
|
# Compose the training dataset
|
|
train_dataset = get_train_dataset(args.dataset_root, args)
|
|
|
|
# initialize the optimizer
|
|
if args.optimizer == "adam":
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
|
elif args.optimizer == "sgd":
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9)
|
|
else:
|
|
raise ValueError(f"Unknown optimizer {args.optimizer}. Please choose between adam and sgd")
|
|
|
|
# initialize the learning rate schedule
|
|
scheduler = make_lr_schedule(args, optimizer)
|
|
|
|
# load them from checkpoint if needed
|
|
args.start_step = 0
|
|
if args.resume_path is not None:
|
|
checkpoint = torch.load(args.resume_path, map_location="cpu", weights_only=True)
|
|
if "model" in checkpoint:
|
|
# this means the user requested to resume from a training checkpoint
|
|
model_without_ddp.load_state_dict(checkpoint["model"])
|
|
# this means the user wants to continue training from where it was left off
|
|
if args.resume_schedule:
|
|
optimizer.load_state_dict(checkpoint["optimizer"])
|
|
scheduler.load_state_dict(checkpoint["scheduler"])
|
|
args.start_step = checkpoint["step"] + 1
|
|
# modify starting point of the dat
|
|
sample_start_step = args.start_step * args.batch_size * args.world_size
|
|
train_dataset = train_dataset[sample_start_step:]
|
|
|
|
else:
|
|
# this means the user wants to finetune on top of a model state dict
|
|
# and that no other changes are required
|
|
model_without_ddp.load_state_dict(checkpoint)
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
# enable training mode
|
|
model.train()
|
|
if args.freeze_batch_norm:
|
|
freeze_batch_norm(model_without_ddp)
|
|
|
|
# put dataloader on top of the dataset
|
|
# make sure to disable shuffling since the dataset is already shuffled
|
|
# in order to guarantee quasi randomness whilst retaining a deterministic
|
|
# dataset consumption order
|
|
if args.distributed:
|
|
# the train dataset is preshuffled in order to respect the iteration order
|
|
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=False, drop_last=True)
|
|
else:
|
|
# the train dataset is already shuffled, so we can use a simple SequentialSampler
|
|
sampler = torch.utils.data.SequentialSampler(train_dataset)
|
|
|
|
train_loader = torch.utils.data.DataLoader(
|
|
train_dataset,
|
|
sampler=sampler,
|
|
batch_size=args.batch_size,
|
|
pin_memory=True,
|
|
num_workers=args.workers,
|
|
)
|
|
|
|
# initialize the logger
|
|
if args.tensorboard_summaries:
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
tensorboard_path = Path(args.checkpoint_dir) / "tensorboard"
|
|
os.makedirs(tensorboard_path, exist_ok=True)
|
|
|
|
tensorboard_run = tensorboard_path / f"{args.name}"
|
|
writer = SummaryWriter(tensorboard_run)
|
|
else:
|
|
writer = None
|
|
|
|
logger = utils.MetricLogger(delimiter=" ")
|
|
|
|
scaler = torch.cuda.amp.GradScaler() if args.mixed_precision else None
|
|
# run the training loop
|
|
# this will perform optimization, respectively logging and saving checkpoints
|
|
# when need be
|
|
run(
|
|
model=model,
|
|
optimizer=optimizer,
|
|
scheduler=scheduler,
|
|
train_loader=train_loader,
|
|
val_loaders=val_loaders,
|
|
logger=logger,
|
|
writer=writer,
|
|
scaler=scaler,
|
|
args=args,
|
|
)
|
|
|
|
|
|
def get_args_parser(add_help=True):
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="PyTorch Stereo Matching Training", add_help=add_help)
|
|
# checkpointing
|
|
parser.add_argument("--name", default="crestereo", help="name of the experiment")
|
|
parser.add_argument("--resume", type=str, default=None, help="from which checkpoint to resume")
|
|
parser.add_argument("--checkpoint-dir", type=str, default="checkpoints", help="path to the checkpoint directory")
|
|
|
|
# dataset
|
|
parser.add_argument("--dataset-root", type=str, default="", help="path to the dataset root directory")
|
|
parser.add_argument(
|
|
"--train-datasets",
|
|
type=str,
|
|
nargs="+",
|
|
default=["crestereo"],
|
|
help="dataset(s) to train on",
|
|
choices=list(VALID_DATASETS.keys()),
|
|
)
|
|
parser.add_argument(
|
|
"--dataset-steps", type=int, nargs="+", default=[300_000], help="number of steps for each dataset"
|
|
)
|
|
parser.add_argument(
|
|
"--steps-is-epochs", action="store_true", help="if set, dataset-steps are interpreted as epochs"
|
|
)
|
|
parser.add_argument(
|
|
"--test-datasets",
|
|
type=str,
|
|
nargs="+",
|
|
default=["middlebury2014-train"],
|
|
help="dataset(s) to test on",
|
|
choices=["middlebury2014-train"],
|
|
)
|
|
parser.add_argument("--dataset-shuffle", type=bool, help="shuffle the dataset", default=True)
|
|
parser.add_argument("--dataset-order-shuffle", type=bool, help="shuffle the dataset order", default=True)
|
|
parser.add_argument("--batch-size", type=int, default=2, help="batch size per GPU")
|
|
parser.add_argument("--workers", type=int, default=4, help="number of workers per GPU")
|
|
parser.add_argument(
|
|
"--threads",
|
|
type=int,
|
|
default=16,
|
|
help="number of CPU threads per GPU. This can be changed around to speed-up transforms if needed. This can lead to worker thread contention so use with care.",
|
|
)
|
|
|
|
# model architecture
|
|
parser.add_argument(
|
|
"--model",
|
|
type=str,
|
|
default="crestereo_base",
|
|
help="model architecture",
|
|
choices=["crestereo_base", "raft_stereo"],
|
|
)
|
|
parser.add_argument("--recurrent-updates", type=int, default=10, help="number of recurrent updates")
|
|
parser.add_argument("--freeze-batch-norm", action="store_true", help="freeze batch norm parameters")
|
|
|
|
# loss parameters
|
|
parser.add_argument("--gamma", type=float, default=0.8, help="gamma parameter for the flow sequence loss")
|
|
parser.add_argument("--flow-loss-weight", type=float, default=1.0, help="weight for the flow loss")
|
|
parser.add_argument(
|
|
"--flow-loss-exclude-large",
|
|
action="store_true",
|
|
help="exclude large flow values from the loss. A large value is defined as a value greater than the ground truth flow norm",
|
|
default=False,
|
|
)
|
|
parser.add_argument("--consistency-weight", type=float, default=0.0, help="consistency loss weight")
|
|
parser.add_argument(
|
|
"--consistency-resize-factor",
|
|
type=float,
|
|
default=0.25,
|
|
help="consistency loss resize factor to account for the fact that the flow is computed on a downsampled image",
|
|
)
|
|
parser.add_argument("--psnr-weight", type=float, default=0.0, help="psnr loss weight")
|
|
parser.add_argument("--smoothness-weight", type=float, default=0.0, help="smoothness loss weight")
|
|
parser.add_argument("--photometric-weight", type=float, default=0.0, help="photometric loss weight")
|
|
parser.add_argument(
|
|
"--photometric-max-displacement-ratio",
|
|
type=float,
|
|
default=0.15,
|
|
help="Only pixels with a displacement smaller than this ratio of the image width will be considered for the photometric loss",
|
|
)
|
|
parser.add_argument("--photometric-ssim-weight", type=float, default=0.85, help="photometric ssim loss weight")
|
|
|
|
# transforms parameters
|
|
parser.add_argument("--gpu-transforms", action="store_true", help="use GPU transforms")
|
|
parser.add_argument(
|
|
"--eval-size", type=int, nargs="+", default=[384, 512], help="size of the images for evaluation"
|
|
)
|
|
parser.add_argument("--resize-size", type=int, nargs=2, default=None, help="resize size")
|
|
parser.add_argument("--crop-size", type=int, nargs=2, default=[384, 512], help="crop size")
|
|
parser.add_argument("--scale-range", type=float, nargs=2, default=[0.6, 1.0], help="random scale range")
|
|
parser.add_argument("--rescale-prob", type=float, default=1.0, help="probability of resizing the image")
|
|
parser.add_argument(
|
|
"--scaling-type", type=str, default="linear", help="scaling type", choices=["exponential", "linear"]
|
|
)
|
|
parser.add_argument("--flip-prob", type=float, default=0.5, help="probability of flipping the image")
|
|
parser.add_argument(
|
|
"--norm-mean", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="mean for image normalization"
|
|
)
|
|
parser.add_argument(
|
|
"--norm-std", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="std for image normalization"
|
|
)
|
|
parser.add_argument(
|
|
"--use-grayscale", action="store_true", help="use grayscale images instead of RGB", default=False
|
|
)
|
|
parser.add_argument("--max-disparity", type=float, default=None, help="maximum disparity")
|
|
parser.add_argument(
|
|
"--interpolation-strategy",
|
|
type=str,
|
|
default="bilinear",
|
|
help="interpolation strategy",
|
|
choices=["bilinear", "bicubic", "mixed"],
|
|
)
|
|
parser.add_argument("--spatial-shift-prob", type=float, default=1.0, help="probability of shifting the image")
|
|
parser.add_argument(
|
|
"--spatial-shift-max-angle", type=float, default=0.1, help="maximum angle for the spatial shift"
|
|
)
|
|
parser.add_argument(
|
|
"--spatial-shift-max-displacement", type=float, default=2.0, help="maximum displacement for the spatial shift"
|
|
)
|
|
parser.add_argument("--gamma-range", type=float, nargs="+", default=[0.8, 1.2], help="range for gamma correction")
|
|
parser.add_argument(
|
|
"--brightness-range", type=float, nargs="+", default=[0.8, 1.2], help="range for brightness correction"
|
|
)
|
|
parser.add_argument(
|
|
"--contrast-range", type=float, nargs="+", default=[0.8, 1.2], help="range for contrast correction"
|
|
)
|
|
parser.add_argument(
|
|
"--saturation-range", type=float, nargs="+", default=0.0, help="range for saturation correction"
|
|
)
|
|
parser.add_argument("--hue-range", type=float, nargs="+", default=0.0, help="range for hue correction")
|
|
parser.add_argument(
|
|
"--asymmetric-jitter-prob",
|
|
type=float,
|
|
default=1.0,
|
|
help="probability of using asymmetric jitter instead of symmetric jitter",
|
|
)
|
|
parser.add_argument("--occlusion-prob", type=float, default=0.5, help="probability of occluding the rightimage")
|
|
parser.add_argument(
|
|
"--occlusion-px-range", type=int, nargs="+", default=[50, 100], help="range for the number of occluded pixels"
|
|
)
|
|
parser.add_argument("--erase-prob", type=float, default=0.0, help="probability of erasing in both images")
|
|
parser.add_argument(
|
|
"--erase-px-range", type=int, nargs="+", default=[50, 100], help="range for the number of erased pixels"
|
|
)
|
|
parser.add_argument(
|
|
"--erase-num-repeats", type=int, default=1, help="number of times to repeat the erase operation"
|
|
)
|
|
|
|
# optimizer parameters
|
|
parser.add_argument("--optimizer", type=str, default="adam", help="optimizer", choices=["adam", "sgd"])
|
|
parser.add_argument("--lr", type=float, default=4e-4, help="learning rate")
|
|
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay")
|
|
parser.add_argument("--clip-grad-norm", type=float, default=0.0, help="clip grad norm")
|
|
|
|
# lr_scheduler parameters
|
|
parser.add_argument("--min-lr", type=float, default=2e-5, help="minimum learning rate")
|
|
parser.add_argument("--warmup-steps", type=int, default=6_000, help="number of warmup steps")
|
|
parser.add_argument(
|
|
"--decay-after-steps", type=int, default=180_000, help="number of steps after which to start decay the lr"
|
|
)
|
|
parser.add_argument(
|
|
"--lr-warmup-method", type=str, default="linear", help="warmup method", choices=["linear", "cosine"]
|
|
)
|
|
parser.add_argument("--lr-warmup-factor", type=float, default=0.02, help="warmup factor for the learning rate")
|
|
parser.add_argument(
|
|
"--lr-decay-method",
|
|
type=str,
|
|
default="linear",
|
|
help="decay method",
|
|
choices=["linear", "cosine", "exponential"],
|
|
)
|
|
parser.add_argument("--lr-decay-gamma", type=float, default=0.8, help="decay factor for the learning rate")
|
|
|
|
# deterministic behaviour
|
|
parser.add_argument("--seed", type=int, default=42, help="seed for random number generators")
|
|
|
|
# mixed precision training
|
|
parser.add_argument("--mixed-precision", action="store_true", help="use mixed precision training")
|
|
|
|
# logging
|
|
parser.add_argument("--tensorboard-summaries", action="store_true", help="log to tensorboard")
|
|
parser.add_argument("--tensorboard-log-frequency", type=int, default=100, help="log frequency")
|
|
parser.add_argument("--save-frequency", type=int, default=1_000, help="save frequency")
|
|
parser.add_argument("--valid-frequency", type=int, default=1_000, help="validation frequency")
|
|
parser.add_argument(
|
|
"--metrics",
|
|
type=str,
|
|
nargs="+",
|
|
default=["mae", "rmse", "1px", "3px", "5px", "relepe"],
|
|
help="metrics to log",
|
|
choices=AVAILABLE_METRICS,
|
|
)
|
|
|
|
# distributed parameters
|
|
parser.add_argument("--world-size", type=int, default=8, help="number of distributed processes")
|
|
parser.add_argument("--dist-url", type=str, default="env://", help="url used to set up distributed training")
|
|
parser.add_argument("--device", type=str, default="cuda", help="device to use for training")
|
|
|
|
# weights API
|
|
parser.add_argument("--weights", type=str, default=None, help="weights API url")
|
|
parser.add_argument(
|
|
"--resume-path", type=str, default=None, help="a path from which to resume or start fine-tuning"
|
|
)
|
|
parser.add_argument("--resume-schedule", action="store_true", help="resume optimizer state")
|
|
|
|
# padder parameters
|
|
parser.add_argument("--padder-type", type=str, default="kitti", help="padder type", choices=["kitti", "sintel"])
|
|
return parser
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = get_args_parser().parse_args()
|
|
main(args)
|