Skip to content

Commit

Permalink
fix scheduled download tests (pytorch#2706)
Browse files Browse the repository at this point in the history
* fix triggers for scheduled workflow

* more fix

* add missing repository checkout

* try fix label in template

* rewrite test infrastructure

* trigger issue generation

* try fix issue template

* try remove quotes

* remove buggy label

* try fix title

* cleanup

* add more test details

* reenable issue creation

Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
  • Loading branch information
2 people authored and vfdev-5 committed Dec 4, 2020
1 parent bffb8b8 commit 423296e
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 86 deletions.
6 changes: 4 additions & 2 deletions .github/failed_schedule_issue_template.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
---
title: Scheduled workflow {{ env.WORKFLOW }}/{{ env.JOB }} failed
labels: bug, module: datasets
title: Scheduled workflow failed
labels:
- bug
- "module: datasets"
---

Oh no, something went wrong in the scheduled workflow {{ env.WORKFLOW }}/{{ env.JOB }}.
Expand Down
16 changes: 10 additions & 6 deletions .github/workflows/tests-schedule.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ name: tests

on:
pull_request:
- "test/test_datasets_download.py"
- ".github/failed_schedule_issue_template.md"
- ".github/workflows/tests-schedule.yml"
paths:
- "test/test_datasets_download.py"
- ".github/failed_schedule_issue_template.md"
- ".github/workflows/tests-schedule.yml"

schedule:
- cron: "0 9 * * *"
Expand All @@ -22,20 +23,23 @@ jobs:
- name: Upgrade pip
run: python -m pip install --upgrade pip

- name: Checkout repository
uses: actions/checkout@v2

- name: Install PyTorch from the nightlies
run: |
pip install numpy
pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
- name: Install tests requirements
run: pip install pytest pytest-subtests
run: pip install pytest

- name: Run tests
run: pytest test/test_datasets_download.py
run: pytest --durations=20 -ra test/test_datasets_download.py

- uses: JasonEtco/create-an-issue@v2.4.0
name: Create issue if download tests failed
if: failure()
if: failure() && github.event_name == 'schedule'
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
REPO: ${{ github.repository }}
Expand Down
162 changes: 84 additions & 78 deletions test/test_datasets_download.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import contextlib
import itertools
import time
import unittest
import unittest.mock
from datetime import datetime
from os import path
from urllib.parse import urlparse
from urllib.request import urlopen, Request

import pytest

from torchvision import datasets
from torchvision.datasets.utils import download_url, check_integrity

Expand Down Expand Up @@ -43,89 +44,94 @@ def inner_wrapper(request, *args, **kwargs):
urlopen = limit_requests_per_time()(urlopen)


class DownloadTester(unittest.TestCase):
@staticmethod
@contextlib.contextmanager
def log_download_attempts(patch=True):
urls_and_md5s = set()
with unittest.mock.patch(
"torchvision.datasets.utils.download_url", wraps=None if patch else download_url
) as mock:
try:
yield urls_and_md5s
finally:
for args, kwargs in mock.call_args_list:
url = args[0]
md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
urls_and_md5s.add((url, md5))

@staticmethod
def retry(fn, times=1, wait=5.0):
msgs = []
for _ in range(times + 1):
try:
return fn()
except AssertionError as error:
msgs.append(str(error))
time.sleep(wait)
else:
raise AssertionError(
"\n".join(
(
f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time.\n",
*(f"{idx}: {error}" for idx, error in enumerate(msgs, 1)),
)
@contextlib.contextmanager
def log_download_attempts(patch=True):
urls_and_md5s = set()
with unittest.mock.patch("torchvision.datasets.utils.download_url", wraps=None if patch else download_url) as mock:
try:
yield urls_and_md5s
finally:
for args, kwargs in mock.call_args_list:
url = args[0]
md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
urls_and_md5s.add((url, md5))


def retry(fn, times=1, wait=5.0):
msgs = []
for _ in range(times + 1):
try:
return fn()
except AssertionError as error:
msgs.append(str(error))
time.sleep(wait)
else:
raise AssertionError(
"\n".join(
(
f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time.\n",
*(f"{idx}: {error}" for idx, error in enumerate(msgs, 1)),
)
)

@staticmethod
def assert_response_ok(response, url=None, ok=200):
msg = f"The server returned status code {response.code}"
if url is not None:
msg += f"for the the URL {url}"
assert response.code == ok, msg

@staticmethod
def assert_is_downloadable(url):
request = Request(url, headers=dict(method="HEAD"))
response = urlopen(request)
DownloadTester.assert_response_ok(response, url)

@staticmethod
def assert_downloads_correctly(url, md5):
with get_tmp_dir() as root:
file = path.join(root, path.basename(url))
with urlopen(url) as response, open(file, "wb") as fh:
DownloadTester.assert_response_ok(response, url)
fh.write(response.read())

assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"

def test_download(self):
assert_fn = (
lambda url, _: self.assert_is_downloadable(url)
if self.only_test_downloadability
else self.assert_downloads_correctly
)
for url, md5 in self.collect_urls_and_md5s():
with self.subTest(url=url, md5=md5):
self.retry(lambda: assert_fn(url, md5))

def collect_urls_and_md5s(self):
raise NotImplementedError

@property
def only_test_downloadability(self):
return True
def assert_server_response_ok(response, url=None):
msg = f"The server returned status code {response.code}"
if url is not None:
msg += f"for the the URL {url}"
assert 200 <= response.code < 300, msg


def assert_url_is_accessible(url):
request = Request(url, headers=dict(method="HEAD"))
response = urlopen(request)
assert_server_response_ok(response, url)


def assert_file_downloads_correctly(url, md5):
with get_tmp_dir() as root:
file = path.join(root, path.basename(url))
with urlopen(url) as response, open(file, "wb") as fh:
assert_server_response_ok(response, url)
fh.write(response.read())

assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"


class DownloadConfig:
def __init__(self, url, md5=None, id=None):
self.url = url
self.md5 = md5
self.id = id or url


def make_parametrize_kwargs(download_configs):
argvalues = []
ids = []
for config in download_configs:
argvalues.append((config.url, config.md5))
ids.append(config.id)

return dict(argnames="url, md5", argvalues=argvalues, ids=ids)


def places365():
with log_download_attempts(patch=False) as urls_and_md5s:
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
with places365_root(split=split, small=small) as places365:
root, data = places365

datasets.Places365(root, split=split, small=small, download=True)

return [DownloadConfig(url, md5=md5, id=f"Places365, {url}") for url, md5 in urls_and_md5s]


class Places365Tester(DownloadTester):
def collect_urls_and_md5s(self):
with self.log_download_attempts(patch=False) as urls_and_md5s:
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
with places365_root(split=split, small=small) as places365:
root, data = places365
@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain(places365(),)))
def test_url_is_accessible(url, md5):
retry(lambda: assert_url_is_accessible(url))

datasets.Places365(root, split=split, small=small, download=True)

return urls_and_md5s
@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain()))
def test_file_downloads_correctly(url, md5):
retry(lambda: assert_file_downloads_correctly(url, md5))

0 comments on commit 423296e

Please sign in to comment.