diff --git a/.github/failed_schedule_issue_template.md b/.github/failed_schedule_issue_template.md index 8ec971ffe47..5e2d77550ac 100644 --- a/.github/failed_schedule_issue_template.md +++ b/.github/failed_schedule_issue_template.md @@ -1,6 +1,8 @@ --- -title: Scheduled workflow {{ env.WORKFLOW }}/{{ env.JOB }} failed -labels: bug, module: datasets +title: Scheduled workflow failed +labels: + - bug + - "module: datasets" --- Oh no, something went wrong in the scheduled workflow {{ env.WORKFLOW }}/{{ env.JOB }}. diff --git a/.github/workflows/tests-schedule.yml b/.github/workflows/tests-schedule.yml index 63919ed3bd6..135bd575cec 100644 --- a/.github/workflows/tests-schedule.yml +++ b/.github/workflows/tests-schedule.yml @@ -2,9 +2,10 @@ name: tests on: pull_request: - - "test/test_datasets_download.py" - - ".github/failed_schedule_issue_template.md" - - ".github/workflows/tests-schedule.yml" + paths: + - "test/test_datasets_download.py" + - ".github/failed_schedule_issue_template.md" + - ".github/workflows/tests-schedule.yml" schedule: - cron: "0 9 * * *" @@ -22,20 +23,23 @@ jobs: - name: Upgrade pip run: python -m pip install --upgrade pip + - name: Checkout repository + uses: actions/checkout@v2 + - name: Install PyTorch from the nightlies run: | pip install numpy pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - name: Install tests requirements - run: pip install pytest pytest-subtests + run: pip install pytest - name: Run tests - run: pytest test/test_datasets_download.py + run: pytest --durations=20 -ra test/test_datasets_download.py - uses: JasonEtco/create-an-issue@v2.4.0 name: Create issue if download tests failed - if: failure() + if: failure() && github.event_name == 'schedule' env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} REPO: ${{ github.repository }} diff --git a/test/test_datasets_download.py b/test/test_datasets_download.py index 7b189368dda..9b040edb1c1 100644 --- a/test/test_datasets_download.py +++ b/test/test_datasets_download.py @@ -1,13 +1,14 @@ import contextlib import itertools import time -import unittest import unittest.mock from datetime import datetime from os import path from urllib.parse import urlparse from urllib.request import urlopen, Request +import pytest + from torchvision import datasets from torchvision.datasets.utils import download_url, check_integrity @@ -43,89 +44,94 @@ def inner_wrapper(request, *args, **kwargs): urlopen = limit_requests_per_time()(urlopen) -class DownloadTester(unittest.TestCase): - @staticmethod - @contextlib.contextmanager - def log_download_attempts(patch=True): - urls_and_md5s = set() - with unittest.mock.patch( - "torchvision.datasets.utils.download_url", wraps=None if patch else download_url - ) as mock: - try: - yield urls_and_md5s - finally: - for args, kwargs in mock.call_args_list: - url = args[0] - md5 = args[-1] if len(args) == 4 else kwargs.get("md5") - urls_and_md5s.add((url, md5)) - - @staticmethod - def retry(fn, times=1, wait=5.0): - msgs = [] - for _ in range(times + 1): - try: - return fn() - except AssertionError as error: - msgs.append(str(error)) - time.sleep(wait) - else: - raise AssertionError( - "\n".join( - ( - f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time.\n", - *(f"{idx}: {error}" for idx, error in enumerate(msgs, 1)), - ) +@contextlib.contextmanager +def log_download_attempts(patch=True): + urls_and_md5s = set() + with unittest.mock.patch("torchvision.datasets.utils.download_url", wraps=None if patch else download_url) as mock: + try: + yield urls_and_md5s + finally: + for args, kwargs in mock.call_args_list: + url = args[0] + md5 = args[-1] if len(args) == 4 else kwargs.get("md5") + urls_and_md5s.add((url, md5)) + + +def retry(fn, times=1, wait=5.0): + msgs = [] + for _ in range(times + 1): + try: + return fn() + except AssertionError as error: + msgs.append(str(error)) + time.sleep(wait) + else: + raise AssertionError( + "\n".join( + ( + f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time.\n", + *(f"{idx}: {error}" for idx, error in enumerate(msgs, 1)), ) ) - - @staticmethod - def assert_response_ok(response, url=None, ok=200): - msg = f"The server returned status code {response.code}" - if url is not None: - msg += f"for the the URL {url}" - assert response.code == ok, msg - - @staticmethod - def assert_is_downloadable(url): - request = Request(url, headers=dict(method="HEAD")) - response = urlopen(request) - DownloadTester.assert_response_ok(response, url) - - @staticmethod - def assert_downloads_correctly(url, md5): - with get_tmp_dir() as root: - file = path.join(root, path.basename(url)) - with urlopen(url) as response, open(file, "wb") as fh: - DownloadTester.assert_response_ok(response, url) - fh.write(response.read()) - - assert check_integrity(file, md5=md5), "The MD5 checksums mismatch" - - def test_download(self): - assert_fn = ( - lambda url, _: self.assert_is_downloadable(url) - if self.only_test_downloadability - else self.assert_downloads_correctly ) - for url, md5 in self.collect_urls_and_md5s(): - with self.subTest(url=url, md5=md5): - self.retry(lambda: assert_fn(url, md5)) - def collect_urls_and_md5s(self): - raise NotImplementedError - @property - def only_test_downloadability(self): - return True +def assert_server_response_ok(response, url=None): + msg = f"The server returned status code {response.code}" + if url is not None: + msg += f"for the the URL {url}" + assert 200 <= response.code < 300, msg + + +def assert_url_is_accessible(url): + request = Request(url, headers=dict(method="HEAD")) + response = urlopen(request) + assert_server_response_ok(response, url) + + +def assert_file_downloads_correctly(url, md5): + with get_tmp_dir() as root: + file = path.join(root, path.basename(url)) + with urlopen(url) as response, open(file, "wb") as fh: + assert_server_response_ok(response, url) + fh.write(response.read()) + + assert check_integrity(file, md5=md5), "The MD5 checksums mismatch" + + +class DownloadConfig: + def __init__(self, url, md5=None, id=None): + self.url = url + self.md5 = md5 + self.id = id or url + + +def make_parametrize_kwargs(download_configs): + argvalues = [] + ids = [] + for config in download_configs: + argvalues.append((config.url, config.md5)) + ids.append(config.id) + + return dict(argnames="url, md5", argvalues=argvalues, ids=ids) + + +def places365(): + with log_download_attempts(patch=False) as urls_and_md5s: + for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)): + with places365_root(split=split, small=small) as places365: + root, data = places365 + + datasets.Places365(root, split=split, small=small, download=True) + + return [DownloadConfig(url, md5=md5, id=f"Places365, {url}") for url, md5 in urls_and_md5s] -class Places365Tester(DownloadTester): - def collect_urls_and_md5s(self): - with self.log_download_attempts(patch=False) as urls_and_md5s: - for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)): - with places365_root(split=split, small=small) as places365: - root, data = places365 +@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain(places365(),))) +def test_url_is_accessible(url, md5): + retry(lambda: assert_url_is_accessible(url)) - datasets.Places365(root, split=split, small=small, download=True) - return urls_and_md5s +@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain())) +def test_file_downloads_correctly(url, md5): + retry(lambda: assert_file_downloads_correctly(url, md5))