from __future__ import absolute_import import hashlib import os import requests import torch import torch.nn import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from tqdm import tqdm from .pretrained_networks import alexnet, squeezenet, vgg16 from .pwcnet import Network as PWCNet from .utils import * URL_MAP = {"alex": "https://raw.githubusercontent.com/danier97/flolpips/main/weights/v0.1/alex.pth"} CKPT_MAP = {"alex": "alex.pth"} MD5_MAP = {"alex": "9642209e2b57a85d20f86d812320f9e6"} def spatial_average(in_tens, keepdim=True): return in_tens.mean([2, 3], keepdim=keepdim) def mw_spatial_average(in_tens, flow, keepdim=True): _, _, h, w = in_tens.shape flow = F.interpolate(flow, (h, w), align_corners=False, mode="bilinear") flow_mag = torch.sqrt(flow[:, 0:1] ** 2 + flow[:, 1:2] ** 2) flow_mag = flow_mag / torch.sum(flow_mag, dim=[1, 2, 3], keepdim=True) return torch.sum(in_tens * flow_mag, dim=[2, 3], keepdim=keepdim) def mtw_spatial_average(in_tens, flow, texture, keepdim=True): _, _, h, w = in_tens.shape flow = F.interpolate(flow, (h, w), align_corners=False, mode="bilinear") texture = F.interpolate(texture, (h, w), align_corners=False, mode="bilinear") flow_mag = torch.sqrt(flow[:, 0:1] ** 2 + flow[:, 1:2] ** 2) flow_mag = (flow_mag - flow_mag.min()) / (flow_mag.max() - flow_mag.min()) + 1e-6 texture = (texture - texture.min()) / (texture.max() - texture.min()) + 1e-6 weight = flow_mag / texture weight /= torch.sum(weight) return torch.sum(in_tens * weight, dim=[2, 3], keepdim=keepdim) def m2w_spatial_average(in_tens, flow, keepdim=True): _, _, h, w = in_tens.shape flow = F.interpolate(flow, (h, w), align_corners=False, mode="bilinear") flow_mag = flow[:, 0:1] ** 2 + flow[:, 1:2] ** 2 # B,1,H,W flow_mag = flow_mag / torch.sum(flow_mag) return torch.sum(in_tens * flow_mag, dim=[2, 3], keepdim=keepdim) def upsample(in_tens, out_HW=(64, 64)): # assumes scale factor is same for H and W in_H, in_W = in_tens.shape[2], in_tens.shape[3] return nn.Upsample(size=out_HW, mode="bilinear", align_corners=False)(in_tens) def md5_hash(path): with open(path, "rb") as f: content = f.read() return hashlib.md5(content).hexdigest() def download(url, local_path, chunk_size=1024): os.makedirs(os.path.split(local_path)[0], exist_ok=True) with requests.get(url, stream=True) as r: total_size = int(r.headers.get("content-length", 0)) with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: with open(local_path, "wb") as f: for data in r.iter_content(chunk_size=chunk_size): if data: f.write(data) pbar.update(chunk_size) def get_ckpt_path(name, root, check=False): assert name in URL_MAP path = os.path.join(root, CKPT_MAP[name]) if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) download(URL_MAP[name], path) md5 = md5_hash(path) assert md5 == MD5_MAP[name], md5 return path # Learned perceptual metric class LPIPS(nn.Module): def __init__( self, pretrained=True, net="alex", version="0.1", lpips=True, spatial=False, pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=False, ): # lpips - [True] means with linear calibration on top of base network # pretrained - [True] means load linear weights super(LPIPS, self).__init__() if verbose: print( "Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]" % ("LPIPS" if lpips else "baseline", net, version, "on" if spatial else "off") ) self.pnet_type = net self.pnet_tune = pnet_tune self.pnet_rand = pnet_rand self.spatial = spatial self.lpips = lpips # false means baseline of just averaging all layers self.version = version self.scaling_layer = ScalingLayer() if self.pnet_type in ["vgg", "vgg16"]: net_type = vgg16 self.chns = [64, 128, 256, 512, 512] elif self.pnet_type == "alex": net_type = alexnet self.chns = [64, 192, 384, 256, 256] elif self.pnet_type == "squeeze": net_type = squeezenet self.chns = [64, 128, 256, 384, 384, 512, 512] self.L = len(self.chns) self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) if lpips: self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] if self.pnet_type == "squeeze": # 7 layers for squeezenet self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) self.lins += [self.lin5, self.lin6] self.lins = nn.ModuleList(self.lins) if pretrained: self.load_from_pretrained(version, net) if verbose: print("Loaded model from: %s" % model_path) if eval_mode: self.eval() def load_from_pretrained(self, version, net): ckpt = get_ckpt_path(net, "pretrained_models/flolpips/weights/v%s" % (version)) self.load_state_dict(torch.load(ckpt, map_location="cpu"), strict=False) def forward(self, in0, in1, retPerLayer=False, normalize=False): if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] in0 = 2 * in0 - 1 in1 = 2 * in1 - 1 # v0.0 - original release had a bug, where input was not scaled in0_input, in1_input = ( (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == "0.1" else (in0, in1) ) outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) feats0, feats1, diffs = {}, {}, {} for kk in range(self.L): feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 if self.lpips: if self.spatial: res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)] else: res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] else: if self.spatial: res = [upsample(diffs[kk].sum(dim=1, keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)] else: res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)] # val = res[0] # for l in range(1,self.L): # val += res[l] # print(val) # a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True) # b = torch.max(self.lins[kk](feats0[kk]**2)) # for kk in range(self.L): # a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True) # b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2))) # a = a/self.L # from IPython import embed # embed() # return 10*torch.log10(b/a) # if(retPerLayer): # return (val, res) # else: return torch.sum(torch.cat(res, 1), dim=(1, 2, 3), keepdims=False) class ScalingLayer(nn.Module): def __init__(self): super(ScalingLayer, self).__init__() self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]) self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]) def forward(self, inp): return (inp - self.shift) / self.scale class NetLinLayer(nn.Module): """A single linear layer which does a 1x1 conv""" def __init__(self, chn_in, chn_out=1, use_dropout=False): super(NetLinLayer, self).__init__() layers = ( [ nn.Dropout(), ] if (use_dropout) else [] ) layers += [ nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x) class Dist2LogitLayer(nn.Module): """takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True)""" def __init__(self, chn_mid=32, use_sigmoid=True): super(Dist2LogitLayer, self).__init__() layers = [ nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True), ] layers += [ nn.LeakyReLU(0.2, True), ] layers += [ nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True), ] layers += [ nn.LeakyReLU(0.2, True), ] layers += [ nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True), ] if use_sigmoid: layers += [ nn.Sigmoid(), ] self.model = nn.Sequential(*layers) def forward(self, d0, d1, eps=0.1): return self.model.forward(torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1)) class BCERankingLoss(nn.Module): def __init__(self, chn_mid=32): super(BCERankingLoss, self).__init__() self.net = Dist2LogitLayer(chn_mid=chn_mid) # self.parameters = list(self.net.parameters()) self.loss = torch.nn.BCELoss() def forward(self, d0, d1, judge): per = (judge + 1.0) / 2.0 self.logit = self.net.forward(d0, d1) return self.loss(self.logit, per) # L2, DSSIM metrics class FakeNet(nn.Module): def __init__(self, use_gpu=True, colorspace="Lab"): super(FakeNet, self).__init__() self.use_gpu = use_gpu self.colorspace = colorspace class L2(FakeNet): def forward(self, in0, in1, retPerLayer=None): assert in0.size()[0] == 1 # currently only supports batchSize 1 if self.colorspace == "RGB": (N, C, X, Y) = in0.size() value = torch.mean( torch.mean(torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2).view(N, 1, 1, Y), dim=3 ).view(N) return value elif self.colorspace == "Lab": value = l2( tensor2np(tensor2tensorlab(in0.data, to_norm=False)), tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.0, ).astype("float") ret_var = Variable(torch.Tensor((value,))) if self.use_gpu: ret_var = ret_var.cuda() return ret_var class DSSIM(FakeNet): def forward(self, in0, in1, retPerLayer=None): assert in0.size()[0] == 1 # currently only supports batchSize 1 if self.colorspace == "RGB": value = dssim(1.0 * tensor2im(in0.data), 1.0 * tensor2im(in1.data), range=255.0).astype("float") elif self.colorspace == "Lab": value = dssim( tensor2np(tensor2tensorlab(in0.data, to_norm=False)), tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.0, ).astype("float") ret_var = Variable(torch.Tensor((value,))) if self.use_gpu: ret_var = ret_var.cuda() return ret_var def print_network(net): num_params = 0 for param in net.parameters(): num_params += param.numel() print("Network", net) print("Total number of parameters: %d" % num_params) class FloLPIPS(LPIPS): def __init__( self, pretrained=True, net="alex", version="0.1", lpips=True, spatial=False, pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=False, ): super(FloLPIPS, self).__init__( pretrained, net, version, lpips, spatial, pnet_rand, pnet_tune, use_dropout, model_path, eval_mode, verbose ) def forward(self, in0, in1, flow, retPerLayer=False, normalize=False): if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] in0 = 2 * in0 - 1 in1 = 2 * in1 - 1 in0_input, in1_input = ( (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == "0.1" else (in0, in1) ) outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) feats0, feats1, diffs = {}, {}, {} for kk in range(self.L): feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 res = [mw_spatial_average(self.lins[kk](diffs[kk]), flow, keepdim=True) for kk in range(self.L)] return torch.sum(torch.cat(res, 1), dim=(1, 2, 3), keepdims=False) class Flolpips(nn.Module): def __init__(self): super(Flolpips, self).__init__() self.loss_fn = FloLPIPS(net="alex", version="0.1") self.flownet = PWCNet() @torch.no_grad() def forward(self, I0, I1, frame_dis, frame_ref): """ args: I0: first frame of the triplet, shape: [B, C, H, W] I1: third frame of the triplet, shape: [B, C, H, W] frame_dis: prediction of the intermediate frame, shape: [B, C, H, W] frame_ref: ground-truth of the intermediate frame, shape: [B, C, H, W] """ assert ( I0.size() == I1.size() == frame_dis.size() == frame_ref.size() ), "the 4 input tensors should have same size" flow_ref = self.flownet(frame_ref, I0) flow_dis = self.flownet(frame_dis, I0) flow_diff = flow_ref - flow_dis flolpips_wrt_I0 = self.loss_fn.forward(frame_ref, frame_dis, flow_diff, normalize=True) flow_ref = self.flownet(frame_ref, I1) flow_dis = self.flownet(frame_dis, I1) flow_diff = flow_ref - flow_dis flolpips_wrt_I1 = self.loss_fn.forward(frame_ref, frame_dis, flow_diff, normalize=True) flolpips = (flolpips_wrt_I0 + flolpips_wrt_I1) / 2 return flolpips