413 lines
15 KiB
Python
413 lines
15 KiB
Python
from __future__ import absolute_import
|
|
|
|
import hashlib
|
|
import os
|
|
|
|
import requests
|
|
import torch
|
|
import torch.nn
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.autograd import Variable
|
|
from tqdm import tqdm
|
|
|
|
from .pretrained_networks import alexnet, squeezenet, vgg16
|
|
from .pwcnet import Network as PWCNet
|
|
from .utils import *
|
|
|
|
URL_MAP = {"alex": "https://raw.githubusercontent.com/danier97/flolpips/main/weights/v0.1/alex.pth"}
|
|
|
|
CKPT_MAP = {"alex": "alex.pth"}
|
|
|
|
MD5_MAP = {"alex": "9642209e2b57a85d20f86d812320f9e6"}
|
|
|
|
|
|
def spatial_average(in_tens, keepdim=True):
|
|
return in_tens.mean([2, 3], keepdim=keepdim)
|
|
|
|
|
|
def mw_spatial_average(in_tens, flow, keepdim=True):
|
|
_, _, h, w = in_tens.shape
|
|
flow = F.interpolate(flow, (h, w), align_corners=False, mode="bilinear")
|
|
flow_mag = torch.sqrt(flow[:, 0:1] ** 2 + flow[:, 1:2] ** 2)
|
|
flow_mag = flow_mag / torch.sum(flow_mag, dim=[1, 2, 3], keepdim=True)
|
|
return torch.sum(in_tens * flow_mag, dim=[2, 3], keepdim=keepdim)
|
|
|
|
|
|
def mtw_spatial_average(in_tens, flow, texture, keepdim=True):
|
|
_, _, h, w = in_tens.shape
|
|
flow = F.interpolate(flow, (h, w), align_corners=False, mode="bilinear")
|
|
texture = F.interpolate(texture, (h, w), align_corners=False, mode="bilinear")
|
|
flow_mag = torch.sqrt(flow[:, 0:1] ** 2 + flow[:, 1:2] ** 2)
|
|
flow_mag = (flow_mag - flow_mag.min()) / (flow_mag.max() - flow_mag.min()) + 1e-6
|
|
texture = (texture - texture.min()) / (texture.max() - texture.min()) + 1e-6
|
|
weight = flow_mag / texture
|
|
weight /= torch.sum(weight)
|
|
return torch.sum(in_tens * weight, dim=[2, 3], keepdim=keepdim)
|
|
|
|
|
|
def m2w_spatial_average(in_tens, flow, keepdim=True):
|
|
_, _, h, w = in_tens.shape
|
|
flow = F.interpolate(flow, (h, w), align_corners=False, mode="bilinear")
|
|
flow_mag = flow[:, 0:1] ** 2 + flow[:, 1:2] ** 2 # B,1,H,W
|
|
flow_mag = flow_mag / torch.sum(flow_mag)
|
|
return torch.sum(in_tens * flow_mag, dim=[2, 3], keepdim=keepdim)
|
|
|
|
|
|
def upsample(in_tens, out_HW=(64, 64)): # assumes scale factor is same for H and W
|
|
in_H, in_W = in_tens.shape[2], in_tens.shape[3]
|
|
return nn.Upsample(size=out_HW, mode="bilinear", align_corners=False)(in_tens)
|
|
|
|
|
|
def md5_hash(path):
|
|
with open(path, "rb") as f:
|
|
content = f.read()
|
|
return hashlib.md5(content).hexdigest()
|
|
|
|
|
|
def download(url, local_path, chunk_size=1024):
|
|
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
|
with requests.get(url, stream=True) as r:
|
|
total_size = int(r.headers.get("content-length", 0))
|
|
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
|
with open(local_path, "wb") as f:
|
|
for data in r.iter_content(chunk_size=chunk_size):
|
|
if data:
|
|
f.write(data)
|
|
pbar.update(chunk_size)
|
|
|
|
|
|
def get_ckpt_path(name, root, check=False):
|
|
assert name in URL_MAP
|
|
path = os.path.join(root, CKPT_MAP[name])
|
|
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
|
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
|
download(URL_MAP[name], path)
|
|
md5 = md5_hash(path)
|
|
assert md5 == MD5_MAP[name], md5
|
|
return path
|
|
|
|
|
|
# Learned perceptual metric
|
|
class LPIPS(nn.Module):
|
|
def __init__(
|
|
self,
|
|
pretrained=True,
|
|
net="alex",
|
|
version="0.1",
|
|
lpips=True,
|
|
spatial=False,
|
|
pnet_rand=False,
|
|
pnet_tune=False,
|
|
use_dropout=True,
|
|
model_path=None,
|
|
eval_mode=True,
|
|
verbose=False,
|
|
):
|
|
# lpips - [True] means with linear calibration on top of base network
|
|
# pretrained - [True] means load linear weights
|
|
|
|
super(LPIPS, self).__init__()
|
|
if verbose:
|
|
print(
|
|
"Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]"
|
|
% ("LPIPS" if lpips else "baseline", net, version, "on" if spatial else "off")
|
|
)
|
|
|
|
self.pnet_type = net
|
|
self.pnet_tune = pnet_tune
|
|
self.pnet_rand = pnet_rand
|
|
self.spatial = spatial
|
|
self.lpips = lpips # false means baseline of just averaging all layers
|
|
self.version = version
|
|
self.scaling_layer = ScalingLayer()
|
|
|
|
if self.pnet_type in ["vgg", "vgg16"]:
|
|
net_type = vgg16
|
|
self.chns = [64, 128, 256, 512, 512]
|
|
elif self.pnet_type == "alex":
|
|
net_type = alexnet
|
|
self.chns = [64, 192, 384, 256, 256]
|
|
elif self.pnet_type == "squeeze":
|
|
net_type = squeezenet
|
|
self.chns = [64, 128, 256, 384, 384, 512, 512]
|
|
self.L = len(self.chns)
|
|
|
|
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
|
|
|
|
if lpips:
|
|
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
|
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
|
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
|
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
|
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
|
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
|
if self.pnet_type == "squeeze": # 7 layers for squeezenet
|
|
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
|
|
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
|
|
self.lins += [self.lin5, self.lin6]
|
|
self.lins = nn.ModuleList(self.lins)
|
|
|
|
if pretrained:
|
|
self.load_from_pretrained(version, net)
|
|
if verbose:
|
|
print("Loaded model from: %s" % model_path)
|
|
|
|
if eval_mode:
|
|
self.eval()
|
|
|
|
def load_from_pretrained(self, version, net):
|
|
ckpt = get_ckpt_path(net, "pretrained_models/flolpips/weights/v%s" % (version))
|
|
self.load_state_dict(torch.load(ckpt, map_location="cpu"), strict=False)
|
|
|
|
def forward(self, in0, in1, retPerLayer=False, normalize=False):
|
|
if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
|
|
in0 = 2 * in0 - 1
|
|
in1 = 2 * in1 - 1
|
|
|
|
# v0.0 - original release had a bug, where input was not scaled
|
|
in0_input, in1_input = (
|
|
(self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == "0.1" else (in0, in1)
|
|
)
|
|
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
|
|
feats0, feats1, diffs = {}, {}, {}
|
|
|
|
for kk in range(self.L):
|
|
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
|
|
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
|
|
|
if self.lpips:
|
|
if self.spatial:
|
|
res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)]
|
|
else:
|
|
res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
|
|
else:
|
|
if self.spatial:
|
|
res = [upsample(diffs[kk].sum(dim=1, keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)]
|
|
else:
|
|
res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)]
|
|
|
|
# val = res[0]
|
|
# for l in range(1,self.L):
|
|
# val += res[l]
|
|
# print(val)
|
|
|
|
# a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
|
|
# b = torch.max(self.lins[kk](feats0[kk]**2))
|
|
# for kk in range(self.L):
|
|
# a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
|
|
# b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2)))
|
|
# a = a/self.L
|
|
# from IPython import embed
|
|
# embed()
|
|
# return 10*torch.log10(b/a)
|
|
|
|
# if(retPerLayer):
|
|
# return (val, res)
|
|
# else:
|
|
return torch.sum(torch.cat(res, 1), dim=(1, 2, 3), keepdims=False)
|
|
|
|
|
|
class ScalingLayer(nn.Module):
|
|
def __init__(self):
|
|
super(ScalingLayer, self).__init__()
|
|
self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None])
|
|
self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None])
|
|
|
|
def forward(self, inp):
|
|
return (inp - self.shift) / self.scale
|
|
|
|
|
|
class NetLinLayer(nn.Module):
|
|
"""A single linear layer which does a 1x1 conv"""
|
|
|
|
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
|
super(NetLinLayer, self).__init__()
|
|
|
|
layers = (
|
|
[
|
|
nn.Dropout(),
|
|
]
|
|
if (use_dropout)
|
|
else []
|
|
)
|
|
layers += [
|
|
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
|
|
]
|
|
self.model = nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
return self.model(x)
|
|
|
|
|
|
class Dist2LogitLayer(nn.Module):
|
|
"""takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True)"""
|
|
|
|
def __init__(self, chn_mid=32, use_sigmoid=True):
|
|
super(Dist2LogitLayer, self).__init__()
|
|
|
|
layers = [
|
|
nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),
|
|
]
|
|
layers += [
|
|
nn.LeakyReLU(0.2, True),
|
|
]
|
|
layers += [
|
|
nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),
|
|
]
|
|
layers += [
|
|
nn.LeakyReLU(0.2, True),
|
|
]
|
|
layers += [
|
|
nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),
|
|
]
|
|
if use_sigmoid:
|
|
layers += [
|
|
nn.Sigmoid(),
|
|
]
|
|
self.model = nn.Sequential(*layers)
|
|
|
|
def forward(self, d0, d1, eps=0.1):
|
|
return self.model.forward(torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1))
|
|
|
|
|
|
class BCERankingLoss(nn.Module):
|
|
def __init__(self, chn_mid=32):
|
|
super(BCERankingLoss, self).__init__()
|
|
self.net = Dist2LogitLayer(chn_mid=chn_mid)
|
|
# self.parameters = list(self.net.parameters())
|
|
self.loss = torch.nn.BCELoss()
|
|
|
|
def forward(self, d0, d1, judge):
|
|
per = (judge + 1.0) / 2.0
|
|
self.logit = self.net.forward(d0, d1)
|
|
return self.loss(self.logit, per)
|
|
|
|
|
|
# L2, DSSIM metrics
|
|
class FakeNet(nn.Module):
|
|
def __init__(self, use_gpu=True, colorspace="Lab"):
|
|
super(FakeNet, self).__init__()
|
|
self.use_gpu = use_gpu
|
|
self.colorspace = colorspace
|
|
|
|
|
|
class L2(FakeNet):
|
|
def forward(self, in0, in1, retPerLayer=None):
|
|
assert in0.size()[0] == 1 # currently only supports batchSize 1
|
|
|
|
if self.colorspace == "RGB":
|
|
(N, C, X, Y) = in0.size()
|
|
value = torch.mean(
|
|
torch.mean(torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2).view(N, 1, 1, Y), dim=3
|
|
).view(N)
|
|
return value
|
|
elif self.colorspace == "Lab":
|
|
value = l2(
|
|
tensor2np(tensor2tensorlab(in0.data, to_norm=False)),
|
|
tensor2np(tensor2tensorlab(in1.data, to_norm=False)),
|
|
range=100.0,
|
|
).astype("float")
|
|
ret_var = Variable(torch.Tensor((value,)))
|
|
if self.use_gpu:
|
|
ret_var = ret_var.cuda()
|
|
return ret_var
|
|
|
|
|
|
class DSSIM(FakeNet):
|
|
def forward(self, in0, in1, retPerLayer=None):
|
|
assert in0.size()[0] == 1 # currently only supports batchSize 1
|
|
|
|
if self.colorspace == "RGB":
|
|
value = dssim(1.0 * tensor2im(in0.data), 1.0 * tensor2im(in1.data), range=255.0).astype("float")
|
|
elif self.colorspace == "Lab":
|
|
value = dssim(
|
|
tensor2np(tensor2tensorlab(in0.data, to_norm=False)),
|
|
tensor2np(tensor2tensorlab(in1.data, to_norm=False)),
|
|
range=100.0,
|
|
).astype("float")
|
|
ret_var = Variable(torch.Tensor((value,)))
|
|
if self.use_gpu:
|
|
ret_var = ret_var.cuda()
|
|
return ret_var
|
|
|
|
|
|
def print_network(net):
|
|
num_params = 0
|
|
for param in net.parameters():
|
|
num_params += param.numel()
|
|
print("Network", net)
|
|
print("Total number of parameters: %d" % num_params)
|
|
|
|
|
|
class FloLPIPS(LPIPS):
|
|
def __init__(
|
|
self,
|
|
pretrained=True,
|
|
net="alex",
|
|
version="0.1",
|
|
lpips=True,
|
|
spatial=False,
|
|
pnet_rand=False,
|
|
pnet_tune=False,
|
|
use_dropout=True,
|
|
model_path=None,
|
|
eval_mode=True,
|
|
verbose=False,
|
|
):
|
|
super(FloLPIPS, self).__init__(
|
|
pretrained, net, version, lpips, spatial, pnet_rand, pnet_tune, use_dropout, model_path, eval_mode, verbose
|
|
)
|
|
|
|
def forward(self, in0, in1, flow, retPerLayer=False, normalize=False):
|
|
if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
|
|
in0 = 2 * in0 - 1
|
|
in1 = 2 * in1 - 1
|
|
|
|
in0_input, in1_input = (
|
|
(self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == "0.1" else (in0, in1)
|
|
)
|
|
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
|
|
feats0, feats1, diffs = {}, {}, {}
|
|
|
|
for kk in range(self.L):
|
|
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
|
|
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
|
|
|
res = [mw_spatial_average(self.lins[kk](diffs[kk]), flow, keepdim=True) for kk in range(self.L)]
|
|
|
|
return torch.sum(torch.cat(res, 1), dim=(1, 2, 3), keepdims=False)
|
|
|
|
|
|
class Flolpips(nn.Module):
|
|
def __init__(self):
|
|
super(Flolpips, self).__init__()
|
|
self.loss_fn = FloLPIPS(net="alex", version="0.1")
|
|
self.flownet = PWCNet()
|
|
|
|
@torch.no_grad()
|
|
def forward(self, I0, I1, frame_dis, frame_ref):
|
|
"""
|
|
args:
|
|
I0: first frame of the triplet, shape: [B, C, H, W]
|
|
I1: third frame of the triplet, shape: [B, C, H, W]
|
|
frame_dis: prediction of the intermediate frame, shape: [B, C, H, W]
|
|
frame_ref: ground-truth of the intermediate frame, shape: [B, C, H, W]
|
|
"""
|
|
assert (
|
|
I0.size() == I1.size() == frame_dis.size() == frame_ref.size()
|
|
), "the 4 input tensors should have same size"
|
|
|
|
flow_ref = self.flownet(frame_ref, I0)
|
|
flow_dis = self.flownet(frame_dis, I0)
|
|
flow_diff = flow_ref - flow_dis
|
|
flolpips_wrt_I0 = self.loss_fn.forward(frame_ref, frame_dis, flow_diff, normalize=True)
|
|
|
|
flow_ref = self.flownet(frame_ref, I1)
|
|
flow_dis = self.flownet(frame_dis, I1)
|
|
flow_diff = flow_ref - flow_dis
|
|
flolpips_wrt_I1 = self.loss_fn.forward(frame_ref, frame_dis, flow_diff, normalize=True)
|
|
|
|
flolpips = (flolpips_wrt_I0 + flolpips_wrt_I1) / 2
|
|
return flolpips
|