65 lines
2.3 KiB
Python
65 lines
2.3 KiB
Python
"""This file should contain all tests that need access to the internet (apart
|
|
from the ones in test_datasets_download.py)
|
|
|
|
We want to bundle all internet-related tests in one file, so the file can be
|
|
cleanly ignored in FB internal test infra.
|
|
"""
|
|
|
|
import os
|
|
import pathlib
|
|
from urllib.error import URLError
|
|
|
|
import pytest
|
|
import torchvision.datasets.utils as utils
|
|
|
|
|
|
class TestDatasetUtils:
|
|
@pytest.mark.parametrize("use_pathlib", (True, False))
|
|
def test_download_url(self, tmpdir, use_pathlib):
|
|
if use_pathlib:
|
|
tmpdir = pathlib.Path(tmpdir)
|
|
url = "http://github.com/pytorch/vision/archive/master.zip"
|
|
try:
|
|
utils.download_url(url, tmpdir)
|
|
assert len(os.listdir(tmpdir)) != 0
|
|
except URLError:
|
|
pytest.skip(f"could not download test file '{url}'")
|
|
|
|
@pytest.mark.parametrize("use_pathlib", (True, False))
|
|
def test_download_url_retry_http(self, tmpdir, use_pathlib):
|
|
if use_pathlib:
|
|
tmpdir = pathlib.Path(tmpdir)
|
|
url = "https://github.com/pytorch/vision/archive/master.zip"
|
|
try:
|
|
utils.download_url(url, tmpdir)
|
|
assert len(os.listdir(tmpdir)) != 0
|
|
except URLError:
|
|
pytest.skip(f"could not download test file '{url}'")
|
|
|
|
@pytest.mark.parametrize("use_pathlib", (True, False))
|
|
def test_download_url_dont_exist(self, tmpdir, use_pathlib):
|
|
if use_pathlib:
|
|
tmpdir = pathlib.Path(tmpdir)
|
|
url = "http://github.com/pytorch/vision/archive/this_doesnt_exist.zip"
|
|
with pytest.raises(URLError):
|
|
utils.download_url(url, tmpdir)
|
|
|
|
@pytest.mark.parametrize("use_pathlib", (True, False))
|
|
def test_download_url_dispatch_download_from_google_drive(self, mocker, tmpdir, use_pathlib):
|
|
if use_pathlib:
|
|
tmpdir = pathlib.Path(tmpdir)
|
|
url = "https://drive.google.com/file/d/1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV/view"
|
|
|
|
id = "1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV"
|
|
filename = "filename"
|
|
md5 = "md5"
|
|
|
|
mocked = mocker.patch("torchvision.datasets.utils.download_file_from_google_drive")
|
|
utils.download_url(url, tmpdir, filename, md5)
|
|
|
|
mocked.assert_called_once_with(id, os.path.expanduser(tmpdir), filename, md5)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__])
|