mysora/tools/scoring/optical_flow/inference.py

247 lines
7.5 KiB
Python

import cv2 # isort:skip
import argparse
import gc
import os
from datetime import timedelta
import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
import torch.nn.functional as F
from einops import rearrange
from PIL import Image
from torch.utils.data import DataLoader, DistributedSampler
from torchvision.transforms.functional import pil_to_tensor
from tqdm import tqdm
# from tools.datasets.utils import extract_frames
from tools.scoring.optical_flow.unimatch import UniMatch
# torch.backends.cudnn.enabled = False # This line enables large batch, but the speed is similar
def extract_frames(
video_path,
frame_inds=None,
points=None,
backend="opencv",
return_length=False,
num_frames=None,
):
"""
Args:
video_path (str): path to video
frame_inds (List[int]): indices of frames to extract
points (List[float]): values within [0, 1); multiply #frames to get frame indices
Return:
List[PIL.Image]
"""
assert backend in ["av", "opencv", "decord"]
assert (frame_inds is None) or (points is None)
assert backend == "opencv"
cap = cv2.VideoCapture(video_path)
if num_frames is not None:
total_frames = num_frames
else:
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if points is not None:
frame_inds = [int(p * total_frames) for p in points]
frames = []
for idx in frame_inds:
if idx >= total_frames:
idx = total_frames - 1
success = cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
if not success:
break
try:
ret, frame = cap.read()
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
frames.append(frame)
except Exception:
continue
if return_length:
return frames, total_frames
return frames
def merge_scores(gathered_list: list, meta: pd.DataFrame, column):
# reorder
indices_list = list(map(lambda x: x[0], gathered_list))
scores_list = list(map(lambda x: x[1], gathered_list))
flat_indices = []
for x in zip(*indices_list):
flat_indices.extend(x)
flat_scores = []
for x in zip(*scores_list):
flat_scores.extend(x)
flat_indices = np.array(flat_indices)
flat_scores = np.array(flat_scores)
# filter duplicates
unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True)
meta.loc[unique_indices, column] = flat_scores[unique_indices_idx]
# drop indices in meta not in unique_indices
meta = meta.loc[unique_indices]
return meta
class VideoTextDataset(torch.utils.data.Dataset):
def __init__(self, meta_path, frame_inds=None):
self.meta_path = meta_path
self.meta = pd.read_csv(meta_path)
self.frame_inds = frame_inds
def __getitem__(self, index):
sample = self.meta.iloc[index]
path = sample["path"]
# extract frames
images = extract_frames(path, frame_inds=self.frame_inds, backend="opencv")
# transform
images = torch.stack([pil_to_tensor(x) for x in images])
# stack
# shape: [N, C, H, W]; dtype: torch.uint8
images = images.float()
H, W = images.shape[-2:]
if H > W:
images = rearrange(images, "N C H W -> N C W H")
images = F.interpolate(images, size=(320, 576), mode="bilinear", align_corners=True)
ret = dict(index=index, images=images)
return ret
def __len__(self):
return len(self.meta)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("meta_path", type=str, help="Path to the input CSV file")
parser.add_argument("--bs", type=int, default=1, help="Batch size") # don't use too large bs for unimatch
parser.add_argument("--num_workers", type=int, default=16, help="Number of workers")
parser.add_argument("--skip_if_existing", action="store_true")
args = parser.parse_args()
return args
@torch.no_grad()
def main():
args = parse_args()
meta_path = args.meta_path
if not os.path.exists(meta_path):
print(f"Meta file '{meta_path}' not found. Exit.")
exit()
wo_ext, ext = os.path.splitext(meta_path)
out_path = f"{wo_ext}_flow{ext}"
if args.skip_if_existing and os.path.exists(out_path):
print(f"Output meta file '{out_path}' already exists. Exit.")
exit()
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
# build model
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = UniMatch(
feature_channels=128,
num_scales=2,
upsample_factor=4,
num_head=1,
ffn_dim_expansion=4,
num_transformer_layers=6,
reg_refine=True,
task="flow",
)
ckpt = torch.load("./pretrained_models/unimatch/gmflow-scale2-regrefine6-mixdata-train320x576-4e7b215d.pth")
model.load_state_dict(ckpt["model"])
model = model.to(device)
# build dataset
NUM_FRAMES = 10
frames_inds = [15 * i for i in range(0, NUM_FRAMES)]
dataset = VideoTextDataset(meta_path=meta_path, frame_inds=frames_inds)
dataloader = DataLoader(
dataset,
batch_size=args.bs,
num_workers=args.num_workers,
sampler=DistributedSampler(
dataset,
num_replicas=dist.get_world_size(),
rank=dist.get_rank(),
shuffle=False,
drop_last=False,
),
)
# compute optical flow scores
indices_list = []
scores_list = []
model.eval()
for batch in tqdm(dataloader, disable=dist.get_rank() != 0):
indices = batch["index"]
images = batch["images"].to(device)
B = images.shape[0]
batch_0 = rearrange(images[:, :-1], "B N C H W -> (B N) C H W").contiguous()
batch_1 = rearrange(images[:, 1:], "B N C H W -> (B N) C H W").contiguous()
res = model(
batch_0,
batch_1,
attn_type="swin",
attn_splits_list=[2, 8],
corr_radius_list=[-1, 4],
prop_radius_list=[-1, 1],
num_reg_refine=6,
task="flow",
pred_bidir_flow=False,
)
flow_maps = res["flow_preds"][-1] # [B * (N-1), 2, H, W]
flow_maps = rearrange(flow_maps, "(B N) C H W -> B N H W C", B=B)
flow_scores = flow_maps.norm(dim=-1).mean(dim=[1, 2, 3]).cpu()
indices_list.extend(indices.tolist())
scores_list.extend(flow_scores.tolist())
# save local results
meta_local = merge_scores([(indices_list, scores_list)], dataset.meta, column="flow")
save_dir_local = os.path.join(os.path.dirname(out_path), "parts")
os.makedirs(save_dir_local, exist_ok=True)
out_path_local = os.path.join(
save_dir_local, os.path.basename(out_path).replace(".csv", f"_part_{dist.get_rank()}.csv")
)
meta_local.to_csv(out_path_local, index=False)
# wait for all ranks to finish data processing
dist.barrier()
torch.cuda.empty_cache()
gc.collect()
gathered_list = [None] * dist.get_world_size()
dist.all_gather_object(gathered_list, (indices_list, scores_list))
if dist.get_rank() == 0:
meta_new = merge_scores(gathered_list, dataset.meta, column="flow")
meta_new.to_csv(out_path, index=False)
print(f"New meta with optical flow scores saved to '{out_path}'.")
if __name__ == "__main__":
main()