-
Notifications
You must be signed in to change notification settings - Fork 7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Split off dataset download tests (#2665)
* split off tests for dataset downloadability * ignore download tests during normal test suite * lint * add retry mechanic
- Loading branch information
Showing
4 changed files
with
104 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import contextlib | ||
import itertools | ||
import unittest | ||
import unittest.mock | ||
from os import path | ||
from time import sleep | ||
from urllib.request import urlopen, Request | ||
|
||
from torchvision import datasets | ||
from torchvision.datasets.utils import download_url, check_integrity | ||
|
||
from common_utils import get_tmp_dir | ||
from fakedata_generation import places365_root | ||
|
||
|
||
class DownloadTester(unittest.TestCase): | ||
@staticmethod | ||
@contextlib.contextmanager | ||
def log_download_attempts(patch=True): | ||
urls_and_md5s = set() | ||
with unittest.mock.patch( | ||
"torchvision.datasets.utils.download_url", wraps=None if patch else download_url | ||
) as mock: | ||
try: | ||
yield urls_and_md5s | ||
finally: | ||
for args, kwargs in mock.call_args_list: | ||
url = args[0] | ||
md5 = args[-1] if len(args) == 4 else kwargs.get("md5") | ||
urls_and_md5s.add((url, md5)) | ||
|
||
@staticmethod | ||
def retry(fn, times=1, wait=5.0): | ||
msgs = [] | ||
for _ in range(times + 1): | ||
try: | ||
return fn() | ||
except AssertionError as error: | ||
msgs.append(str(error)) | ||
sleep(wait) | ||
else: | ||
raise AssertionError( | ||
"\n".join( | ||
( | ||
f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time.\n", | ||
*(f"{idx}: {error}" for idx, error in enumerate(msgs, 1)), | ||
) | ||
) | ||
) | ||
|
||
@staticmethod | ||
def assert_response_ok(response, url=None, ok=200): | ||
msg = f"The server returned status code {response.code}" | ||
if url is not None: | ||
msg += f"for the the URL {url}" | ||
assert response.code == ok, msg | ||
|
||
@staticmethod | ||
def assert_is_downloadable(url): | ||
request = Request(url, headers=dict(method="HEAD")) | ||
response = urlopen(request) | ||
DownloadTester.assert_response_ok(response, url) | ||
|
||
@staticmethod | ||
def assert_downloads_correctly(url, md5): | ||
with get_tmp_dir() as root: | ||
file = path.join(root, path.basename(url)) | ||
with urlopen(url) as response, open(file, "wb") as fh: | ||
DownloadTester.assert_response_ok(response, url) | ||
fh.write(response.read()) | ||
|
||
assert check_integrity(file, md5=md5), "The MD5 checksums mismatch" | ||
|
||
def test_download(self): | ||
assert_fn = ( | ||
lambda url, _: self.assert_is_downloadable(url) | ||
if self.only_test_downloadability | ||
else self.assert_downloads_correctly | ||
) | ||
for url, md5 in self.collect_urls_and_md5s(): | ||
with self.subTest(url=url, md5=md5): | ||
self.retry(lambda: assert_fn(url, md5)) | ||
sleep(2.0) | ||
|
||
def collect_urls_and_md5s(self): | ||
raise NotImplementedError | ||
|
||
@property | ||
def only_test_downloadability(self): | ||
return True | ||
|
||
|
||
class Places365Tester(DownloadTester): | ||
def collect_urls_and_md5s(self): | ||
with self.log_download_attempts(patch=False) as urls_and_md5s: | ||
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)): | ||
with places365_root(split=split, small=small) as places365: | ||
root, data = places365 | ||
|
||
datasets.Places365(root, split=split, small=small, download=True) | ||
|
||
return urls_and_md5s |