import contextlib import gzip import os import pathlib import re import tarfile import zipfile import pytest import torch import torchvision.datasets.utils as utils from common_utils import assert_equal from torch._utils_internal import get_file_path_2 from torchvision.datasets.folder import make_dataset from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS TEST_FILE = get_file_path_2( os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg" ) def patch_url_redirection(mocker, redirect_url): class Response: def __init__(self, url): self.url = url @contextlib.contextmanager def patched_opener(*args, **kwargs): yield Response(redirect_url) return mocker.patch("torchvision.datasets.utils.urllib.request.urlopen", side_effect=patched_opener) class TestDatasetsUtils: def test_get_redirect_url(self, mocker): url = "https://url.org" expected_redirect_url = "https://redirect.url.org" mock = patch_url_redirection(mocker, expected_redirect_url) actual = utils._get_redirect_url(url) assert actual == expected_redirect_url assert mock.call_count == 2 call_args_1, call_args_2 = mock.call_args_list assert call_args_1[0][0].full_url == url assert call_args_2[0][0].full_url == expected_redirect_url def test_get_redirect_url_max_hops_exceeded(self, mocker): url = "https://url.org" redirect_url = "https://redirect.url.org" mock = patch_url_redirection(mocker, redirect_url) with pytest.raises(RecursionError): utils._get_redirect_url(url, max_hops=0) assert mock.call_count == 1 assert mock.call_args[0][0].full_url == url @pytest.mark.parametrize("use_pathlib", (True, False)) def test_check_md5(self, use_pathlib): fpath = TEST_FILE if use_pathlib: fpath = pathlib.Path(fpath) correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc" false_md5 = "" assert utils.check_md5(fpath, correct_md5) assert not utils.check_md5(fpath, false_md5) def test_check_integrity(self): existing_fpath = TEST_FILE nonexisting_fpath = "" correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc" false_md5 = "" assert utils.check_integrity(existing_fpath, correct_md5) assert not utils.check_integrity(existing_fpath, false_md5) assert utils.check_integrity(existing_fpath) assert not utils.check_integrity(nonexisting_fpath) def test_get_google_drive_file_id(self): url = "https://drive.google.com/file/d/1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV/view" expected = "1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV" actual = utils._get_google_drive_file_id(url) assert actual == expected def test_get_google_drive_file_id_invalid_url(self): url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz" assert utils._get_google_drive_file_id(url) is None @pytest.mark.parametrize( "file, expected", [ ("foo.tar.bz2", (".tar.bz2", ".tar", ".bz2")), ("foo.tar.xz", (".tar.xz", ".tar", ".xz")), ("foo.tar", (".tar", ".tar", None)), ("foo.tar.gz", (".tar.gz", ".tar", ".gz")), ("foo.tbz", (".tbz", ".tar", ".bz2")), ("foo.tbz2", (".tbz2", ".tar", ".bz2")), ("foo.tgz", (".tgz", ".tar", ".gz")), ("foo.bz2", (".bz2", None, ".bz2")), ("foo.gz", (".gz", None, ".gz")), ("foo.zip", (".zip", ".zip", None)), ("foo.xz", (".xz", None, ".xz")), ("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")), ("foo.bar.gz", (".gz", None, ".gz")), ("foo.bar.zip", (".zip", ".zip", None)), ], ) def test_detect_file_type(self, file, expected): assert utils._detect_file_type(file) == expected @pytest.mark.parametrize("file", ["foo", "foo.tar.baz", "foo.bar"]) def test_detect_file_type_incompatible(self, file): # tests detect file type for no extension, unknown compression and unknown partial extension with pytest.raises(RuntimeError): utils._detect_file_type(file) @pytest.mark.parametrize("extension", [".bz2", ".gz", ".xz"]) @pytest.mark.parametrize("use_pathlib", (True, False)) def test_decompress(self, extension, tmpdir, use_pathlib): def create_compressed(root, content="this is the content"): file = os.path.join(root, "file") compressed = f"{file}{extension}" compressed_file_opener = _COMPRESSED_FILE_OPENERS[extension] with compressed_file_opener(compressed, "wb") as fh: fh.write(content.encode()) return compressed, file, content compressed, file, content = create_compressed(tmpdir) if use_pathlib: compressed = pathlib.Path(compressed) utils._decompress(compressed) assert os.path.exists(file) with open(file) as fh: assert fh.read() == content def test_decompress_no_compression(self): with pytest.raises(RuntimeError): utils._decompress("foo.tar") @pytest.mark.parametrize("use_pathlib", (True, False)) def test_decompress_remove_finished(self, tmpdir, use_pathlib): def create_compressed(root, content="this is the content"): file = os.path.join(root, "file") compressed = f"{file}.gz" with gzip.open(compressed, "wb") as fh: fh.write(content.encode()) return compressed, file, content compressed, file, content = create_compressed(tmpdir) print(f"{type(compressed)=}") if use_pathlib: compressed = pathlib.Path(compressed) tmpdir = pathlib.Path(tmpdir) extracted_dir = utils.extract_archive(compressed, tmpdir, remove_finished=True) assert not os.path.exists(compressed) if use_pathlib: assert isinstance(extracted_dir, pathlib.Path) assert isinstance(compressed, pathlib.Path) else: assert isinstance(extracted_dir, str) assert isinstance(compressed, str) @pytest.mark.parametrize("extension", [".gz", ".xz"]) @pytest.mark.parametrize("remove_finished", [True, False]) def test_extract_archive_defer_to_decompress(self, extension, remove_finished, mocker): filename = "foo" file = f"{filename}{extension}" mocked = mocker.patch("torchvision.datasets.utils._decompress") utils.extract_archive(file, remove_finished=remove_finished) mocked.assert_called_once_with(file, filename, remove_finished=remove_finished) @pytest.mark.parametrize("use_pathlib", (True, False)) def test_extract_zip(self, tmpdir, use_pathlib): def create_archive(root, content="this is the content"): file = os.path.join(root, "dst.txt") archive = os.path.join(root, "archive.zip") with zipfile.ZipFile(archive, "w") as zf: zf.writestr(os.path.basename(file), content) return archive, file, content if use_pathlib: tmpdir = pathlib.Path(tmpdir) archive, file, content = create_archive(tmpdir) utils.extract_archive(archive, tmpdir) assert os.path.exists(file) with open(file) as fh: assert fh.read() == content @pytest.mark.parametrize( "extension, mode", [(".tar", "w"), (".tar.gz", "w:gz"), (".tgz", "w:gz"), (".tar.xz", "w:xz")] ) @pytest.mark.parametrize("use_pathlib", (True, False)) def test_extract_tar(self, extension, mode, tmpdir, use_pathlib): def create_archive(root, extension, mode, content="this is the content"): src = os.path.join(root, "src.txt") dst = os.path.join(root, "dst.txt") archive = os.path.join(root, f"archive{extension}") with open(src, "w") as fh: fh.write(content) with tarfile.open(archive, mode=mode) as fh: fh.add(src, arcname=os.path.basename(dst)) return archive, dst, content if use_pathlib: tmpdir = pathlib.Path(tmpdir) archive, file, content = create_archive(tmpdir, extension, mode) utils.extract_archive(archive, tmpdir) assert os.path.exists(file) with open(file) as fh: assert fh.read() == content def test_verify_str_arg(self): assert "a" == utils.verify_str_arg("a", "arg", ("a",)) pytest.raises(ValueError, utils.verify_str_arg, 0, ("a",), "arg") pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg") @pytest.mark.parametrize( ("dtype", "actual_hex", "expected_hex"), [ (torch.uint8, "01 23 45 67 89 AB CD EF", "01 23 45 67 89 AB CD EF"), (torch.float16, "01 23 45 67 89 AB CD EF", "23 01 67 45 AB 89 EF CD"), (torch.int32, "01 23 45 67 89 AB CD EF", "67 45 23 01 EF CD AB 89"), (torch.float64, "01 23 45 67 89 AB CD EF", "EF CD AB 89 67 45 23 01"), ], ) def test_flip_byte_order(self, dtype, actual_hex, expected_hex): def to_tensor(hex): return torch.frombuffer(bytes.fromhex(hex), dtype=dtype) assert_equal( utils._flip_byte_order(to_tensor(actual_hex)), to_tensor(expected_hex), ) @pytest.mark.parametrize( ("kwargs", "expected_error_msg"), [ (dict(is_valid_file=lambda path: pathlib.Path(path).suffix in {".png", ".jpeg"}), "classes c"), (dict(extensions=".png"), re.escape("classes b, c. Supported extensions are: .png")), (dict(extensions=(".png", ".jpeg")), re.escape("classes c. Supported extensions are: .png, .jpeg")), ], ) def test_make_dataset_no_valid_files(tmpdir, kwargs, expected_error_msg): tmpdir = pathlib.Path(tmpdir) (tmpdir / "a").mkdir() (tmpdir / "a" / "a.png").touch() (tmpdir / "b").mkdir() (tmpdir / "b" / "b.jpeg").touch() (tmpdir / "c").mkdir() (tmpdir / "c" / "c.unknown").touch() with pytest.raises(FileNotFoundError, match=expected_error_msg): make_dataset(str(tmpdir), **kwargs) if __name__ == "__main__": pytest.main([__file__])