1131 lines
43 KiB
Python
1131 lines
43 KiB
Python
import concurrent.futures
|
|
import contextlib
|
|
import glob
|
|
import io
|
|
import os
|
|
import re
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import requests
|
|
import torch
|
|
import torchvision.transforms.v2.functional as F
|
|
from common_utils import assert_equal, cpu_and_cuda, IN_OSS_CI, needs_cuda
|
|
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence
|
|
from torchvision.io.image import (
|
|
decode_avif,
|
|
decode_gif,
|
|
decode_heic,
|
|
decode_image,
|
|
decode_jpeg,
|
|
decode_png,
|
|
decode_webp,
|
|
encode_jpeg,
|
|
encode_png,
|
|
ImageReadMode,
|
|
read_file,
|
|
read_image,
|
|
write_file,
|
|
write_jpeg,
|
|
write_png,
|
|
)
|
|
|
|
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
|
|
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
|
|
IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
|
|
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, "damaged_jpeg")
|
|
DAMAGED_PNG = os.path.join(IMAGE_ROOT, "damaged_png")
|
|
ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
|
|
INTERLACED_PNG = os.path.join(IMAGE_ROOT, "interlaced_png")
|
|
TOOSMALL_PNG = os.path.join(IMAGE_ROOT, "toosmall_png")
|
|
IS_WINDOWS = sys.platform in ("win32", "cygwin")
|
|
IS_MACOS = sys.platform == "darwin"
|
|
IS_LINUX = sys.platform == "linux"
|
|
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
|
|
WEBP_TEST_IMAGES_DIR = os.environ.get("WEBP_TEST_IMAGES_DIR", "")
|
|
# See https://github.com/pytorch/vision/pull/8724#issuecomment-2503964558
|
|
HEIC_AVIF_MESSAGE = "AVIF and HEIF only available on linux."
|
|
|
|
|
|
def _get_safe_image_name(name):
|
|
# Used when we need to change the pytest "id" for an "image path" parameter.
|
|
# If we don't, the test id (i.e. its name) will contain the whole path to the image, which is machine-specific,
|
|
# and this creates issues when the test is running in a different machine than where it was collected
|
|
# (typically, in fb internal infra)
|
|
return name.split(os.path.sep)[-1]
|
|
|
|
|
|
def get_images(directory, img_ext):
|
|
assert os.path.isdir(directory)
|
|
image_paths = glob.glob(directory + f"/**/*{img_ext}", recursive=True)
|
|
for path in image_paths:
|
|
if path.split(os.sep)[-2] not in ["damaged_jpeg", "jpeg_write"]:
|
|
yield path
|
|
|
|
|
|
def pil_read_image(img_path):
|
|
with Image.open(img_path) as img:
|
|
return torch.from_numpy(np.array(img))
|
|
|
|
|
|
def normalize_dimensions(img_pil):
|
|
if len(img_pil.shape) == 3:
|
|
img_pil = img_pil.permute(2, 0, 1)
|
|
else:
|
|
img_pil = img_pil.unsqueeze(0)
|
|
return img_pil
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"img_path",
|
|
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"pil_mode, mode",
|
|
[
|
|
(None, ImageReadMode.UNCHANGED),
|
|
("L", ImageReadMode.GRAY),
|
|
("RGB", ImageReadMode.RGB),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("scripted", (False, True))
|
|
@pytest.mark.parametrize("decode_fun", (decode_jpeg, decode_image))
|
|
def test_decode_jpeg(img_path, pil_mode, mode, scripted, decode_fun):
|
|
|
|
with Image.open(img_path) as img:
|
|
is_cmyk = img.mode == "CMYK"
|
|
if pil_mode is not None:
|
|
img = img.convert(pil_mode)
|
|
img_pil = torch.from_numpy(np.array(img))
|
|
if is_cmyk and mode == ImageReadMode.UNCHANGED:
|
|
# flip the colors to match libjpeg
|
|
img_pil = 255 - img_pil
|
|
|
|
img_pil = normalize_dimensions(img_pil)
|
|
data = read_file(img_path)
|
|
if scripted:
|
|
decode_fun = torch.jit.script(decode_fun)
|
|
img_ljpeg = decode_fun(data, mode=mode)
|
|
|
|
# Permit a small variation on pixel values to account for implementation
|
|
# differences between Pillow and LibJPEG.
|
|
abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item()
|
|
assert abs_mean_diff < 2
|
|
|
|
|
|
@pytest.mark.parametrize("codec", ["png", "jpeg"])
|
|
@pytest.mark.parametrize("orientation", [1, 2, 3, 4, 5, 6, 7, 8, 0])
|
|
def test_decode_with_exif_orientation(tmpdir, codec, orientation):
|
|
fp = os.path.join(tmpdir, f"exif_oriented_{orientation}.{codec}")
|
|
t = torch.randint(0, 256, size=(3, 256, 257), dtype=torch.uint8)
|
|
im = F.to_pil_image(t)
|
|
exif = im.getexif()
|
|
exif[0x0112] = orientation # set exif orientation
|
|
im.save(fp, codec.upper(), exif=exif.tobytes())
|
|
|
|
data = read_file(fp)
|
|
output = decode_image(data, apply_exif_orientation=True)
|
|
|
|
pimg = Image.open(fp)
|
|
pimg = ImageOps.exif_transpose(pimg)
|
|
|
|
expected = F.pil_to_tensor(pimg)
|
|
torch.testing.assert_close(expected, output)
|
|
|
|
|
|
@pytest.mark.parametrize("size", [65533, 1, 7, 10, 23, 33])
|
|
def test_invalid_exif(tmpdir, size):
|
|
# Inspired from a PIL test:
|
|
# https://github.com/python-pillow/Pillow/blob/8f63748e50378424628155994efd7e0739a4d1d1/Tests/test_file_jpeg.py#L299
|
|
fp = os.path.join(tmpdir, "invalid_exif.jpg")
|
|
t = torch.randint(0, 256, size=(3, 256, 257), dtype=torch.uint8)
|
|
im = F.to_pil_image(t)
|
|
im.save(fp, "JPEG", exif=b"1" * size)
|
|
|
|
data = read_file(fp)
|
|
output = decode_image(data, apply_exif_orientation=True)
|
|
|
|
pimg = Image.open(fp)
|
|
pimg = ImageOps.exif_transpose(pimg)
|
|
|
|
expected = F.pil_to_tensor(pimg)
|
|
torch.testing.assert_close(expected, output)
|
|
|
|
|
|
def test_decode_bad_huffman_images():
|
|
# sanity check: make sure we can decode the bad Huffman encoding
|
|
bad_huff = read_file(os.path.join(DAMAGED_JPEG, "bad_huffman.jpg"))
|
|
decode_jpeg(bad_huff)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"img_path",
|
|
[
|
|
pytest.param(truncated_image, id=_get_safe_image_name(truncated_image))
|
|
for truncated_image in glob.glob(os.path.join(DAMAGED_JPEG, "corrupt*.jpg"))
|
|
],
|
|
)
|
|
def test_damaged_corrupt_images(img_path):
|
|
# Truncated images should raise an exception
|
|
data = read_file(img_path)
|
|
if "corrupt34" in img_path:
|
|
match_message = "Image is incomplete or truncated"
|
|
else:
|
|
match_message = "Unsupported marker type"
|
|
with pytest.raises(RuntimeError, match=match_message):
|
|
decode_jpeg(data)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"img_path",
|
|
[pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(FAKEDATA_DIR, ".png")],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"pil_mode, mode",
|
|
[
|
|
(None, ImageReadMode.UNCHANGED),
|
|
("L", ImageReadMode.GRAY),
|
|
("LA", ImageReadMode.GRAY_ALPHA),
|
|
("RGB", ImageReadMode.RGB),
|
|
("RGBA", ImageReadMode.RGB_ALPHA),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("scripted", (False, True))
|
|
@pytest.mark.parametrize("decode_fun", (decode_png, decode_image))
|
|
def test_decode_png(img_path, pil_mode, mode, scripted, decode_fun):
|
|
|
|
if scripted:
|
|
decode_fun = torch.jit.script(decode_fun)
|
|
|
|
with Image.open(img_path) as img:
|
|
if pil_mode is not None:
|
|
img = img.convert(pil_mode)
|
|
img_pil = torch.from_numpy(np.array(img))
|
|
|
|
img_pil = normalize_dimensions(img_pil)
|
|
|
|
if img_path.endswith("16.png"):
|
|
data = read_file(img_path)
|
|
img_lpng = decode_fun(data, mode=mode)
|
|
assert img_lpng.dtype == torch.uint16
|
|
# PIL converts 16 bits pngs to uint8
|
|
img_lpng = F.to_dtype(img_lpng, torch.uint8, scale=True)
|
|
else:
|
|
data = read_file(img_path)
|
|
img_lpng = decode_fun(data, mode=mode)
|
|
|
|
tol = 0 if pil_mode is None else 1
|
|
|
|
if PILLOW_VERSION >= (8, 3) and pil_mode == "LA":
|
|
# Avoid checking the transparency channel until
|
|
# https://github.com/python-pillow/Pillow/issues/5593#issuecomment-878244910
|
|
# is fixed.
|
|
# TODO: remove once fix is released in PIL. Should be > 8.3.1.
|
|
img_lpng, img_pil = img_lpng[0], img_pil[0]
|
|
|
|
torch.testing.assert_close(img_lpng, img_pil, atol=tol, rtol=0)
|
|
|
|
|
|
def test_decode_png_errors():
|
|
with pytest.raises(RuntimeError, match="Out of bound read in decode_png"):
|
|
decode_png(read_file(os.path.join(DAMAGED_PNG, "sigsegv.png")))
|
|
with pytest.raises(RuntimeError, match="Content is too small for png"):
|
|
decode_png(read_file(os.path.join(TOOSMALL_PNG, "heapbof.png")))
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"img_path",
|
|
[pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
|
|
)
|
|
@pytest.mark.parametrize("scripted", (True, False))
|
|
def test_encode_png(img_path, scripted):
|
|
pil_image = Image.open(img_path)
|
|
img_pil = torch.from_numpy(np.array(pil_image))
|
|
img_pil = img_pil.permute(2, 0, 1)
|
|
encode = torch.jit.script(encode_png) if scripted else encode_png
|
|
png_buf = encode(img_pil, compression_level=6)
|
|
|
|
rec_img = Image.open(io.BytesIO(bytes(png_buf.tolist())))
|
|
rec_img = torch.from_numpy(np.array(rec_img))
|
|
rec_img = rec_img.permute(2, 0, 1)
|
|
|
|
assert_equal(img_pil, rec_img)
|
|
|
|
|
|
def test_encode_png_errors():
|
|
with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
|
|
encode_png(torch.empty((3, 100, 100), dtype=torch.float32))
|
|
|
|
with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
|
|
encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=-1)
|
|
|
|
with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
|
|
encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=10)
|
|
|
|
with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
|
|
encode_png(torch.empty((5, 100, 100), dtype=torch.uint8))
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"img_path",
|
|
[pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
|
|
)
|
|
@pytest.mark.parametrize("scripted", (True, False))
|
|
def test_write_png(img_path, tmpdir, scripted):
|
|
pil_image = Image.open(img_path)
|
|
img_pil = torch.from_numpy(np.array(pil_image))
|
|
img_pil = img_pil.permute(2, 0, 1)
|
|
|
|
filename, _ = os.path.splitext(os.path.basename(img_path))
|
|
torch_png = os.path.join(tmpdir, f"{filename}_torch.png")
|
|
write = torch.jit.script(write_png) if scripted else write_png
|
|
write(img_pil, torch_png, compression_level=6)
|
|
saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
|
|
saved_image = saved_image.permute(2, 0, 1)
|
|
|
|
assert_equal(img_pil, saved_image)
|
|
|
|
|
|
def test_read_image():
|
|
# Just testing torchcsript, the functionality is somewhat tested already in other tests.
|
|
path = next(get_images(IMAGE_ROOT, ".jpg"))
|
|
out = read_image(path)
|
|
out_scripted = torch.jit.script(read_image)(path)
|
|
torch.testing.assert_close(out, out_scripted, atol=0, rtol=0)
|
|
|
|
|
|
@pytest.mark.parametrize("scripted", (True, False))
|
|
def test_read_file(tmpdir, scripted):
|
|
fname, content = "test1.bin", b"TorchVision\211\n"
|
|
fpath = os.path.join(tmpdir, fname)
|
|
with open(fpath, "wb") as f:
|
|
f.write(content)
|
|
|
|
fun = torch.jit.script(read_file) if scripted else read_file
|
|
data = fun(fpath)
|
|
expected = torch.tensor(list(content), dtype=torch.uint8)
|
|
os.unlink(fpath)
|
|
assert_equal(data, expected)
|
|
|
|
with pytest.raises(RuntimeError, match="No such file or directory: 'tst'"):
|
|
read_file("tst")
|
|
|
|
|
|
def test_read_file_non_ascii(tmpdir):
|
|
fname, content = "日本語(Japanese).bin", b"TorchVision\211\n"
|
|
fpath = os.path.join(tmpdir, fname)
|
|
with open(fpath, "wb") as f:
|
|
f.write(content)
|
|
|
|
data = read_file(fpath)
|
|
expected = torch.tensor(list(content), dtype=torch.uint8)
|
|
os.unlink(fpath)
|
|
assert_equal(data, expected)
|
|
|
|
|
|
@pytest.mark.parametrize("scripted", (True, False))
|
|
def test_write_file(tmpdir, scripted):
|
|
fname, content = "test1.bin", b"TorchVision\211\n"
|
|
fpath = os.path.join(tmpdir, fname)
|
|
content_tensor = torch.tensor(list(content), dtype=torch.uint8)
|
|
write = torch.jit.script(write_file) if scripted else write_file
|
|
write(fpath, content_tensor)
|
|
|
|
with open(fpath, "rb") as f:
|
|
saved_content = f.read()
|
|
os.unlink(fpath)
|
|
assert content == saved_content
|
|
|
|
|
|
def test_write_file_non_ascii(tmpdir):
|
|
fname, content = "日本語(Japanese).bin", b"TorchVision\211\n"
|
|
fpath = os.path.join(tmpdir, fname)
|
|
content_tensor = torch.tensor(list(content), dtype=torch.uint8)
|
|
write_file(fpath, content_tensor)
|
|
|
|
with open(fpath, "rb") as f:
|
|
saved_content = f.read()
|
|
os.unlink(fpath)
|
|
assert content == saved_content
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"shape",
|
|
[
|
|
(27, 27),
|
|
(60, 60),
|
|
(105, 105),
|
|
],
|
|
)
|
|
def test_read_1_bit_png(shape, tmpdir):
|
|
np_rng = np.random.RandomState(0)
|
|
image_path = os.path.join(tmpdir, f"test_{shape}.png")
|
|
pixels = np_rng.rand(*shape) > 0.5
|
|
img = Image.fromarray(pixels)
|
|
img.save(image_path)
|
|
img1 = read_image(image_path)
|
|
img2 = normalize_dimensions(torch.as_tensor(pixels * 255, dtype=torch.uint8))
|
|
assert_equal(img1, img2)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"shape",
|
|
[
|
|
(27, 27),
|
|
(60, 60),
|
|
(105, 105),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"mode",
|
|
[
|
|
ImageReadMode.UNCHANGED,
|
|
ImageReadMode.GRAY,
|
|
],
|
|
)
|
|
def test_read_1_bit_png_consistency(shape, mode, tmpdir):
|
|
np_rng = np.random.RandomState(0)
|
|
image_path = os.path.join(tmpdir, f"test_{shape}.png")
|
|
pixels = np_rng.rand(*shape) > 0.5
|
|
img = Image.fromarray(pixels)
|
|
img.save(image_path)
|
|
img1 = read_image(image_path, mode)
|
|
img2 = read_image(image_path, mode)
|
|
assert_equal(img1, img2)
|
|
|
|
|
|
def test_read_interlaced_png():
|
|
imgs = list(get_images(INTERLACED_PNG, ".png"))
|
|
with Image.open(imgs[0]) as im1, Image.open(imgs[1]) as im2:
|
|
assert im1.info.get("interlace") is not im2.info.get("interlace")
|
|
img1 = read_image(imgs[0])
|
|
img2 = read_image(imgs[1])
|
|
assert_equal(img1, img2)
|
|
|
|
|
|
@needs_cuda
|
|
@pytest.mark.parametrize("mode", [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB])
|
|
@pytest.mark.parametrize("scripted", (False, True))
|
|
def test_decode_jpegs_cuda(mode, scripted):
|
|
encoded_images = []
|
|
for jpeg_path in get_images(IMAGE_ROOT, ".jpg"):
|
|
if "cmyk" in jpeg_path:
|
|
continue
|
|
encoded_image = read_file(jpeg_path)
|
|
encoded_images.append(encoded_image)
|
|
decoded_images_cpu = decode_jpeg(encoded_images, mode=mode)
|
|
decode_fn = torch.jit.script(decode_jpeg) if scripted else decode_jpeg
|
|
|
|
# test multithreaded decoding
|
|
# in the current version we prevent this by using a lock but we still want to test it
|
|
num_workers = 10
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
|
|
futures = [executor.submit(decode_fn, encoded_images, mode, "cuda") for _ in range(num_workers)]
|
|
decoded_images_threaded = [future.result() for future in futures]
|
|
assert len(decoded_images_threaded) == num_workers
|
|
for decoded_images in decoded_images_threaded:
|
|
assert len(decoded_images) == len(encoded_images)
|
|
for decoded_image_cuda, decoded_image_cpu in zip(decoded_images, decoded_images_cpu):
|
|
assert decoded_image_cuda.shape == decoded_image_cpu.shape
|
|
assert decoded_image_cuda.dtype == decoded_image_cpu.dtype == torch.uint8
|
|
assert (decoded_image_cuda.cpu().float() - decoded_image_cpu.cpu().float()).abs().mean() < 2
|
|
|
|
|
|
@needs_cuda
|
|
def test_decode_image_cuda_raises():
|
|
data = torch.randint(0, 127, size=(255,), device="cuda", dtype=torch.uint8)
|
|
with pytest.raises(RuntimeError):
|
|
decode_image(data)
|
|
|
|
|
|
@needs_cuda
|
|
def test_decode_jpeg_cuda_device_param():
|
|
path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path)
|
|
data = read_file(path)
|
|
current_device = torch.cuda.current_device()
|
|
current_stream = torch.cuda.current_stream()
|
|
num_devices = torch.cuda.device_count()
|
|
devices = ["cuda", torch.device("cuda")] + [torch.device(f"cuda:{i}") for i in range(num_devices)]
|
|
results = []
|
|
for device in devices:
|
|
results.append(decode_jpeg(data, device=device))
|
|
assert len(results) == len(devices)
|
|
for result in results:
|
|
assert torch.all(result.cpu() == results[0].cpu())
|
|
assert current_device == torch.cuda.current_device()
|
|
assert current_stream == torch.cuda.current_stream()
|
|
|
|
|
|
@needs_cuda
|
|
def test_decode_jpeg_cuda_errors():
|
|
data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
|
|
with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
|
|
decode_jpeg(data.reshape(-1, 1), device="cuda")
|
|
with pytest.raises(ValueError, match="must be tensors"):
|
|
decode_jpeg([1, 2, 3])
|
|
with pytest.raises(ValueError, match="Input tensor must be a CPU tensor"):
|
|
decode_jpeg(data.to("cuda"), device="cuda")
|
|
with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
|
|
decode_jpeg(data.to(torch.float), device="cuda")
|
|
with pytest.raises(RuntimeError, match="Expected the device parameter to be a cuda device"):
|
|
torch.ops.image.decode_jpegs_cuda([data], ImageReadMode.UNCHANGED.value, "cpu")
|
|
with pytest.raises(ValueError, match="Input tensor must be a CPU tensor"):
|
|
decode_jpeg(
|
|
torch.empty((100,), dtype=torch.uint8, device="cuda"),
|
|
)
|
|
with pytest.raises(ValueError, match="Input list must contain tensors on CPU"):
|
|
decode_jpeg(
|
|
[
|
|
torch.empty((100,), dtype=torch.uint8, device="cuda"),
|
|
torch.empty((100,), dtype=torch.uint8, device="cuda"),
|
|
]
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="Input list must contain tensors on CPU"):
|
|
decode_jpeg(
|
|
[
|
|
torch.empty((100,), dtype=torch.uint8, device="cuda"),
|
|
torch.empty((100,), dtype=torch.uint8, device="cuda"),
|
|
],
|
|
device="cuda",
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="Input list must contain tensors on CPU"):
|
|
decode_jpeg(
|
|
[
|
|
torch.empty((100,), dtype=torch.uint8, device="cpu"),
|
|
torch.empty((100,), dtype=torch.uint8, device="cuda"),
|
|
],
|
|
device="cuda",
|
|
)
|
|
|
|
with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
|
|
decode_jpeg(
|
|
[
|
|
torch.empty((100,), dtype=torch.uint8),
|
|
torch.empty((100,), dtype=torch.float32),
|
|
],
|
|
device="cuda",
|
|
)
|
|
|
|
with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
|
|
decode_jpeg(
|
|
[
|
|
torch.empty((100,), dtype=torch.uint8),
|
|
torch.empty((1, 100), dtype=torch.uint8),
|
|
],
|
|
device="cuda",
|
|
)
|
|
|
|
with pytest.raises(RuntimeError, match="Error while decoding JPEG images"):
|
|
decode_jpeg(
|
|
[
|
|
torch.empty((100,), dtype=torch.uint8),
|
|
torch.empty((100,), dtype=torch.uint8),
|
|
],
|
|
device="cuda",
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="Input list must contain at least one element"):
|
|
decode_jpeg([], device="cuda")
|
|
|
|
|
|
def test_encode_jpeg_errors():
|
|
|
|
with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
|
|
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))
|
|
|
|
with pytest.raises(ValueError, match="Image quality should be a positive number between 1 and 100"):
|
|
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)
|
|
|
|
with pytest.raises(ValueError, match="Image quality should be a positive number between 1 and 100"):
|
|
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)
|
|
|
|
with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
|
|
encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8))
|
|
|
|
with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
|
|
encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8))
|
|
|
|
with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
|
|
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
|
|
|
|
|
|
@pytest.mark.skipif(IS_MACOS, reason="https://github.com/pytorch/vision/issues/8031")
|
|
@pytest.mark.parametrize(
|
|
"img_path",
|
|
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
|
|
)
|
|
@pytest.mark.parametrize("scripted", (True, False))
|
|
def test_encode_jpeg(img_path, scripted):
|
|
img = read_image(img_path)
|
|
|
|
pil_img = F.to_pil_image(img)
|
|
buf = io.BytesIO()
|
|
pil_img.save(buf, format="JPEG", quality=75)
|
|
|
|
encoded_jpeg_pil = torch.frombuffer(buf.getvalue(), dtype=torch.uint8)
|
|
|
|
encode = torch.jit.script(encode_jpeg) if scripted else encode_jpeg
|
|
for src_img in [img, img.contiguous()]:
|
|
encoded_jpeg_torch = encode(src_img, quality=75)
|
|
assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)
|
|
|
|
|
|
@needs_cuda
|
|
def test_encode_jpeg_cuda_device_param():
|
|
path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path)
|
|
|
|
data = read_image(path)
|
|
|
|
current_device = torch.cuda.current_device()
|
|
current_stream = torch.cuda.current_stream()
|
|
num_devices = torch.cuda.device_count()
|
|
devices = ["cuda", torch.device("cuda")] + [torch.device(f"cuda:{i}") for i in range(num_devices)]
|
|
results = []
|
|
for device in devices:
|
|
results.append(encode_jpeg(data.to(device=device)))
|
|
assert len(results) == len(devices)
|
|
for result in results:
|
|
assert torch.all(result.cpu() == results[0].cpu())
|
|
assert current_device == torch.cuda.current_device()
|
|
assert current_stream == torch.cuda.current_stream()
|
|
|
|
|
|
@needs_cuda
|
|
@pytest.mark.parametrize(
|
|
"img_path",
|
|
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
|
|
)
|
|
@pytest.mark.parametrize("scripted", (False, True))
|
|
@pytest.mark.parametrize("contiguous", (False, True))
|
|
def test_encode_jpeg_cuda(img_path, scripted, contiguous):
|
|
decoded_image_tv = read_image(img_path)
|
|
encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg
|
|
|
|
if "cmyk" in img_path:
|
|
pytest.xfail("Encoding a CMYK jpeg isn't supported")
|
|
if decoded_image_tv.shape[0] == 1:
|
|
pytest.xfail("Decoding a grayscale jpeg isn't supported")
|
|
# For more detail as to why check out: https://github.com/NVIDIA/cuda-samples/issues/23#issuecomment-559283013
|
|
if contiguous:
|
|
decoded_image_tv = decoded_image_tv[None].contiguous(memory_format=torch.contiguous_format)[0]
|
|
else:
|
|
decoded_image_tv = decoded_image_tv[None].contiguous(memory_format=torch.channels_last)[0]
|
|
encoded_jpeg_cuda_tv = encode_fn(decoded_image_tv.cuda(), quality=75)
|
|
decoded_jpeg_cuda_tv = decode_jpeg(encoded_jpeg_cuda_tv.cpu())
|
|
|
|
# the actual encoded bytestreams from libnvjpeg and libjpeg-turbo differ for the same quality
|
|
# instead, we re-decode the encoded image and compare to the original
|
|
abs_mean_diff = (decoded_jpeg_cuda_tv.float() - decoded_image_tv.float()).abs().mean().item()
|
|
assert abs_mean_diff < 3
|
|
|
|
|
|
@needs_cuda
|
|
def test_encode_jpeg_cuda_sync():
|
|
"""
|
|
Non-regression test for https://github.com/pytorch/vision/issues/8587.
|
|
Attempts to reproduce an intermittent CUDA stream synchronization bug
|
|
by randomly creating images and round-tripping them via encode_jpeg
|
|
and decode_jpeg on the GPU. Fails if the mean difference in uint8 range
|
|
exceeds 5.
|
|
"""
|
|
torch.manual_seed(42)
|
|
|
|
# manual testing shows this bug appearing often in iterations between 50 and 100
|
|
# as a synchronization bug, this can't be reliably reproduced
|
|
max_iterations = 100
|
|
threshold = 5.0 # in [0..255]
|
|
|
|
device = torch.device("cuda")
|
|
|
|
for iteration in range(max_iterations):
|
|
height, width = torch.randint(4000, 5000, size=(2,))
|
|
|
|
image = torch.linspace(0, 1, steps=height * width, device=device)
|
|
image = image.view(1, height, width).expand(3, -1, -1)
|
|
|
|
image = (image * 255).clamp(0, 255).to(torch.uint8)
|
|
jpeg_bytes = encode_jpeg(image, quality=100)
|
|
|
|
decoded_image = decode_jpeg(jpeg_bytes.cpu(), device=device)
|
|
mean_difference = (image.float() - decoded_image.float()).abs().mean().item()
|
|
|
|
assert mean_difference <= threshold, (
|
|
f"Encode/decode mismatch at iteration={iteration}, "
|
|
f"size={height}x{width}, mean diff={mean_difference:.2f}"
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("device", cpu_and_cuda())
|
|
@pytest.mark.parametrize("scripted", (True, False))
|
|
@pytest.mark.parametrize("contiguous", (True, False))
|
|
def test_encode_jpegs_batch(scripted, contiguous, device):
|
|
if device == "cpu" and IS_MACOS:
|
|
pytest.skip("https://github.com/pytorch/vision/issues/8031")
|
|
decoded_images_tv = []
|
|
for jpeg_path in get_images(IMAGE_ROOT, ".jpg"):
|
|
if "cmyk" in jpeg_path:
|
|
continue
|
|
decoded_image = read_image(jpeg_path)
|
|
if decoded_image.shape[0] == 1:
|
|
continue
|
|
if contiguous:
|
|
decoded_image = decoded_image[None].contiguous(memory_format=torch.contiguous_format)[0]
|
|
else:
|
|
decoded_image = decoded_image[None].contiguous(memory_format=torch.channels_last)[0]
|
|
decoded_images_tv.append(decoded_image)
|
|
|
|
encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg
|
|
|
|
decoded_images_tv_device = [img.to(device=device) for img in decoded_images_tv]
|
|
encoded_jpegs_tv_device = encode_fn(decoded_images_tv_device, quality=75)
|
|
encoded_jpegs_tv_device = [decode_jpeg(img.cpu()) for img in encoded_jpegs_tv_device]
|
|
|
|
for original, encoded_decoded in zip(decoded_images_tv, encoded_jpegs_tv_device):
|
|
c, h, w = original.shape
|
|
abs_mean_diff = (original.float() - encoded_decoded.float()).abs().mean().item()
|
|
assert abs_mean_diff < 3
|
|
|
|
# test multithreaded decoding
|
|
# in the current version we prevent this by using a lock but we still want to test it
|
|
num_workers = 10
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
|
|
futures = [executor.submit(encode_fn, decoded_images_tv_device) for _ in range(num_workers)]
|
|
encoded_images_threaded = [future.result() for future in futures]
|
|
assert len(encoded_images_threaded) == num_workers
|
|
for encoded_images in encoded_images_threaded:
|
|
assert len(decoded_images_tv_device) == len(encoded_images)
|
|
for i, (encoded_image_cuda, decoded_image_tv) in enumerate(zip(encoded_images, decoded_images_tv_device)):
|
|
# make sure all the threads produce identical outputs
|
|
assert torch.all(encoded_image_cuda == encoded_images_threaded[0][i])
|
|
|
|
# make sure the outputs are identical or close enough to baseline
|
|
decoded_cuda_encoded_image = decode_jpeg(encoded_image_cuda.cpu())
|
|
assert decoded_cuda_encoded_image.shape == decoded_image_tv.shape
|
|
assert decoded_cuda_encoded_image.dtype == decoded_image_tv.dtype
|
|
assert (decoded_cuda_encoded_image.cpu().float() - decoded_image_tv.cpu().float()).abs().mean() < 3
|
|
|
|
|
|
@needs_cuda
|
|
def test_single_encode_jpeg_cuda_errors():
|
|
with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
|
|
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32, device="cuda"))
|
|
|
|
with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"):
|
|
encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda"))
|
|
|
|
with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"):
|
|
encode_jpeg(torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda"))
|
|
|
|
with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
|
|
encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda"))
|
|
|
|
with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
|
|
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8, device="cuda"))
|
|
|
|
|
|
@needs_cuda
|
|
def test_batch_encode_jpegs_cuda_errors():
|
|
with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
|
|
encode_jpeg(
|
|
[
|
|
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
|
|
torch.empty((3, 100, 100), dtype=torch.float32, device="cuda"),
|
|
]
|
|
)
|
|
|
|
with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"):
|
|
encode_jpeg(
|
|
[
|
|
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
|
|
torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda"),
|
|
]
|
|
)
|
|
|
|
with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"):
|
|
encode_jpeg(
|
|
[
|
|
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
|
|
torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda"),
|
|
]
|
|
)
|
|
|
|
with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
|
|
encode_jpeg(
|
|
[
|
|
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
|
|
torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda"),
|
|
]
|
|
)
|
|
|
|
with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
|
|
encode_jpeg(
|
|
[
|
|
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
|
|
torch.empty((100, 100), dtype=torch.uint8, device="cuda"),
|
|
]
|
|
)
|
|
|
|
with pytest.raises(RuntimeError, match="Input tensor should be on CPU"):
|
|
encode_jpeg(
|
|
[
|
|
torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"),
|
|
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
|
|
]
|
|
)
|
|
|
|
with pytest.raises(
|
|
RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg"
|
|
):
|
|
encode_jpeg(
|
|
[
|
|
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
|
|
torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"),
|
|
]
|
|
)
|
|
|
|
if torch.cuda.device_count() >= 2:
|
|
with pytest.raises(
|
|
RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg"
|
|
):
|
|
encode_jpeg(
|
|
[
|
|
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:0"),
|
|
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:1"),
|
|
]
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="encode_jpeg requires at least one input tensor when a list is passed"):
|
|
encode_jpeg([])
|
|
|
|
|
|
@pytest.mark.skipif(IS_MACOS, reason="https://github.com/pytorch/vision/issues/8031")
|
|
@pytest.mark.parametrize(
|
|
"img_path",
|
|
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
|
|
)
|
|
@pytest.mark.parametrize("scripted", (True, False))
|
|
def test_write_jpeg(img_path, tmpdir, scripted):
|
|
tmpdir = Path(tmpdir)
|
|
img = read_image(img_path)
|
|
pil_img = F.to_pil_image(img)
|
|
|
|
torch_jpeg = str(tmpdir / "torch.jpg")
|
|
pil_jpeg = str(tmpdir / "pil.jpg")
|
|
|
|
write = torch.jit.script(write_jpeg) if scripted else write_jpeg
|
|
write(img, torch_jpeg, quality=75)
|
|
pil_img.save(pil_jpeg, quality=75)
|
|
|
|
with open(torch_jpeg, "rb") as f:
|
|
torch_bytes = f.read()
|
|
|
|
with open(pil_jpeg, "rb") as f:
|
|
pil_bytes = f.read()
|
|
|
|
assert_equal(torch_bytes, pil_bytes)
|
|
|
|
|
|
def test_pathlib_support(tmpdir):
|
|
# Just make sure pathlib.Path is supported where relevant
|
|
|
|
jpeg_path = Path(next(get_images(ENCODE_JPEG, ".jpg")))
|
|
|
|
read_file(jpeg_path)
|
|
read_image(jpeg_path)
|
|
|
|
write_path = Path(tmpdir) / "whatever"
|
|
img = torch.randint(0, 10, size=(3, 4, 4), dtype=torch.uint8)
|
|
|
|
write_file(write_path, data=img.flatten())
|
|
write_jpeg(img, write_path)
|
|
write_png(img, write_path)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"name", ("gifgrid", "fire", "porsche", "treescap", "treescap-interlaced", "solid2", "x-trans", "earth")
|
|
)
|
|
@pytest.mark.parametrize("scripted", (True, False))
|
|
def test_decode_gif(tmpdir, name, scripted):
|
|
# Using test images from GIFLIB
|
|
# https://sourceforge.net/p/giflib/code/ci/master/tree/pic/, we assert PIL
|
|
# and torchvision decoded outputs are equal.
|
|
# We're not testing against "welcome2" because PIL and GIFLIB disagee on what
|
|
# the background color should be (likely a difference in the way they handle
|
|
# transparency?)
|
|
# 'earth' image is from wikipedia, licensed under CC BY-SA 3.0
|
|
# https://creativecommons.org/licenses/by-sa/3.0/
|
|
# it allows to properly test for transparency, TOP-LEFT offsets, and
|
|
# disposal modes.
|
|
|
|
path = tmpdir / f"{name}.gif"
|
|
if name == "earth":
|
|
if IN_OSS_CI:
|
|
# TODO: Fix this... one day.
|
|
pytest.skip("Skipping 'earth' test as it's flaky on OSS CI")
|
|
url = "https://upload.wikimedia.org/wikipedia/commons/2/2c/Rotating_earth_%28large%29.gif"
|
|
else:
|
|
url = f"https://sourceforge.net/p/giflib/code/ci/master/tree/pic/{name}.gif?format=raw"
|
|
with open(path, "wb") as f:
|
|
f.write(requests.get(url).content)
|
|
|
|
encoded_bytes = read_file(path)
|
|
f = torch.jit.script(decode_gif) if scripted else decode_gif
|
|
tv_out = f(encoded_bytes)
|
|
if tv_out.ndim == 3:
|
|
tv_out = tv_out[None]
|
|
|
|
assert tv_out.is_contiguous(memory_format=torch.channels_last)
|
|
|
|
# For some reason, not using Image.open() as a CM causes "ResourceWarning: unclosed file"
|
|
with Image.open(path) as pil_img:
|
|
pil_seq = ImageSequence.Iterator(pil_img)
|
|
|
|
for pil_frame, tv_frame in zip(pil_seq, tv_out):
|
|
pil_frame = F.pil_to_tensor(pil_frame.convert("RGB"))
|
|
torch.testing.assert_close(tv_frame, pil_frame, atol=0, rtol=0)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"decode_fun, match",
|
|
[
|
|
(decode_png, "Content is not png"),
|
|
(decode_jpeg, "Not a JPEG file"),
|
|
(decode_gif, re.escape("DGifOpenFileName() failed - 103")),
|
|
(decode_webp, "WebPGetFeatures failed."),
|
|
pytest.param(
|
|
decode_avif,
|
|
"BMFF parsing failed",
|
|
# marks=pytest.mark.skipif(not IS_LINUX, reason=HEIC_AVIF_MESSAGE)
|
|
marks=pytest.mark.skipif(True, reason="Skipping avif/heic tests for now."),
|
|
),
|
|
pytest.param(
|
|
decode_heic,
|
|
"Invalid input: No 'ftyp' box",
|
|
# marks=pytest.mark.skipif(not IS_LINUX, reason=HEIC_AVIF_MESSAGE),
|
|
marks=pytest.mark.skipif(True, reason="Skipping avif/heic tests for now."),
|
|
),
|
|
],
|
|
)
|
|
def test_decode_bad_encoded_data(decode_fun, match):
|
|
encoded_data = torch.randint(0, 256, (100,), dtype=torch.uint8)
|
|
with pytest.raises(RuntimeError, match="Input tensor must be 1-dimensional"):
|
|
decode_fun(encoded_data[None])
|
|
with pytest.raises(RuntimeError, match="Input tensor must have uint8 data type"):
|
|
decode_fun(encoded_data.float())
|
|
with pytest.raises(RuntimeError, match="Input tensor must be contiguous"):
|
|
decode_fun(encoded_data[::2])
|
|
with pytest.raises(RuntimeError, match=match):
|
|
decode_fun(encoded_data)
|
|
|
|
|
|
@pytest.mark.parametrize("decode_fun", (decode_webp, decode_image))
|
|
@pytest.mark.parametrize("scripted", (False, True))
|
|
def test_decode_webp(decode_fun, scripted):
|
|
encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".webp")))
|
|
if scripted:
|
|
decode_fun = torch.jit.script(decode_fun)
|
|
img = decode_fun(encoded_bytes)
|
|
assert img.shape == (3, 100, 100)
|
|
assert img[None].is_contiguous(memory_format=torch.channels_last)
|
|
img += 123 # make sure image buffer wasn't freed by underlying decoding lib
|
|
|
|
|
|
@pytest.mark.parametrize("decode_fun", (decode_webp, decode_image))
|
|
def test_decode_webp_grayscale(decode_fun, capfd):
|
|
encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".webp")))
|
|
|
|
# We warn at the C++ layer because for decode_image(), we don't do the image
|
|
# type dispatch until we get to the C++ version of decode_image(). We could
|
|
# warn at the Python layer in decode_webp(), but then users would get a
|
|
# double wanring: one from the Python layer and one from the C++ layer.
|
|
#
|
|
# Because we use the TORCH_WARN_ONCE macro, we need to do this dance to
|
|
# temporarily always warn so we can test.
|
|
@contextlib.contextmanager
|
|
def set_always_warn():
|
|
torch._C._set_warnAlways(True)
|
|
yield
|
|
torch._C._set_warnAlways(False)
|
|
|
|
with set_always_warn():
|
|
img = decode_fun(encoded_bytes, mode=ImageReadMode.GRAY)
|
|
assert "Webp does not support grayscale conversions" in capfd.readouterr().err
|
|
|
|
# Note that because we do not support grayscale conversions, we expect
|
|
# that the number of color channels is still 3.
|
|
assert img.shape == (3, 100, 100)
|
|
|
|
|
|
# This test is skipped by default because it requires webp images that we're not
|
|
# including within the repo. The test images were downloaded manually from the
|
|
# different pages of https://developers.google.com/speed/webp/gallery
|
|
@pytest.mark.skipif(not WEBP_TEST_IMAGES_DIR, reason="WEBP_TEST_IMAGES_DIR is not set")
|
|
@pytest.mark.parametrize("decode_fun", (decode_webp, decode_image))
|
|
@pytest.mark.parametrize("scripted", (False, True))
|
|
@pytest.mark.parametrize(
|
|
"mode, pil_mode",
|
|
(
|
|
# Note that converting an RGBA image to RGB leads to bad results because the
|
|
# transparent pixels aren't necessarily set to "black" or "white", they can be
|
|
# random stuff. This is consistent with PIL results.
|
|
(ImageReadMode.RGB, "RGB"),
|
|
(ImageReadMode.RGB_ALPHA, "RGBA"),
|
|
(ImageReadMode.UNCHANGED, None),
|
|
),
|
|
)
|
|
@pytest.mark.parametrize("filename", Path(WEBP_TEST_IMAGES_DIR).glob("*.webp"), ids=lambda p: p.name)
|
|
def test_decode_webp_against_pil(decode_fun, scripted, mode, pil_mode, filename):
|
|
encoded_bytes = read_file(filename)
|
|
if scripted:
|
|
decode_fun = torch.jit.script(decode_fun)
|
|
img = decode_fun(encoded_bytes, mode=mode)
|
|
assert img[None].is_contiguous(memory_format=torch.channels_last)
|
|
|
|
pil_img = Image.open(filename).convert(pil_mode)
|
|
from_pil = F.pil_to_tensor(pil_img)
|
|
assert_equal(img, from_pil)
|
|
img += 123 # make sure image buffer wasn't freed by underlying decoding lib
|
|
|
|
|
|
# @pytest.mark.skipif(not IS_LINUX, reason=HEIC_AVIF_MESSAGE)
|
|
@pytest.mark.skipif(True, reason="Skipping avif/heic tests for now.")
|
|
@pytest.mark.parametrize("decode_fun", (decode_avif,))
|
|
def test_decode_avif(decode_fun):
|
|
encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".avif")))
|
|
img = decode_fun(encoded_bytes)
|
|
assert img.shape == (3, 100, 100)
|
|
assert img[None].is_contiguous(memory_format=torch.channels_last)
|
|
img += 123 # make sure image buffer wasn't freed by underlying decoding lib
|
|
|
|
|
|
# Note: decode_image fails because some of these files have a (valid) signature
|
|
# we don't recognize. We should probably use libmagic....
|
|
# @pytest.mark.skipif(not IS_LINUX, reason=HEIC_AVIF_MESSAGE)
|
|
@pytest.mark.skipif(True, reason="Skipping avif/heic tests for now.")
|
|
@pytest.mark.parametrize("decode_fun", (decode_avif, decode_heic))
|
|
@pytest.mark.parametrize(
|
|
"mode, pil_mode",
|
|
(
|
|
(ImageReadMode.RGB, "RGB"),
|
|
(ImageReadMode.RGB_ALPHA, "RGBA"),
|
|
(ImageReadMode.UNCHANGED, None),
|
|
),
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"filename", Path("/home/nicolashug/dev/libavif/tests/data/").glob("*.avif"), ids=lambda p: p.name
|
|
)
|
|
def test_decode_avif_heic_against_pil(decode_fun, mode, pil_mode, filename):
|
|
if "reversed_dimg_order" in str(filename):
|
|
# Pillow properly decodes this one, but we don't (order of parts of the
|
|
# image is wrong). This is due to a bug that was recently fixed in
|
|
# libavif. Hopefully this test will end up passing soon with a new
|
|
# libavif version https://github.com/AOMediaCodec/libavif/issues/2311
|
|
pytest.xfail()
|
|
import pillow_avif # noqa
|
|
|
|
encoded_bytes = read_file(filename)
|
|
try:
|
|
img = decode_fun(encoded_bytes, mode=mode)
|
|
except RuntimeError as e:
|
|
if any(
|
|
s in str(e)
|
|
for s in (
|
|
"BMFF parsing failed",
|
|
"avifDecoderParse failed: ",
|
|
"file contains more than one image",
|
|
"no 'ispe' property",
|
|
"'iref' has double references",
|
|
"Invalid image grid",
|
|
"decode_heif failed: Invalid input: No 'meta' box",
|
|
)
|
|
):
|
|
pytest.skip(reason="Expected failure, that's OK")
|
|
else:
|
|
raise e
|
|
assert img[None].is_contiguous(memory_format=torch.channels_last)
|
|
if mode == ImageReadMode.RGB:
|
|
assert img.shape[0] == 3
|
|
if mode == ImageReadMode.RGB_ALPHA:
|
|
assert img.shape[0] == 4
|
|
|
|
if img.dtype == torch.uint16:
|
|
img = F.to_dtype(img, dtype=torch.uint8, scale=True)
|
|
try:
|
|
from_pil = F.pil_to_tensor(Image.open(filename).convert(pil_mode))
|
|
except RuntimeError as e:
|
|
if any(s in str(e) for s in ("Invalid image grid", "Failed to decode image: Not implemented")):
|
|
pytest.skip(reason="PIL failure")
|
|
else:
|
|
raise e
|
|
|
|
if True:
|
|
from torchvision.utils import make_grid
|
|
|
|
g = make_grid([img, from_pil])
|
|
F.to_pil_image(g).save(f"/home/nicolashug/out_images/{filename.name}.{pil_mode}.png")
|
|
|
|
is_decode_heic = getattr(decode_fun, "__name__", getattr(decode_fun, "name", None)) == "decode_heic"
|
|
if mode == ImageReadMode.RGB and not is_decode_heic:
|
|
# We don't compare torchvision's AVIF against PIL for RGB because
|
|
# results look pretty different on RGBA images (other images are fine).
|
|
# The result on torchvision basically just plainly ignores the alpha
|
|
# channel, resuting in transparent pixels looking dark. PIL seems to be
|
|
# using a sort of k-nn thing (Take a look at the resuting images)
|
|
return
|
|
if filename.name == "sofa_grid1x5_420.avif" and is_decode_heic:
|
|
return
|
|
|
|
torch.testing.assert_close(img, from_pil, rtol=0, atol=3)
|
|
|
|
|
|
# @pytest.mark.skipif(not IS_LINUX, reason=HEIC_AVIF_MESSAGE)
|
|
@pytest.mark.skipif(True, reason="Skipping avif/heic tests for now.")
|
|
@pytest.mark.parametrize("decode_fun", (decode_heic,))
|
|
def test_decode_heic(decode_fun):
|
|
encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".heic")))
|
|
img = decode_fun(encoded_bytes)
|
|
assert img.shape == (3, 100, 100)
|
|
assert img[None].is_contiguous(memory_format=torch.channels_last)
|
|
img += 123 # make sure image buffer wasn't freed by underlying decoding lib
|
|
|
|
|
|
@pytest.mark.parametrize("input_type", ("Path", "str", "tensor"))
|
|
@pytest.mark.parametrize("scripted", (False, True))
|
|
def test_decode_image_path(input_type, scripted):
|
|
# Check that decode_image can support not just tensors as input
|
|
path = next(get_images(IMAGE_ROOT, ".jpg"))
|
|
if input_type == "Path":
|
|
input = Path(path)
|
|
elif input_type == "str":
|
|
input = path
|
|
elif input_type == "tensor":
|
|
input = read_file(path)
|
|
else:
|
|
raise ValueError("Oops")
|
|
|
|
if scripted and input_type == "Path":
|
|
pytest.xfail(reason="Can't pass a Path when scripting")
|
|
|
|
decode_fun = torch.jit.script(decode_image) if scripted else decode_image
|
|
decode_fun(input)
|
|
|
|
|
|
def test_mode_str():
|
|
# Make sure decode_image supports string modes. We just test decode_image,
|
|
# not all of the decoding functions, but they should all support that too.
|
|
# Torchscript fails when passing strings, which is expected.
|
|
path = next(get_images(IMAGE_ROOT, ".png"))
|
|
assert decode_image(path, mode="RGB").shape[0] == 3
|
|
assert decode_image(path, mode="rGb").shape[0] == 3
|
|
assert decode_image(path, mode="GRAY").shape[0] == 1
|
|
assert decode_image(path, mode="RGBA").shape[0] == 4
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__])
|