Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mock redirection logic for tests #4197

Merged
merged 7 commits into from
Jul 22, 2021
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 44 additions & 5 deletions test/test_internet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,61 @@
import pytest
import warnings
from urllib.error import URLError
from urllib.request import Request

import torchvision.datasets.utils as utils
from common_utils import get_tmp_dir


class TestDatasetUtils:
@pytest.fixture
pmeier marked this conversation as resolved.
Show resolved Hide resolved
def patch_url_redirection(mocker):
class Response:
pmeier marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, url):
self.url = url

def factory(*urls):
class PatchedOpener:
def __init__(self, request_or_url, *args, **kwargs):
self._request_or_url = request_or_url

def __enter__(self):
url = (
self._request_or_url.full_url
if isinstance(self._request_or_url.full_url, Request)
else self._request_or_url.full_url
)
pmeier marked this conversation as resolved.
Show resolved Hide resolved

if url == urls[-1]:
redirect_url = url
else:
redirect_url = urls[urls.index(url) + 1]
pmeier marked this conversation as resolved.
Show resolved Hide resolved

return Response(redirect_url)

def __exit__(self, exc_type, exc_val, exc_tb):
pass

mocker.patch("torchvision.datasets.utils.urllib.request.urlopen", new=PatchedOpener)

def test_get_redirect_url(self):
return factory


class TestDatasetUtils:
def test_get_redirect_url(self, patch_url_redirection):
url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"
expected = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"
expected_redirected_url = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"

patch_url_redirection(url, expected_redirected_url)

actual = utils._get_redirect_url(url)
assert actual == expected
assert actual == expected_redirected_url

def test_get_redirect_url_max_hops_exceeded(self):
def test_get_redirect_url_max_hops_exceeded(self, patch_url_redirection):
url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"
redirected_url = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"

patch_url_redirection(url, redirected_url)

with pytest.raises(RecursionError):
utils._get_redirect_url(url, max_hops=0)

Expand Down