mysora/eval/vae/flolpips/flolpips.py

413 lines
15 KiB
Python

from __future__ import absolute_import
import hashlib
import os
import requests
import torch
import torch.nn
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from tqdm import tqdm
from .pretrained_networks import alexnet, squeezenet, vgg16
from .pwcnet import Network as PWCNet
from .utils import *
URL_MAP = {"alex": "https://raw.githubusercontent.com/danier97/flolpips/main/weights/v0.1/alex.pth"}
CKPT_MAP = {"alex": "alex.pth"}
MD5_MAP = {"alex": "9642209e2b57a85d20f86d812320f9e6"}
def spatial_average(in_tens, keepdim=True):
return in_tens.mean([2, 3], keepdim=keepdim)
def mw_spatial_average(in_tens, flow, keepdim=True):
_, _, h, w = in_tens.shape
flow = F.interpolate(flow, (h, w), align_corners=False, mode="bilinear")
flow_mag = torch.sqrt(flow[:, 0:1] ** 2 + flow[:, 1:2] ** 2)
flow_mag = flow_mag / torch.sum(flow_mag, dim=[1, 2, 3], keepdim=True)
return torch.sum(in_tens * flow_mag, dim=[2, 3], keepdim=keepdim)
def mtw_spatial_average(in_tens, flow, texture, keepdim=True):
_, _, h, w = in_tens.shape
flow = F.interpolate(flow, (h, w), align_corners=False, mode="bilinear")
texture = F.interpolate(texture, (h, w), align_corners=False, mode="bilinear")
flow_mag = torch.sqrt(flow[:, 0:1] ** 2 + flow[:, 1:2] ** 2)
flow_mag = (flow_mag - flow_mag.min()) / (flow_mag.max() - flow_mag.min()) + 1e-6
texture = (texture - texture.min()) / (texture.max() - texture.min()) + 1e-6
weight = flow_mag / texture
weight /= torch.sum(weight)
return torch.sum(in_tens * weight, dim=[2, 3], keepdim=keepdim)
def m2w_spatial_average(in_tens, flow, keepdim=True):
_, _, h, w = in_tens.shape
flow = F.interpolate(flow, (h, w), align_corners=False, mode="bilinear")
flow_mag = flow[:, 0:1] ** 2 + flow[:, 1:2] ** 2 # B,1,H,W
flow_mag = flow_mag / torch.sum(flow_mag)
return torch.sum(in_tens * flow_mag, dim=[2, 3], keepdim=keepdim)
def upsample(in_tens, out_HW=(64, 64)): # assumes scale factor is same for H and W
in_H, in_W = in_tens.shape[2], in_tens.shape[3]
return nn.Upsample(size=out_HW, mode="bilinear", align_corners=False)(in_tens)
def md5_hash(path):
with open(path, "rb") as f:
content = f.read()
return hashlib.md5(content).hexdigest()
def download(url, local_path, chunk_size=1024):
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
with requests.get(url, stream=True) as r:
total_size = int(r.headers.get("content-length", 0))
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
with open(local_path, "wb") as f:
for data in r.iter_content(chunk_size=chunk_size):
if data:
f.write(data)
pbar.update(chunk_size)
def get_ckpt_path(name, root, check=False):
assert name in URL_MAP
path = os.path.join(root, CKPT_MAP[name])
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
download(URL_MAP[name], path)
md5 = md5_hash(path)
assert md5 == MD5_MAP[name], md5
return path
# Learned perceptual metric
class LPIPS(nn.Module):
def __init__(
self,
pretrained=True,
net="alex",
version="0.1",
lpips=True,
spatial=False,
pnet_rand=False,
pnet_tune=False,
use_dropout=True,
model_path=None,
eval_mode=True,
verbose=False,
):
# lpips - [True] means with linear calibration on top of base network
# pretrained - [True] means load linear weights
super(LPIPS, self).__init__()
if verbose:
print(
"Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]"
% ("LPIPS" if lpips else "baseline", net, version, "on" if spatial else "off")
)
self.pnet_type = net
self.pnet_tune = pnet_tune
self.pnet_rand = pnet_rand
self.spatial = spatial
self.lpips = lpips # false means baseline of just averaging all layers
self.version = version
self.scaling_layer = ScalingLayer()
if self.pnet_type in ["vgg", "vgg16"]:
net_type = vgg16
self.chns = [64, 128, 256, 512, 512]
elif self.pnet_type == "alex":
net_type = alexnet
self.chns = [64, 192, 384, 256, 256]
elif self.pnet_type == "squeeze":
net_type = squeezenet
self.chns = [64, 128, 256, 384, 384, 512, 512]
self.L = len(self.chns)
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
if lpips:
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
if self.pnet_type == "squeeze": # 7 layers for squeezenet
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
self.lins += [self.lin5, self.lin6]
self.lins = nn.ModuleList(self.lins)
if pretrained:
self.load_from_pretrained(version, net)
if verbose:
print("Loaded model from: %s" % model_path)
if eval_mode:
self.eval()
def load_from_pretrained(self, version, net):
ckpt = get_ckpt_path(net, "pretrained_models/flolpips/weights/v%s" % (version))
self.load_state_dict(torch.load(ckpt, map_location="cpu"), strict=False)
def forward(self, in0, in1, retPerLayer=False, normalize=False):
if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
in0 = 2 * in0 - 1
in1 = 2 * in1 - 1
# v0.0 - original release had a bug, where input was not scaled
in0_input, in1_input = (
(self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == "0.1" else (in0, in1)
)
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
feats0, feats1, diffs = {}, {}, {}
for kk in range(self.L):
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
if self.lpips:
if self.spatial:
res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)]
else:
res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
else:
if self.spatial:
res = [upsample(diffs[kk].sum(dim=1, keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)]
else:
res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)]
# val = res[0]
# for l in range(1,self.L):
# val += res[l]
# print(val)
# a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
# b = torch.max(self.lins[kk](feats0[kk]**2))
# for kk in range(self.L):
# a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
# b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2)))
# a = a/self.L
# from IPython import embed
# embed()
# return 10*torch.log10(b/a)
# if(retPerLayer):
# return (val, res)
# else:
return torch.sum(torch.cat(res, 1), dim=(1, 2, 3), keepdims=False)
class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None])
self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None])
def forward(self, inp):
return (inp - self.shift) / self.scale
class NetLinLayer(nn.Module):
"""A single linear layer which does a 1x1 conv"""
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = (
[
nn.Dropout(),
]
if (use_dropout)
else []
)
layers += [
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
]
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class Dist2LogitLayer(nn.Module):
"""takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True)"""
def __init__(self, chn_mid=32, use_sigmoid=True):
super(Dist2LogitLayer, self).__init__()
layers = [
nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),
]
layers += [
nn.LeakyReLU(0.2, True),
]
layers += [
nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),
]
layers += [
nn.LeakyReLU(0.2, True),
]
layers += [
nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),
]
if use_sigmoid:
layers += [
nn.Sigmoid(),
]
self.model = nn.Sequential(*layers)
def forward(self, d0, d1, eps=0.1):
return self.model.forward(torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1))
class BCERankingLoss(nn.Module):
def __init__(self, chn_mid=32):
super(BCERankingLoss, self).__init__()
self.net = Dist2LogitLayer(chn_mid=chn_mid)
# self.parameters = list(self.net.parameters())
self.loss = torch.nn.BCELoss()
def forward(self, d0, d1, judge):
per = (judge + 1.0) / 2.0
self.logit = self.net.forward(d0, d1)
return self.loss(self.logit, per)
# L2, DSSIM metrics
class FakeNet(nn.Module):
def __init__(self, use_gpu=True, colorspace="Lab"):
super(FakeNet, self).__init__()
self.use_gpu = use_gpu
self.colorspace = colorspace
class L2(FakeNet):
def forward(self, in0, in1, retPerLayer=None):
assert in0.size()[0] == 1 # currently only supports batchSize 1
if self.colorspace == "RGB":
(N, C, X, Y) = in0.size()
value = torch.mean(
torch.mean(torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2).view(N, 1, 1, Y), dim=3
).view(N)
return value
elif self.colorspace == "Lab":
value = l2(
tensor2np(tensor2tensorlab(in0.data, to_norm=False)),
tensor2np(tensor2tensorlab(in1.data, to_norm=False)),
range=100.0,
).astype("float")
ret_var = Variable(torch.Tensor((value,)))
if self.use_gpu:
ret_var = ret_var.cuda()
return ret_var
class DSSIM(FakeNet):
def forward(self, in0, in1, retPerLayer=None):
assert in0.size()[0] == 1 # currently only supports batchSize 1
if self.colorspace == "RGB":
value = dssim(1.0 * tensor2im(in0.data), 1.0 * tensor2im(in1.data), range=255.0).astype("float")
elif self.colorspace == "Lab":
value = dssim(
tensor2np(tensor2tensorlab(in0.data, to_norm=False)),
tensor2np(tensor2tensorlab(in1.data, to_norm=False)),
range=100.0,
).astype("float")
ret_var = Variable(torch.Tensor((value,)))
if self.use_gpu:
ret_var = ret_var.cuda()
return ret_var
def print_network(net):
num_params = 0
for param in net.parameters():
num_params += param.numel()
print("Network", net)
print("Total number of parameters: %d" % num_params)
class FloLPIPS(LPIPS):
def __init__(
self,
pretrained=True,
net="alex",
version="0.1",
lpips=True,
spatial=False,
pnet_rand=False,
pnet_tune=False,
use_dropout=True,
model_path=None,
eval_mode=True,
verbose=False,
):
super(FloLPIPS, self).__init__(
pretrained, net, version, lpips, spatial, pnet_rand, pnet_tune, use_dropout, model_path, eval_mode, verbose
)
def forward(self, in0, in1, flow, retPerLayer=False, normalize=False):
if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
in0 = 2 * in0 - 1
in1 = 2 * in1 - 1
in0_input, in1_input = (
(self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == "0.1" else (in0, in1)
)
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
feats0, feats1, diffs = {}, {}, {}
for kk in range(self.L):
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
res = [mw_spatial_average(self.lins[kk](diffs[kk]), flow, keepdim=True) for kk in range(self.L)]
return torch.sum(torch.cat(res, 1), dim=(1, 2, 3), keepdims=False)
class Flolpips(nn.Module):
def __init__(self):
super(Flolpips, self).__init__()
self.loss_fn = FloLPIPS(net="alex", version="0.1")
self.flownet = PWCNet()
@torch.no_grad()
def forward(self, I0, I1, frame_dis, frame_ref):
"""
args:
I0: first frame of the triplet, shape: [B, C, H, W]
I1: third frame of the triplet, shape: [B, C, H, W]
frame_dis: prediction of the intermediate frame, shape: [B, C, H, W]
frame_ref: ground-truth of the intermediate frame, shape: [B, C, H, W]
"""
assert (
I0.size() == I1.size() == frame_dis.size() == frame_ref.size()
), "the 4 input tensors should have same size"
flow_ref = self.flownet(frame_ref, I0)
flow_dis = self.flownet(frame_dis, I0)
flow_diff = flow_ref - flow_dis
flolpips_wrt_I0 = self.loss_fn.forward(frame_ref, frame_dis, flow_diff, normalize=True)
flow_ref = self.flownet(frame_ref, I1)
flow_dis = self.flownet(frame_dis, I1)
flow_diff = flow_ref - flow_dis
flolpips_wrt_I1 = self.loss_fn.forward(frame_ref, frame_dis, flow_diff, normalize=True)
flolpips = (flolpips_wrt_I0 + flolpips_wrt_I1) / 2
return flolpips