Skip to content

Commit

Permalink
Split off dataset download tests (#2665)
Browse files Browse the repository at this point in the history
* split off tests for dataset downloadability

* ignore download tests during normal test suite

* lint

* add retry mechanic
  • Loading branch information
pmeier authored Sep 14, 2020
1 parent a4736ea commit 3b31b72
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .circleci/unittest/linux/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ eval "$(./conda/bin/conda shell.bash hook)"
conda activate ./env

python -m torch.utils.collect_env
pytest --cov=torchvision --junitxml=test-results/junit.xml -v --durations 20 test
pytest --cov=torchvision --junitxml=test-results/junit.xml -v --durations 20 test --ignore=test/test_datasets_download.py
2 changes: 1 addition & 1 deletion .circleci/unittest/windows/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ eval "$(./conda/Scripts/conda.exe 'shell.bash' 'hook')"
conda activate ./env

python -m torch.utils.collect_env
pytest --cov=torchvision --junitxml=test-results/junit.xml -v --durations 20 test
pytest --cov=torchvision --junitxml=test-results/junit.xml -v --durations 20 test --ignore=test/test_datasets_download.py
14 changes: 0 additions & 14 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,20 +311,6 @@ def target_transform(target):
self.assertEqual(actual_image, expected_image)
self.assertEqual(actual_target, expected_target)

@mock.patch("torchvision.datasets.utils.download_url")
def test_places365_downloadable(self, download_url):
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
with places365_root(split=split, small=small) as places365:
root, data = places365

torchvision.datasets.Places365(root, split=split, small=small, download=True)

urls = {call_args[0][0] for call_args in download_url.call_args_list}
for url in urls:
with self.subTest(url=url):
response = urlopen(Request(url, method="HEAD"))
assert response.code == 200, f"Server returned status code {response.code} for {url}."

def test_places365_devkit_download(self):
for split in ("train-standard", "train-challenge", "val"):
with self.subTest(split=split):
Expand Down
102 changes: 102 additions & 0 deletions test/test_datasets_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import contextlib
import itertools
import unittest
import unittest.mock
from os import path
from time import sleep
from urllib.request import urlopen, Request

from torchvision import datasets
from torchvision.datasets.utils import download_url, check_integrity

from common_utils import get_tmp_dir
from fakedata_generation import places365_root


class DownloadTester(unittest.TestCase):
@staticmethod
@contextlib.contextmanager
def log_download_attempts(patch=True):
urls_and_md5s = set()
with unittest.mock.patch(
"torchvision.datasets.utils.download_url", wraps=None if patch else download_url
) as mock:
try:
yield urls_and_md5s
finally:
for args, kwargs in mock.call_args_list:
url = args[0]
md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
urls_and_md5s.add((url, md5))

@staticmethod
def retry(fn, times=1, wait=5.0):
msgs = []
for _ in range(times + 1):
try:
return fn()
except AssertionError as error:
msgs.append(str(error))
sleep(wait)
else:
raise AssertionError(
"\n".join(
(
f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time.\n",
*(f"{idx}: {error}" for idx, error in enumerate(msgs, 1)),
)
)
)

@staticmethod
def assert_response_ok(response, url=None, ok=200):
msg = f"The server returned status code {response.code}"
if url is not None:
msg += f"for the the URL {url}"
assert response.code == ok, msg

@staticmethod
def assert_is_downloadable(url):
request = Request(url, headers=dict(method="HEAD"))
response = urlopen(request)
DownloadTester.assert_response_ok(response, url)

@staticmethod
def assert_downloads_correctly(url, md5):
with get_tmp_dir() as root:
file = path.join(root, path.basename(url))
with urlopen(url) as response, open(file, "wb") as fh:
DownloadTester.assert_response_ok(response, url)
fh.write(response.read())

assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"

def test_download(self):
assert_fn = (
lambda url, _: self.assert_is_downloadable(url)
if self.only_test_downloadability
else self.assert_downloads_correctly
)
for url, md5 in self.collect_urls_and_md5s():
with self.subTest(url=url, md5=md5):
self.retry(lambda: assert_fn(url, md5))
sleep(2.0)

def collect_urls_and_md5s(self):
raise NotImplementedError

@property
def only_test_downloadability(self):
return True


class Places365Tester(DownloadTester):
def collect_urls_and_md5s(self):
with self.log_download_attempts(patch=False) as urls_and_md5s:
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
with places365_root(split=split, small=small) as places365:
root, data = places365

datasets.Places365(root, split=split, small=small, download=True)

return urls_and_md5s

0 comments on commit 3b31b72

Please sign in to comment.