616 lines
25 KiB
Python
616 lines
25 KiB
Python
import os
|
|
import re
|
|
import sys
|
|
import tempfile
|
|
from io import BytesIO
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
import torchvision.transforms.functional as F
|
|
import torchvision.utils as utils
|
|
from common_utils import assert_equal, cpu_and_cuda
|
|
from PIL import __version__ as PILLOW_VERSION, Image, ImageColor
|
|
from torchvision.transforms.v2.functional import to_dtype
|
|
|
|
|
|
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
|
|
|
|
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
|
|
rotated_boxes = torch.tensor(
|
|
[
|
|
[100, 150, 150, 150, 150, 250, 100, 250],
|
|
[200, 350, 250, 350, 250, 250, 200, 250],
|
|
[300, 200, 200, 200, 200, 250, 300, 250],
|
|
# Not really a rectangle, but it doesn't matter
|
|
[
|
|
100,
|
|
100,
|
|
200,
|
|
50,
|
|
290,
|
|
350,
|
|
200,
|
|
400,
|
|
],
|
|
],
|
|
dtype=torch.float,
|
|
)
|
|
keypoints = torch.tensor([[[10, 10], [5, 5], [2, 2]], [[20, 20], [30, 30], [3, 3]]], dtype=torch.float)
|
|
|
|
|
|
def test_make_grid_not_inplace():
|
|
t = torch.rand(5, 3, 10, 10)
|
|
t_clone = t.clone()
|
|
|
|
utils.make_grid(t, normalize=False)
|
|
assert_equal(t, t_clone, msg="make_grid modified tensor in-place")
|
|
|
|
utils.make_grid(t, normalize=True, scale_each=False)
|
|
assert_equal(t, t_clone, msg="make_grid modified tensor in-place")
|
|
|
|
utils.make_grid(t, normalize=True, scale_each=True)
|
|
assert_equal(t, t_clone, msg="make_grid modified tensor in-place")
|
|
|
|
|
|
def test_normalize_in_make_grid():
|
|
t = torch.rand(5, 3, 10, 10) * 255
|
|
norm_max = torch.tensor(1.0)
|
|
norm_min = torch.tensor(0.0)
|
|
|
|
grid = utils.make_grid(t, normalize=True)
|
|
grid_max = torch.max(grid)
|
|
grid_min = torch.min(grid)
|
|
|
|
# Rounding the result to one decimal for comparison
|
|
n_digits = 1
|
|
rounded_grid_max = torch.round(grid_max * 10**n_digits) / (10**n_digits)
|
|
rounded_grid_min = torch.round(grid_min * 10**n_digits) / (10**n_digits)
|
|
|
|
assert_equal(norm_max, rounded_grid_max, msg="Normalized max is not equal to 1")
|
|
assert_equal(norm_min, rounded_grid_min, msg="Normalized min is not equal to 0")
|
|
|
|
|
|
@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
|
|
def test_save_image():
|
|
with tempfile.NamedTemporaryFile(suffix=".png") as f:
|
|
t = torch.rand(2, 3, 64, 64)
|
|
utils.save_image(t, f.name)
|
|
assert os.path.exists(f.name), "The image is not present after save"
|
|
|
|
|
|
@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
|
|
def test_save_image_single_pixel():
|
|
with tempfile.NamedTemporaryFile(suffix=".png") as f:
|
|
t = torch.rand(1, 3, 1, 1)
|
|
utils.save_image(t, f.name)
|
|
assert os.path.exists(f.name), "The pixel image is not present after save"
|
|
|
|
|
|
@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
|
|
def test_save_image_file_object():
|
|
with tempfile.NamedTemporaryFile(suffix=".png") as f:
|
|
t = torch.rand(2, 3, 64, 64)
|
|
utils.save_image(t, f.name)
|
|
img_orig = Image.open(f.name)
|
|
fp = BytesIO()
|
|
utils.save_image(t, fp, format="png")
|
|
img_bytes = Image.open(fp)
|
|
assert_equal(F.pil_to_tensor(img_orig), F.pil_to_tensor(img_bytes), msg="Image not stored in file object")
|
|
|
|
|
|
@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
|
|
def test_save_image_single_pixel_file_object():
|
|
with tempfile.NamedTemporaryFile(suffix=".png") as f:
|
|
t = torch.rand(1, 3, 1, 1)
|
|
utils.save_image(t, f.name)
|
|
img_orig = Image.open(f.name)
|
|
fp = BytesIO()
|
|
utils.save_image(t, fp, format="png")
|
|
img_bytes = Image.open(fp)
|
|
assert_equal(F.pil_to_tensor(img_orig), F.pil_to_tensor(img_bytes), msg="Image not stored in file object")
|
|
|
|
|
|
def test_draw_boxes():
|
|
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
|
|
img_cp = img.clone()
|
|
boxes_cp = boxes.clone()
|
|
labels = ["a", "b", "c", "d"]
|
|
colors = ["green", "#FF00FF", (0, 255, 0), "red"]
|
|
result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True)
|
|
|
|
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_util.png")
|
|
if not os.path.exists(path):
|
|
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
|
|
res.save(path)
|
|
|
|
if PILLOW_VERSION >= (10, 1):
|
|
# The reference image is only valid for new PIL versions
|
|
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
|
|
assert_equal(result, expected)
|
|
|
|
# Check if modification is not in place
|
|
assert_equal(boxes, boxes_cp)
|
|
assert_equal(img, img_cp)
|
|
|
|
|
|
@pytest.mark.skipif(PILLOW_VERSION < (10, 1), reason="The reference image is only valid for PIL >= 10.1")
|
|
def test_draw_boxes_with_coloured_labels():
|
|
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
|
|
labels = ["a", "b", "c", "d"]
|
|
colors = ["green", "#FF00FF", (0, 255, 0), "red"]
|
|
label_colors = ["green", "red", (0, 255, 0), "#FF00FF"]
|
|
result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True, label_colors=label_colors)
|
|
|
|
path = os.path.join(
|
|
os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_different_label_colors.png"
|
|
)
|
|
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
|
|
assert_equal(result, expected)
|
|
|
|
|
|
@pytest.mark.skipif(PILLOW_VERSION < (10, 1), reason="The reference image is only valid for PIL >= 10.1")
|
|
def test_draw_boxes_with_coloured_label_backgrounds():
|
|
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
|
|
labels = ["a", "b", "c", "d"]
|
|
colors = ["green", "#FF00FF", (0, 255, 0), "red"]
|
|
label_colors = ["green", "red", (0, 255, 0), "#FF00FF"]
|
|
result = utils.draw_bounding_boxes(
|
|
img, boxes, labels=labels, colors=colors, fill=True, label_colors=label_colors, fill_labels=True
|
|
)
|
|
|
|
path = os.path.join(
|
|
os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_different_label_fill_colors.png"
|
|
)
|
|
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
|
|
assert_equal(result, expected)
|
|
|
|
|
|
@pytest.mark.skipif(PILLOW_VERSION < (10, 1), reason="The reference image is only valid for PIL >= 10.1")
|
|
def test_draw_rotated_boxes():
|
|
img = torch.full((3, 500, 500), 255, dtype=torch.uint8)
|
|
colors = ["blue", "yellow", (0, 255, 0), "black"]
|
|
|
|
result = utils.draw_bounding_boxes(img, rotated_boxes, colors=colors)
|
|
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_rotated_boxes.png")
|
|
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
|
|
assert_equal(result, expected)
|
|
|
|
|
|
@pytest.mark.skipif(PILLOW_VERSION < (10, 1), reason="The reference image is only valid for PIL >= 10.1")
|
|
def test_draw_rotated_boxes_fill():
|
|
img = torch.full((3, 500, 500), 255, dtype=torch.uint8)
|
|
colors = ["blue", "yellow", (0, 255, 0), "black"]
|
|
|
|
result = utils.draw_bounding_boxes(img, rotated_boxes, colors=colors, fill=True)
|
|
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_rotated_boxes_fill.png")
|
|
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
|
|
assert_equal(result, expected)
|
|
|
|
|
|
@pytest.mark.parametrize("fill", [True, False])
|
|
def test_draw_boxes_dtypes(fill):
|
|
img_uint8 = torch.full((3, 100, 100), 255, dtype=torch.uint8)
|
|
out_uint8 = utils.draw_bounding_boxes(img_uint8, boxes, fill=fill)
|
|
|
|
assert img_uint8 is not out_uint8
|
|
assert out_uint8.dtype == torch.uint8
|
|
|
|
img_float = to_dtype(img_uint8, torch.float, scale=True)
|
|
out_float = utils.draw_bounding_boxes(img_float, boxes, fill=fill)
|
|
|
|
assert img_float is not out_float
|
|
assert out_float.is_floating_point()
|
|
|
|
torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1)
|
|
|
|
|
|
@pytest.mark.parametrize("colors", [None, ["red", "blue", "#FF00FF", (1, 34, 122)], "red", "#FF00FF", (1, 34, 122)])
|
|
def test_draw_boxes_colors(colors):
|
|
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
|
|
utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors=colors)
|
|
|
|
with pytest.raises(ValueError, match="Number of colors must be equal or larger than the number of objects"):
|
|
utils.draw_bounding_boxes(image=img, boxes=boxes, colors=[])
|
|
|
|
|
|
def test_draw_boxes_vanilla():
|
|
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
|
|
img_cp = img.clone()
|
|
boxes_cp = boxes.clone()
|
|
result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors="white")
|
|
|
|
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png")
|
|
if not os.path.exists(path):
|
|
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
|
|
res.save(path)
|
|
|
|
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
|
|
assert_equal(result, expected)
|
|
# Check if modification is not in place
|
|
assert_equal(boxes, boxes_cp)
|
|
assert_equal(img, img_cp)
|
|
|
|
|
|
def test_draw_boxes_grayscale():
|
|
img = torch.full((1, 4, 4), fill_value=255, dtype=torch.uint8)
|
|
boxes = torch.tensor([[0, 0, 3, 3]], dtype=torch.int64)
|
|
bboxed_img = utils.draw_bounding_boxes(image=img, boxes=boxes, colors=["#1BBC9B"])
|
|
assert bboxed_img.size(0) == 3
|
|
|
|
|
|
def test_draw_invalid_boxes():
|
|
img_tp = ((1, 1, 1), (1, 2, 3))
|
|
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
|
|
img_correct = torch.zeros((3, 10, 10), dtype=torch.uint8)
|
|
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
|
|
boxes_wrong = torch.tensor([[10, 10, 4, 5], [30, 20, 10, 5]], dtype=torch.float)
|
|
labels_wrong = ["one", "two"]
|
|
colors_wrong = ["pink", "blue"]
|
|
|
|
with pytest.raises(TypeError, match="Tensor expected"):
|
|
utils.draw_bounding_boxes(img_tp, boxes)
|
|
with pytest.raises(ValueError, match="Pass individual images, not batches"):
|
|
utils.draw_bounding_boxes(img_wrong2, boxes)
|
|
with pytest.raises(ValueError, match="Only grayscale and RGB images are supported"):
|
|
utils.draw_bounding_boxes(img_wrong2[0][:2], boxes)
|
|
with pytest.raises(ValueError, match="Number of boxes"):
|
|
utils.draw_bounding_boxes(img_correct, boxes, labels_wrong)
|
|
with pytest.raises(ValueError, match="Number of colors"):
|
|
utils.draw_bounding_boxes(img_correct, boxes, colors=colors_wrong)
|
|
with pytest.raises(ValueError, match="Boxes need to be in"):
|
|
utils.draw_bounding_boxes(img_correct, boxes_wrong)
|
|
|
|
|
|
def test_draw_boxes_warning():
|
|
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
|
|
|
|
with pytest.warns(UserWarning, match=re.escape("Argument 'font_size' will be ignored since 'font' is not set.")):
|
|
utils.draw_bounding_boxes(img, boxes, font_size=11)
|
|
|
|
|
|
def test_draw_no_boxes():
|
|
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
|
|
boxes = torch.full((0, 4), 0, dtype=torch.float)
|
|
with pytest.warns(UserWarning, match=re.escape("boxes doesn't contain any box. No box was drawn")):
|
|
res = utils.draw_bounding_boxes(img, boxes)
|
|
# Check that the function didn't change the image
|
|
assert res.eq(img).all()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"colors",
|
|
[
|
|
None,
|
|
"blue",
|
|
"#FF00FF",
|
|
(1, 34, 122),
|
|
["red", "blue"],
|
|
["#FF00FF", (1, 34, 122)],
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("alpha", (0, 0.5, 0.7, 1))
|
|
@pytest.mark.parametrize("device", cpu_and_cuda())
|
|
def test_draw_segmentation_masks(colors, alpha, device):
|
|
"""This test makes sure that masks draw their corresponding color where they should"""
|
|
num_masks, h, w = 2, 100, 100
|
|
dtype = torch.uint8
|
|
img = torch.randint(0, 256, size=(3, h, w), dtype=dtype, device=device)
|
|
masks = torch.zeros((num_masks, h, w), dtype=torch.bool, device=device)
|
|
masks[0, 10:20, 10:20] = True
|
|
masks[1, 15:25, 15:25] = True
|
|
|
|
overlap = masks[0] & masks[1]
|
|
|
|
out = utils.draw_segmentation_masks(img, masks, colors=colors, alpha=alpha)
|
|
assert out.dtype == dtype
|
|
assert out is not img
|
|
|
|
# Make sure the image didn't change where there's no mask
|
|
masked_pixels = masks[0] | masks[1]
|
|
assert_equal(img[:, ~masked_pixels], out[:, ~masked_pixels])
|
|
|
|
if colors is None:
|
|
colors = utils._generate_color_palette(num_masks)
|
|
elif isinstance(colors, str) or isinstance(colors, tuple):
|
|
colors = [colors]
|
|
|
|
# Make sure each mask draws with its own color
|
|
for mask, color in zip(masks, colors):
|
|
if isinstance(color, str):
|
|
color = ImageColor.getrgb(color)
|
|
color = torch.tensor(color, dtype=dtype, device=device)
|
|
|
|
if alpha == 1:
|
|
assert (out[:, mask & ~overlap] == color[:, None]).all()
|
|
elif alpha == 0:
|
|
assert (out[:, mask & ~overlap] == img[:, mask & ~overlap]).all()
|
|
|
|
interpolated_color = (img[:, mask & ~overlap] * (1 - alpha) + color[:, None] * alpha).to(dtype)
|
|
torch.testing.assert_close(out[:, mask & ~overlap], interpolated_color, rtol=0.0, atol=1.0)
|
|
|
|
interpolated_overlap = (img[:, overlap] * (1 - alpha)).to(dtype)
|
|
torch.testing.assert_close(out[:, overlap], interpolated_overlap, rtol=0.0, atol=1.0)
|
|
|
|
|
|
def test_draw_segmentation_masks_dtypes():
|
|
num_masks, h, w = 2, 100, 100
|
|
|
|
masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool)
|
|
|
|
img_uint8 = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8)
|
|
out_uint8 = utils.draw_segmentation_masks(img_uint8, masks)
|
|
|
|
assert img_uint8 is not out_uint8
|
|
assert out_uint8.dtype == torch.uint8
|
|
|
|
img_float = to_dtype(img_uint8, torch.float, scale=True)
|
|
out_float = utils.draw_segmentation_masks(img_float, masks)
|
|
|
|
assert img_float is not out_float
|
|
assert out_float.is_floating_point()
|
|
|
|
torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1)
|
|
|
|
|
|
@pytest.mark.parametrize("device", cpu_and_cuda())
|
|
def test_draw_segmentation_masks_errors(device):
|
|
h, w = 10, 10
|
|
|
|
masks = torch.randint(0, 2, size=(h, w), dtype=torch.bool, device=device)
|
|
img = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8, device=device)
|
|
|
|
with pytest.raises(TypeError, match="The image must be a tensor"):
|
|
utils.draw_segmentation_masks(image="Not A Tensor Image", masks=masks)
|
|
with pytest.raises(ValueError, match="The image dtype must be"):
|
|
img_bad_dtype = torch.randint(0, 256, size=(3, h, w), dtype=torch.int64)
|
|
utils.draw_segmentation_masks(image=img_bad_dtype, masks=masks)
|
|
with pytest.raises(ValueError, match="Pass individual images, not batches"):
|
|
batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8)
|
|
utils.draw_segmentation_masks(image=batch, masks=masks)
|
|
with pytest.raises(ValueError, match="Pass an RGB image"):
|
|
one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8)
|
|
utils.draw_segmentation_masks(image=one_channel, masks=masks)
|
|
with pytest.raises(ValueError, match="The masks must be of dtype bool"):
|
|
masks_bad_dtype = torch.randint(0, 2, size=(h, w), dtype=torch.float)
|
|
utils.draw_segmentation_masks(image=img, masks=masks_bad_dtype)
|
|
with pytest.raises(ValueError, match="masks must be of shape"):
|
|
masks_bad_shape = torch.randint(0, 2, size=(3, 2, h, w), dtype=torch.bool)
|
|
utils.draw_segmentation_masks(image=img, masks=masks_bad_shape)
|
|
with pytest.raises(ValueError, match="must have the same height and width"):
|
|
masks_bad_shape = torch.randint(0, 2, size=(h + 4, w), dtype=torch.bool)
|
|
utils.draw_segmentation_masks(image=img, masks=masks_bad_shape)
|
|
with pytest.raises(ValueError, match="Number of colors must be equal or larger than the number of objects"):
|
|
utils.draw_segmentation_masks(image=img, masks=masks, colors=[])
|
|
with pytest.raises(ValueError, match="`colors` must be a tuple or a string, or a list thereof"):
|
|
bad_colors = np.array(["red", "blue"]) # should be a list
|
|
utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
|
|
with pytest.raises(ValueError, match="If passed as tuple, colors should be an RGB triplet"):
|
|
bad_colors = ("red", "blue") # should be a list
|
|
utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
|
|
|
|
|
|
@pytest.mark.parametrize("device", cpu_and_cuda())
|
|
def test_draw_no_segmention_mask(device):
|
|
img = torch.full((3, 100, 100), 0, dtype=torch.uint8, device=device)
|
|
masks = torch.full((0, 100, 100), 0, dtype=torch.bool, device=device)
|
|
with pytest.warns(UserWarning, match=re.escape("masks doesn't contain any mask. No mask was drawn")):
|
|
res = utils.draw_segmentation_masks(img, masks)
|
|
# Check that the function didn't change the image
|
|
assert res.eq(img).all()
|
|
|
|
|
|
def test_draw_keypoints_vanilla():
|
|
# Keypoints is declared on top as global variable
|
|
keypoints_cp = keypoints.clone()
|
|
|
|
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
|
|
img_cp = img.clone()
|
|
result = utils.draw_keypoints(
|
|
img,
|
|
keypoints,
|
|
colors="red",
|
|
connectivity=[
|
|
(0, 1),
|
|
],
|
|
)
|
|
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png")
|
|
if not os.path.exists(path):
|
|
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
|
|
res.save(path)
|
|
|
|
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
|
|
assert_equal(result, expected)
|
|
# Check that keypoints are not modified inplace
|
|
assert_equal(keypoints, keypoints_cp)
|
|
# Check that image is not modified in place
|
|
assert_equal(img, img_cp)
|
|
|
|
|
|
def test_draw_keypoins_K_equals_one():
|
|
# Non-regression test for https://github.com/pytorch/vision/pull/8439
|
|
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
|
|
keypoints = torch.tensor([[[10, 10]]], dtype=torch.float)
|
|
utils.draw_keypoints(img, keypoints)
|
|
|
|
|
|
@pytest.mark.parametrize("colors", ["red", "#FF00FF", (1, 34, 122)])
|
|
def test_draw_keypoints_colored(colors):
|
|
# Keypoints is declared on top as global variable
|
|
keypoints_cp = keypoints.clone()
|
|
|
|
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
|
|
img_cp = img.clone()
|
|
result = utils.draw_keypoints(
|
|
img,
|
|
keypoints,
|
|
colors=colors,
|
|
connectivity=[
|
|
(0, 1),
|
|
],
|
|
)
|
|
assert result.size(0) == 3
|
|
assert_equal(keypoints, keypoints_cp)
|
|
assert_equal(img, img_cp)
|
|
|
|
|
|
@pytest.mark.parametrize("connectivity", [[(0, 1)], [(0, 1), (1, 2)]])
|
|
@pytest.mark.parametrize(
|
|
"vis",
|
|
[
|
|
torch.tensor([[1, 1, 0], [1, 1, 0]], dtype=torch.bool),
|
|
torch.tensor([[1, 1, 0], [1, 1, 0]], dtype=torch.float).unsqueeze_(-1),
|
|
],
|
|
)
|
|
def test_draw_keypoints_visibility(connectivity, vis):
|
|
# Keypoints is declared on top as global variable
|
|
keypoints_cp = keypoints.clone()
|
|
|
|
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
|
|
img_cp = img.clone()
|
|
|
|
vis_cp = vis if vis is None else vis.clone()
|
|
|
|
result = utils.draw_keypoints(
|
|
image=img,
|
|
keypoints=keypoints,
|
|
connectivity=connectivity,
|
|
colors="red",
|
|
visibility=vis,
|
|
)
|
|
assert result.size(0) == 3
|
|
assert_equal(keypoints, keypoints_cp)
|
|
assert_equal(img, img_cp)
|
|
|
|
# compare with a fakedata image
|
|
# connect the key points 0 to 1 for both skeletons and do not show the other key points
|
|
path = os.path.join(
|
|
os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoints_visibility.png"
|
|
)
|
|
if not os.path.exists(path):
|
|
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
|
|
res.save(path)
|
|
|
|
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
|
|
assert_equal(result, expected)
|
|
|
|
if vis_cp is None:
|
|
assert vis is None
|
|
else:
|
|
assert_equal(vis, vis_cp)
|
|
assert vis.dtype == vis_cp.dtype
|
|
|
|
|
|
def test_draw_keypoints_visibility_default():
|
|
# Keypoints is declared on top as global variable
|
|
keypoints_cp = keypoints.clone()
|
|
|
|
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
|
|
img_cp = img.clone()
|
|
|
|
result = utils.draw_keypoints(
|
|
image=img,
|
|
keypoints=keypoints,
|
|
connectivity=[(0, 1)],
|
|
colors="red",
|
|
visibility=None,
|
|
)
|
|
assert result.size(0) == 3
|
|
assert_equal(keypoints, keypoints_cp)
|
|
assert_equal(img, img_cp)
|
|
|
|
# compare against fakedata image, which connects 0->1 for both key-point skeletons
|
|
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png")
|
|
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
|
|
assert_equal(result, expected)
|
|
|
|
|
|
def test_draw_keypoints_dtypes():
|
|
image_uint8 = torch.randint(0, 256, size=(3, 100, 100), dtype=torch.uint8)
|
|
image_float = to_dtype(image_uint8, torch.float, scale=True)
|
|
|
|
out_uint8 = utils.draw_keypoints(image_uint8, keypoints)
|
|
out_float = utils.draw_keypoints(image_float, keypoints)
|
|
|
|
assert out_uint8.dtype == torch.uint8
|
|
assert out_uint8 is not image_uint8
|
|
|
|
assert out_float.is_floating_point()
|
|
assert out_float is not image_float
|
|
|
|
torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1)
|
|
|
|
|
|
def test_draw_keypoints_errors():
|
|
h, w = 10, 10
|
|
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
|
|
|
|
with pytest.raises(TypeError, match="The image must be a tensor"):
|
|
utils.draw_keypoints(image="Not A Tensor Image", keypoints=keypoints)
|
|
with pytest.raises(ValueError, match="The image dtype must be"):
|
|
img_bad_dtype = torch.full((3, h, w), 0, dtype=torch.int64)
|
|
utils.draw_keypoints(image=img_bad_dtype, keypoints=keypoints)
|
|
with pytest.raises(ValueError, match="Pass individual images, not batches"):
|
|
batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8)
|
|
utils.draw_keypoints(image=batch, keypoints=keypoints)
|
|
with pytest.raises(ValueError, match="Pass an RGB image"):
|
|
one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8)
|
|
utils.draw_keypoints(image=one_channel, keypoints=keypoints)
|
|
with pytest.raises(ValueError, match="keypoints must be of shape"):
|
|
invalid_keypoints = torch.tensor([[10, 10, 10, 10], [5, 6, 7, 8]], dtype=torch.float)
|
|
utils.draw_keypoints(image=img, keypoints=invalid_keypoints)
|
|
with pytest.raises(ValueError, match=re.escape("visibility must be of shape (num_instances, K)")):
|
|
one_dim_visibility = torch.tensor([True, True, True], dtype=torch.bool)
|
|
utils.draw_keypoints(image=img, keypoints=keypoints, visibility=one_dim_visibility)
|
|
with pytest.raises(ValueError, match=re.escape("visibility must be of shape (num_instances, K)")):
|
|
three_dim_visibility = torch.ones((2, 3, 4), dtype=torch.bool)
|
|
utils.draw_keypoints(image=img, keypoints=keypoints, visibility=three_dim_visibility)
|
|
with pytest.raises(ValueError, match="keypoints and visibility must have the same dimensionality"):
|
|
vis_wrong_n = torch.ones((3, 3), dtype=torch.bool)
|
|
utils.draw_keypoints(image=img, keypoints=keypoints, visibility=vis_wrong_n)
|
|
with pytest.raises(ValueError, match="keypoints and visibility must have the same dimensionality"):
|
|
vis_wrong_k = torch.ones((2, 4), dtype=torch.bool)
|
|
utils.draw_keypoints(image=img, keypoints=keypoints, visibility=vis_wrong_k)
|
|
|
|
|
|
@pytest.mark.parametrize("batch", (True, False))
|
|
def test_flow_to_image(batch):
|
|
h, w = 100, 100
|
|
flow = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
|
|
flow = torch.stack(flow[::-1], dim=0).float()
|
|
flow[0] -= h / 2
|
|
flow[1] -= w / 2
|
|
|
|
if batch:
|
|
flow = torch.stack([flow, flow])
|
|
|
|
img = utils.flow_to_image(flow)
|
|
assert img.shape == (2, 3, h, w) if batch else (3, h, w)
|
|
|
|
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "expected_flow.pt")
|
|
expected_img = torch.load(path, map_location="cpu", weights_only=True)
|
|
|
|
if batch:
|
|
expected_img = torch.stack([expected_img, expected_img])
|
|
|
|
assert_equal(expected_img, img)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"input_flow, match",
|
|
(
|
|
(torch.full((3, 10, 10), 0, dtype=torch.float), "Input flow should have shape"),
|
|
(torch.full((5, 3, 10, 10), 0, dtype=torch.float), "Input flow should have shape"),
|
|
(torch.full((2, 10), 0, dtype=torch.float), "Input flow should have shape"),
|
|
(torch.full((5, 2, 10), 0, dtype=torch.float), "Input flow should have shape"),
|
|
(torch.full((2, 10, 30), 0, dtype=torch.int), "Flow should be of dtype torch.float"),
|
|
),
|
|
)
|
|
def test_flow_to_image_errors(input_flow, match):
|
|
with pytest.raises(ValueError, match=match):
|
|
utils.flow_to_image(flow=input_flow)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__])
|