sglang_v0.5.2/vision_0.22.1/references/depth/stereo/cascade_evaluation.py

300 lines
12 KiB
Python

import os
import warnings
import torch
import torchvision
import torchvision.prototype.models.depth.stereo
import utils
from torch.nn import functional as F
from train import make_eval_loader
from utils.metrics import AVAILABLE_METRICS
from visualization import make_prediction_image_side_to_side
def get_args_parser(add_help=True):
import argparse
parser = argparse.ArgumentParser(description="PyTorch Stereo Matching Evaluation", add_help=add_help)
parser.add_argument("--dataset", type=str, default="middlebury2014-train", help="dataset to use")
parser.add_argument("--dataset-root", type=str, default="", help="root of the dataset")
parser.add_argument("--checkpoint", type=str, default="", help="path to weights")
parser.add_argument("--weights", type=str, default=None, help="torchvision API weight")
parser.add_argument(
"--model",
type=str,
default="crestereo_base",
help="which model to use if not speciffying a training checkpoint",
)
parser.add_argument("--img-folder", type=str, default="images")
parser.add_argument("--batch-size", type=int, default=1, help="batch size")
parser.add_argument("--workers", type=int, default=0, help="number of workers")
parser.add_argument("--eval-size", type=int, nargs="+", default=[384, 512], help="resize size")
parser.add_argument(
"--norm-mean", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="mean for image normalization"
)
parser.add_argument(
"--norm-std", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="std for image normalization"
)
parser.add_argument(
"--use-grayscale", action="store_true", help="use grayscale images instead of RGB", default=False
)
parser.add_argument("--max-disparity", type=float, default=None, help="maximum disparity")
parser.add_argument(
"--interpolation-strategy",
type=str,
default="bilinear",
help="interpolation strategy",
choices=["bilinear", "bicubic", "mixed"],
)
parser.add_argument("--n_iterations", nargs="+", type=int, default=[10], help="number of recurent iterations")
parser.add_argument("--n_cascades", nargs="+", type=int, default=[1], help="number of cascades")
parser.add_argument(
"--metrics",
type=str,
nargs="+",
default=["mae", "rmse", "1px", "3px", "5px", "relepe"],
help="metrics to log",
choices=AVAILABLE_METRICS,
)
parser.add_argument("--mixed-precision", action="store_true", help="use mixed precision training")
parser.add_argument("--world-size", type=int, default=1, help="number of distributed processes")
parser.add_argument("--dist-url", type=str, default="env://", help="url used to set up distributed training")
parser.add_argument("--device", type=str, default="cuda", help="device to use for training")
parser.add_argument("--save-images", action="store_true", help="save images of the predictions")
parser.add_argument("--padder-type", type=str, default="kitti", help="padder type", choices=["kitti", "sintel"])
return parser
def cascade_inference(model, image_left, image_right, iterations, cascades):
# check that image size is divisible by 16 * (2 ** (cascades - 1))
for image in [image_left, image_right]:
if image.shape[-2] % ((2 ** (cascades - 1))) != 0:
raise ValueError(
f"image height is not divisible by {16 * (2 ** (cascades - 1))}. Image shape: {image.shape[-2]}"
)
if image.shape[-1] % ((2 ** (cascades - 1))) != 0:
raise ValueError(
f"image width is not divisible by {16 * (2 ** (cascades - 1))}. Image shape: {image.shape[-2]}"
)
left_image_pyramid = [image_left]
right_image_pyramid = [image_right]
for idx in range(0, cascades - 1):
ds_factor = int(2 ** (idx + 1))
ds_shape = (image_left.shape[-2] // ds_factor, image_left.shape[-1] // ds_factor)
left_image_pyramid += F.interpolate(image_left, size=ds_shape, mode="bilinear", align_corners=True).unsqueeze(0)
right_image_pyramid += F.interpolate(image_right, size=ds_shape, mode="bilinear", align_corners=True).unsqueeze(
0
)
flow_init = None
for left_image, right_image in zip(reversed(left_image_pyramid), reversed(right_image_pyramid)):
flow_pred = model(left_image, right_image, flow_init, num_iters=iterations)
# flow pred is a list
flow_init = flow_pred[-1]
return flow_init
@torch.inference_mode()
def _evaluate(
model,
args,
val_loader,
*,
padder_mode,
print_freq=10,
writer=None,
step=None,
iterations=10,
cascades=1,
batch_size=None,
header=None,
save_images=False,
save_path="",
):
"""Helper function to compute various metrics (epe, etc.) for a model on a given dataset.
We process as many samples as possible with ddp.
"""
model.eval()
header = header or "Test:"
device = torch.device(args.device)
metric_logger = utils.MetricLogger(delimiter=" ")
iterations = iterations or args.recurrent_updates
logger = utils.MetricLogger()
for meter_name in args.metrics:
logger.add_meter(meter_name, fmt="{global_avg:.4f}")
if "fl-all" not in args.metrics:
logger.add_meter("fl-all", fmt="{global_avg:.4f}")
num_processed_samples = 0
with torch.cuda.amp.autocast(enabled=args.mixed_precision, dtype=torch.float16):
batch_idx = 0
for blob in metric_logger.log_every(val_loader, print_freq, header):
image_left, image_right, disp_gt, valid_disp_mask = (x.to(device) for x in blob)
padder = utils.InputPadder(image_left.shape, mode=padder_mode)
image_left, image_right = padder.pad(image_left, image_right)
disp_pred = cascade_inference(model, image_left, image_right, iterations, cascades)
disp_pred = disp_pred[:, :1, :, :]
disp_pred = padder.unpad(disp_pred)
if save_images:
if args.distributed:
rank_prefix = args.rank
else:
rank_prefix = 0
make_prediction_image_side_to_side(
disp_pred, disp_gt, valid_disp_mask, save_path, prefix=f"batch_{rank_prefix}_{batch_idx}"
)
metrics, _ = utils.compute_metrics(disp_pred, disp_gt, valid_disp_mask, metrics=logger.meters.keys())
num_processed_samples += image_left.shape[0]
for name in metrics:
logger.meters[name].update(metrics[name], n=1)
batch_idx += 1
num_processed_samples = utils.reduce_across_processes(num_processed_samples) / args.world_size
print("Num_processed_samples: ", num_processed_samples)
if (
hasattr(val_loader.dataset, "__len__")
and len(val_loader.dataset) != num_processed_samples
and torch.distributed.get_rank() == 0
):
warnings.warn(
f"Number of processed samples {num_processed_samples} is different"
f"from the dataset size {len(val_loader.dataset)}. This may happen if"
"the dataset is not divisible by the batch size. Try lowering the batch size for more accurate results."
)
if writer is not None and args.rank == 0:
for meter_name, meter_value in logger.meters.items():
scalar_name = f"{meter_name} {header}"
writer.add_scalar(scalar_name, meter_value.avg, step)
logger.synchronize_between_processes()
print(header, logger)
logger_metrics = {k: v.global_avg for k, v in logger.meters.items()}
return logger_metrics
def evaluate(model, loader, args, writer=None, step=None):
os.makedirs(args.img_folder, exist_ok=True)
checkpoint_name = os.path.basename(args.checkpoint) or args.weights
image_checkpoint_folder = os.path.join(args.img_folder, checkpoint_name)
metrics = {}
base_image_folder = os.path.join(image_checkpoint_folder, args.dataset)
os.makedirs(base_image_folder, exist_ok=True)
for n_cascades in args.n_cascades:
for n_iters in args.n_iterations:
config = f"{n_cascades}c_{n_iters}i"
config_image_folder = os.path.join(base_image_folder, config)
os.makedirs(config_image_folder, exist_ok=True)
metrics[config] = _evaluate(
model,
args,
loader,
padder_mode=args.padder_type,
header=f"{args.dataset} evaluation@ size:{args.eval_size} n_cascades:{n_cascades} n_iters:{n_iters}",
batch_size=args.batch_size,
writer=writer,
step=step,
iterations=n_iters,
cascades=n_cascades,
save_path=config_image_folder,
save_images=args.save_images,
)
metric_log = []
metric_log_dict = {}
# print the final results
for config in metrics:
config_tokens = config.split("_")
config_iters = config_tokens[1][:-1]
config_cascades = config_tokens[0][:-1]
metric_log_dict[config_cascades] = metric_log_dict.get(config_cascades, {})
metric_log_dict[config_cascades][config_iters] = metrics[config]
evaluation_str = f"{args.dataset} evaluation@ size:{args.eval_size} n_cascades:{config_cascades} recurrent_updates:{config_iters}"
metrics_str = f"Metrics: {metrics[config]}"
metric_log.extend([evaluation_str, metrics_str])
print(evaluation_str)
print(metrics_str)
eval_log_name = f"{checkpoint_name.replace('.pth', '')}_eval.log"
print("Saving eval log to: ", eval_log_name)
with open(eval_log_name, "w") as f:
f.write(f"Dataset: {args.dataset} @size: {args.eval_size}:\n")
# write the dict line by line for each key, and each value in the keys
for config_cascades in metric_log_dict:
f.write("{\n")
f.write(f"\t{config_cascades}: {{\n")
for config_iters in metric_log_dict[config_cascades]:
# convert every metric to 4 decimal places
metrics = metric_log_dict[config_cascades][config_iters]
metrics = {k: float(f"{v:.3f}") for k, v in metrics.items()}
f.write(f"\t\t{config_iters}: {metrics}\n")
f.write("\t},\n")
f.write("}\n")
def load_checkpoint(args):
utils.setup_ddp(args)
if not args.weights:
checkpoint = torch.load(args.checkpoint, map_location=torch.device("cpu"), weights_only=True)
if "model" in checkpoint:
experiment_args = checkpoint["args"]
model = torchvision.prototype.models.depth.stereo.__dict__[experiment_args.model](weights=None)
model.load_state_dict(checkpoint["model"])
else:
model = torchvision.prototype.models.depth.stereo.__dict__[args.model](weights=None)
model.load_state_dict(checkpoint)
# set the appropriate devices
if args.distributed and args.device == "cpu":
raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun")
device = torch.device(args.device)
else:
model = torchvision.prototype.models.depth.stereo.__dict__[args.model](weights=args.weights)
# convert to DDP if need be
if args.distributed:
model = model.to(args.device)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
else:
model.to(device)
return model
def main(args):
model = load_checkpoint(args)
loader = make_eval_loader(args.dataset, args)
evaluate(model, loader, args)
if __name__ == "__main__":
args = get_args_parser().parse_args()
main(args)