108 lines
3.5 KiB
Python
108 lines
3.5 KiB
Python
import cv2
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
def normalize_tensor(in_feat, eps=1e-10):
|
|
norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True))
|
|
return in_feat / (norm_factor + eps)
|
|
|
|
|
|
def l2(p0, p1, range=255.0):
|
|
return 0.5 * np.mean((p0 / range - p1 / range) ** 2)
|
|
|
|
|
|
def dssim(p0, p1, range=255.0):
|
|
from skimage.measure import compare_ssim
|
|
|
|
return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.0
|
|
|
|
|
|
def tensor2im(image_tensor, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0):
|
|
image_numpy = image_tensor[0].cpu().float().numpy()
|
|
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
|
|
return image_numpy.astype(imtype)
|
|
|
|
|
|
def tensor2np(tensor_obj):
|
|
# change dimension of a tensor object into a numpy array
|
|
return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0))
|
|
|
|
|
|
def np2tensor(np_obj):
|
|
# change dimenion of np array into tensor array
|
|
return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
|
|
|
|
|
def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False):
|
|
# image tensor to lab tensor
|
|
from skimage import color
|
|
|
|
img = tensor2im(image_tensor)
|
|
img_lab = color.rgb2lab(img)
|
|
if mc_only:
|
|
img_lab[:, :, 0] = img_lab[:, :, 0] - 50
|
|
if to_norm and not mc_only:
|
|
img_lab[:, :, 0] = img_lab[:, :, 0] - 50
|
|
img_lab = img_lab / 100.0
|
|
|
|
return np2tensor(img_lab)
|
|
|
|
|
|
def read_frame_yuv2rgb(stream, width, height, iFrame, bit_depth, pix_fmt="420"):
|
|
if pix_fmt == "420":
|
|
multiplier = 1
|
|
uv_factor = 2
|
|
elif pix_fmt == "444":
|
|
multiplier = 2
|
|
uv_factor = 1
|
|
else:
|
|
print("Pixel format {} is not supported".format(pix_fmt))
|
|
return
|
|
|
|
if bit_depth == 8:
|
|
datatype = np.uint8
|
|
stream.seek(iFrame * 1.5 * width * height * multiplier)
|
|
Y = np.fromfile(stream, dtype=datatype, count=width * height).reshape((height, width))
|
|
|
|
# read chroma samples and upsample since original is 4:2:0 sampling
|
|
U = np.fromfile(stream, dtype=datatype, count=(width // uv_factor) * (height // uv_factor)).reshape(
|
|
(height // uv_factor, width // uv_factor)
|
|
)
|
|
V = np.fromfile(stream, dtype=datatype, count=(width // uv_factor) * (height // uv_factor)).reshape(
|
|
(height // uv_factor, width // uv_factor)
|
|
)
|
|
|
|
else:
|
|
datatype = np.uint16
|
|
stream.seek(iFrame * 3 * width * height * multiplier)
|
|
Y = np.fromfile(stream, dtype=datatype, count=width * height).reshape((height, width))
|
|
|
|
U = np.fromfile(stream, dtype=datatype, count=(width // uv_factor) * (height // uv_factor)).reshape(
|
|
(height // uv_factor, width // uv_factor)
|
|
)
|
|
V = np.fromfile(stream, dtype=datatype, count=(width // uv_factor) * (height // uv_factor)).reshape(
|
|
(height // uv_factor, width // uv_factor)
|
|
)
|
|
|
|
if pix_fmt == "420":
|
|
yuv = np.empty((height * 3 // 2, width), dtype=datatype)
|
|
yuv[0:height, :] = Y
|
|
|
|
yuv[height : height + height // 4, :] = U.reshape(-1, width)
|
|
yuv[height + height // 4 :, :] = V.reshape(-1, width)
|
|
|
|
if bit_depth != 8:
|
|
yuv = (yuv / (2**bit_depth - 1) * 255).astype(np.uint8)
|
|
|
|
# convert to rgb
|
|
rgb = cv2.cvtColor(yuv, cv2.COLOR_YUV2RGB_I420)
|
|
|
|
else:
|
|
yvu = np.stack([Y, V, U], axis=2)
|
|
if bit_depth != 8:
|
|
yvu = (yvu / (2**bit_depth - 1) * 255).astype(np.uint8)
|
|
rgb = cv2.cvtColor(yvu, cv2.COLOR_YCrCb2RGB)
|
|
|
|
return rgb
|