sglang_v0.5.2/vision_0.22.1/references/optical_flow/train.py

390 lines
15 KiB
Python

import argparse
import warnings
from math import ceil
from pathlib import Path
import torch
import torchvision.models.optical_flow
import utils
from presets import OpticalFlowPresetEval, OpticalFlowPresetTrain
from torchvision.datasets import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
def get_train_dataset(stage, dataset_root):
if stage == "chairs":
transforms = OpticalFlowPresetTrain(crop_size=(368, 496), min_scale=0.1, max_scale=1.0, do_flip=True)
return FlyingChairs(root=dataset_root, split="train", transforms=transforms)
elif stage == "things":
transforms = OpticalFlowPresetTrain(crop_size=(400, 720), min_scale=-0.4, max_scale=0.8, do_flip=True)
return FlyingThings3D(root=dataset_root, split="train", pass_name="both", transforms=transforms)
elif stage == "sintel_SKH": # S + K + H as from paper
crop_size = (368, 768)
transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.2, max_scale=0.6, do_flip=True)
things_clean = FlyingThings3D(root=dataset_root, split="train", pass_name="clean", transforms=transforms)
sintel = Sintel(root=dataset_root, split="train", pass_name="both", transforms=transforms)
kitti_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.3, max_scale=0.5, do_flip=True)
kitti = KittiFlow(root=dataset_root, split="train", transforms=kitti_transforms)
hd1k_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.5, max_scale=0.2, do_flip=True)
hd1k = HD1K(root=dataset_root, split="train", transforms=hd1k_transforms)
# As future improvement, we could probably be using a distributed sampler here
# The distribution is S(.71), T(.135), K(.135), H(.02)
return 100 * sintel + 200 * kitti + 5 * hd1k + things_clean
elif stage == "kitti":
transforms = OpticalFlowPresetTrain(
# resize and crop params
crop_size=(288, 960),
min_scale=-0.2,
max_scale=0.4,
stretch_prob=0,
# flip params
do_flip=False,
# jitter params
brightness=0.3,
contrast=0.3,
saturation=0.3,
hue=0.3 / 3.14,
asymmetric_jitter_prob=0,
)
return KittiFlow(root=dataset_root, split="train", transforms=transforms)
else:
raise ValueError(f"Unknown stage {stage}")
@torch.no_grad()
def _evaluate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None):
"""Helper function to compute various metrics (epe, etc.) for a model on a given dataset.
We process as many samples as possible with ddp, and process the rest on a single worker.
"""
batch_size = batch_size or args.batch_size
device = torch.device(args.device)
model.eval()
if args.distributed:
sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
else:
sampler = torch.utils.data.SequentialSampler(val_dataset)
val_loader = torch.utils.data.DataLoader(
val_dataset,
sampler=sampler,
batch_size=batch_size,
pin_memory=True,
num_workers=args.workers,
)
num_flow_updates = num_flow_updates or args.num_flow_updates
def inner_loop(blob):
if blob[0].dim() == 3:
# input is not batched, so we add an extra dim for consistency
blob = [x[None, :, :, :] if x is not None else None for x in blob]
image1, image2, flow_gt = blob[:3]
valid_flow_mask = None if len(blob) == 3 else blob[-1]
image1, image2 = image1.to(device), image2.to(device)
padder = utils.InputPadder(image1.shape, mode=padder_mode)
image1, image2 = padder.pad(image1, image2)
flow_predictions = model(image1, image2, num_flow_updates=num_flow_updates)
flow_pred = flow_predictions[-1]
flow_pred = padder.unpad(flow_pred).cpu()
metrics, num_pixels_tot = utils.compute_metrics(flow_pred, flow_gt, valid_flow_mask)
# We compute per-pixel epe (epe) and per-image epe (called f1-epe in RAFT paper).
# per-pixel epe: average epe of all pixels of all images
# per-image epe: average epe on each image independently, then average over images
for name in ("epe", "1px", "3px", "5px", "f1"): # f1 is called f1-all in paper
logger.meters[name].update(metrics[name], n=num_pixels_tot)
logger.meters["per_image_epe"].update(metrics["epe"], n=batch_size)
logger = utils.MetricLogger()
for meter_name in ("epe", "1px", "3px", "5px", "per_image_epe", "f1"):
logger.add_meter(meter_name, fmt="{global_avg:.4f}")
num_processed_samples = 0
for blob in logger.log_every(val_loader, header=header, print_freq=None):
inner_loop(blob)
num_processed_samples += blob[0].shape[0] # batch size
if args.distributed:
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
print(
f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. "
"Going to process the remaining samples individually, if any."
)
if args.rank == 0: # we only need to process the rest on a single worker
for i in range(num_processed_samples, len(val_dataset)):
inner_loop(val_dataset[i])
logger.synchronize_between_processes()
print(header, logger)
def evaluate(model, args):
val_datasets = args.val_dataset or []
if args.weights and args.test_only:
weights = torchvision.models.get_weight(args.weights)
trans = weights.transforms()
def preprocessing(img1, img2, flow, valid_flow_mask):
img1, img2 = trans(img1, img2)
if flow is not None and not isinstance(flow, torch.Tensor):
flow = torch.from_numpy(flow)
if valid_flow_mask is not None and not isinstance(valid_flow_mask, torch.Tensor):
valid_flow_mask = torch.from_numpy(valid_flow_mask)
return img1, img2, flow, valid_flow_mask
else:
preprocessing = OpticalFlowPresetEval()
for name in val_datasets:
if name == "kitti":
# Kitti has different image sizes, so we need to individually pad them, we can't batch.
# see comment in InputPadder
if args.batch_size != 1 and (not args.distributed or args.rank == 0):
warnings.warn(
f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1."
)
val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=preprocessing)
_evaluate(
model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1
)
elif name == "sintel":
for pass_name in ("clean", "final"):
val_dataset = Sintel(
root=args.dataset_root, split="train", pass_name=pass_name, transforms=preprocessing
)
_evaluate(
model,
args,
val_dataset,
num_flow_updates=32,
padder_mode="sintel",
header=f"Sintel val {pass_name}",
)
else:
warnings.warn(f"Can't validate on {val_dataset}, skipping.")
def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args):
device = torch.device(args.device)
for data_blob in logger.log_every(train_loader):
optimizer.zero_grad()
image1, image2, flow_gt, valid_flow_mask = (x.to(device) for x in data_blob)
flow_predictions = model(image1, image2, num_flow_updates=args.num_flow_updates)
loss = utils.sequence_loss(flow_predictions, flow_gt, valid_flow_mask, args.gamma)
metrics, _ = utils.compute_metrics(flow_predictions[-1], flow_gt, valid_flow_mask)
metrics.pop("f1")
logger.update(loss=loss, **metrics)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
optimizer.step()
scheduler.step()
def main(args):
utils.setup_ddp(args)
args.test_only = args.train_dataset is None
if args.distributed and args.device == "cpu":
raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun")
device = torch.device(args.device)
if args.use_deterministic_algorithms:
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
else:
torch.backends.cudnn.benchmark = True
model = torchvision.models.get_model(args.model, weights=args.weights)
if args.distributed:
model = model.to(args.local_rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
model_without_ddp = model.module
else:
model.to(device)
model_without_ddp = model
if args.resume is not None:
checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
model_without_ddp.load_state_dict(checkpoint["model"])
if args.test_only:
# Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
evaluate(model, args)
return
print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
train_dataset = get_train_dataset(args.train_dataset, args.dataset_root)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer=optimizer,
max_lr=args.lr,
epochs=args.epochs,
steps_per_epoch=ceil(len(train_dataset) / (args.world_size * args.batch_size)),
pct_start=0.05,
cycle_momentum=False,
anneal_strategy="linear",
)
if args.resume is not None:
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
args.start_epoch = checkpoint["epoch"] + 1
else:
args.start_epoch = 0
torch.backends.cudnn.benchmark = True
model.train()
if args.freeze_batch_norm:
utils.freeze_batch_norm(model.module)
if args.distributed:
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True)
else:
sampler = torch.utils.data.RandomSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset,
sampler=sampler,
batch_size=args.batch_size,
pin_memory=True,
num_workers=args.workers,
)
logger = utils.MetricLogger()
done = False
for epoch in range(args.start_epoch, args.epochs):
print(f"EPOCH {epoch}")
if args.distributed:
# needed on distributed mode, otherwise the data loading order would be the same for all epochs
sampler.set_epoch(epoch)
train_one_epoch(
model=model,
optimizer=optimizer,
scheduler=scheduler,
train_loader=train_loader,
logger=logger,
args=args,
)
# Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
print(f"Epoch {epoch} done. ", logger)
if not args.distributed or args.rank == 0:
checkpoint = {
"model": model_without_ddp.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"epoch": epoch,
"args": args,
}
torch.save(checkpoint, Path(args.output_dir) / f"{args.name}_{epoch}.pth")
torch.save(checkpoint, Path(args.output_dir) / f"{args.name}.pth")
if epoch % args.val_freq == 0 or done:
evaluate(model, args)
model.train()
if args.freeze_batch_norm:
utils.freeze_batch_norm(model.module)
def get_args_parser(add_help=True):
parser = argparse.ArgumentParser(add_help=add_help, description="Train or evaluate an optical-flow model.")
parser.add_argument(
"--name",
default="raft",
type=str,
help="The name of the experiment - determines the name of the files where weights are saved.",
)
parser.add_argument("--output-dir", default=".", type=str, help="Output dir where checkpoints will be stored.")
parser.add_argument(
"--resume",
type=str,
help="A path to previously saved weights. Used to re-start training from, or evaluate a pre-saved model.",
)
parser.add_argument("--workers", type=int, default=12, help="Number of workers for the data loading part.")
parser.add_argument(
"--train-dataset",
type=str,
help="The dataset to use for training. If not passed, only validation is performed (and you probably want to pass --resume).",
)
parser.add_argument("--val-dataset", type=str, nargs="+", help="The dataset(s) to use for validation.")
parser.add_argument("--val-freq", type=int, default=2, help="Validate every X epochs")
parser.add_argument("--epochs", type=int, default=20, help="The total number of epochs to train.")
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--lr", type=float, default=0.00002, help="Learning rate for AdamW optimizer")
parser.add_argument("--weight-decay", type=float, default=0.00005, help="Weight decay for AdamW optimizer")
parser.add_argument("--adamw-eps", type=float, default=1e-8, help="eps value for AdamW optimizer")
parser.add_argument(
"--freeze-batch-norm", action="store_true", help="Set BatchNorm modules of the model in eval mode."
)
parser.add_argument(
"--model", type=str, default="raft_large", help="The name of the model to use - either raft_large or raft_small"
)
# TODO: resume and weights should be in an exclusive arg group
parser.add_argument(
"--num_flow_updates",
type=int,
default=12,
help="number of updates (or 'iters') in the update operator of the model.",
)
parser.add_argument("--gamma", type=float, default=0.8, help="exponential weighting for loss. Must be < 1.")
parser.add_argument("--dist-url", default="env://", help="URL used to set up distributed training")
parser.add_argument(
"--dataset-root",
help="Root folder where the datasets are stored. Will be passed as the 'root' parameter of the datasets.",
required=True,
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu, Default: cuda)")
parser.add_argument(
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
)
return parser
if __name__ == "__main__":
args = get_args_parser().parse_args()
Path(args.output_dir).mkdir(exist_ok=True)
main(args)