Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix scheduled download tests #2706

Merged
merged 14 commits into from
Sep 28, 2020
4 changes: 3 additions & 1 deletion .github/failed_schedule_issue_template.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
---
title: Scheduled workflow {{ env.WORKFLOW }}/{{ env.JOB }} failed
labels: bug, module: datasets
labels:
- bug
- module: datasets
---

Oh no, something went wrong in the scheduled workflow {{ env.WORKFLOW }}/{{ env.JOB }}.
Expand Down
49 changes: 28 additions & 21 deletions .github/workflows/tests-schedule.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ name: tests

on:
pull_request:
- "test/test_datasets_download.py"
- ".github/failed_schedule_issue_template.md"
- ".github/workflows/tests-schedule.yml"
paths:
- "test/test_datasets_download.py"
- ".github/failed_schedule_issue_template.md"
- ".github/workflows/tests-schedule.yml"

schedule:
- cron: "0 9 * * *"
Expand All @@ -14,24 +15,30 @@ jobs:
runs-on: ubuntu-latest

steps:
- name: Set up python
uses: actions/setup-python@v2
with:
python-version: 3.6

- name: Upgrade pip
run: python -m pip install --upgrade pip

- name: Install PyTorch from the nightlies
run: |
pip install numpy
pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html

- name: Install tests requirements
run: pip install pytest pytest-subtests

- name: Run tests
run: pytest test/test_datasets_download.py
# - name: Set up python
# uses: actions/setup-python@v2
# with:
# python-version: 3.6
#
# - name: Upgrade pip
# run: python -m pip install --upgrade pip
#
# - name: Install PyTorch from the nightlies
# run: |
# pip install numpy
# pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
#
# - name: Install tests requirements
# run: pip install pytest pytest-subtests

- name: Checkout repository
uses: actions/checkout@v2

# - name: Run tests
# run: pytest test/test_datasets_download.py

- name: Fail workflow
run: exit 1

- uses: JasonEtco/create-an-issue@v2.4.0
name: Create issue if download tests failed
Expand Down
162 changes: 84 additions & 78 deletions test/test_datasets_download.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import contextlib
import itertools
import time
import unittest
import unittest.mock
from datetime import datetime
from os import path
from urllib.parse import urlparse
from urllib.request import urlopen, Request

import pytest

from torchvision import datasets
from torchvision.datasets.utils import download_url, check_integrity

Expand Down Expand Up @@ -43,89 +44,94 @@ def inner_wrapper(request, *args, **kwargs):
urlopen = limit_requests_per_time()(urlopen)


class DownloadTester(unittest.TestCase):
@staticmethod
@contextlib.contextmanager
def log_download_attempts(patch=True):
urls_and_md5s = set()
with unittest.mock.patch(
"torchvision.datasets.utils.download_url", wraps=None if patch else download_url
) as mock:
try:
yield urls_and_md5s
finally:
for args, kwargs in mock.call_args_list:
url = args[0]
md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
urls_and_md5s.add((url, md5))

@staticmethod
def retry(fn, times=1, wait=5.0):
msgs = []
for _ in range(times + 1):
try:
return fn()
except AssertionError as error:
msgs.append(str(error))
time.sleep(wait)
else:
raise AssertionError(
"\n".join(
(
f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time.\n",
*(f"{idx}: {error}" for idx, error in enumerate(msgs, 1)),
)
@contextlib.contextmanager
def log_download_attempts(patch=True):
urls_and_md5s = set()
with unittest.mock.patch("torchvision.datasets.utils.download_url", wraps=None if patch else download_url) as mock:
try:
yield urls_and_md5s
finally:
for args, kwargs in mock.call_args_list:
url = args[0]
md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
urls_and_md5s.add((url, md5))


def retry(fn, times=1, wait=5.0):
msgs = []
for _ in range(times + 1):
try:
return fn()
except AssertionError as error:
msgs.append(str(error))
time.sleep(wait)
else:
raise AssertionError(
"\n".join(
(
f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time.\n",
*(f"{idx}: {error}" for idx, error in enumerate(msgs, 1)),
)
)

@staticmethod
def assert_response_ok(response, url=None, ok=200):
msg = f"The server returned status code {response.code}"
if url is not None:
msg += f"for the the URL {url}"
assert response.code == ok, msg

@staticmethod
def assert_is_downloadable(url):
request = Request(url, headers=dict(method="HEAD"))
response = urlopen(request)
DownloadTester.assert_response_ok(response, url)

@staticmethod
def assert_downloads_correctly(url, md5):
with get_tmp_dir() as root:
file = path.join(root, path.basename(url))
with urlopen(url) as response, open(file, "wb") as fh:
DownloadTester.assert_response_ok(response, url)
fh.write(response.read())

assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"

def test_download(self):
assert_fn = (
lambda url, _: self.assert_is_downloadable(url)
if self.only_test_downloadability
else self.assert_downloads_correctly
)
for url, md5 in self.collect_urls_and_md5s():
with self.subTest(url=url, md5=md5):
self.retry(lambda: assert_fn(url, md5))

def collect_urls_and_md5s(self):
raise NotImplementedError

@property
def only_test_downloadability(self):
return True
def assert_server_response_ok(response, url=None):
msg = f"The server returned status code {response.code}"
if url is not None:
msg += f"for the the URL {url}"
assert 200 <= response.code < 300, msg


def assert_url_is_accessible(url):
request = Request(url, headers=dict(method="HEAD"))
response = urlopen(request)
assert_server_response_ok(response, url)


def assert_file_downloads_correctly(url, md5):
with get_tmp_dir() as root:
file = path.join(root, path.basename(url))
with urlopen(url) as response, open(file, "wb") as fh:
assert_server_response_ok(response, url)
fh.write(response.read())

assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"


class DownloadConfig:
def __init__(self, url, md5=None, id=None):
self.url = url
self.md5 = md5
self.id = id or url


def make_parametrize_kwargs(download_configs):
argvalues = []
ids = []
for config in download_configs:
argvalues.append((config.url, config.md5))
ids.append(config.id)

return dict(argnames="url, md5", argvalues=argvalues, ids=ids)


def places365():
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To add new download tests for a dataset, simply implement a similar function and add it to L130 and optionally to L135.

with log_download_attempts(patch=False) as urls_and_md5s:
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
with places365_root(split=split, small=small) as places365:
root, data = places365

datasets.Places365(root, split=split, small=small, download=True)

return [DownloadConfig(url, md5=md5, id=f"Places365, {url}") for url, md5 in urls_and_md5s]


class Places365Tester(DownloadTester):
def collect_urls_and_md5s(self):
with self.log_download_attempts(patch=False) as urls_and_md5s:
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
with places365_root(split=split, small=small) as places365:
root, data = places365
@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain(places365(),)))
def test_url_is_accessible(url, md5):
retry(lambda: assert_url_is_accessible(url))

datasets.Places365(root, split=split, small=small, download=True)

return urls_and_md5s
@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain()))
def test_file_downloads_correctly(url, md5):
retry(lambda: assert_file_downloads_correctly(url, md5))