Skip to content

Commit

Permalink
fix redirection in download tests (#3568)
Browse files Browse the repository at this point in the history
Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
  • Loading branch information
pmeier and fmassa authored Mar 17, 2021
1 parent 8ee6339 commit 00c460a
Showing 1 changed file with 25 additions and 13 deletions.
38 changes: 25 additions & 13 deletions test/test_datasets_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
import pytest

from torchvision import datasets
from torchvision.datasets.utils import download_url, check_integrity, download_file_from_google_drive, USER_AGENT
from torchvision.datasets.utils import (
download_url,
check_integrity,
download_file_from_google_drive,
_get_redirect_url,
USER_AGENT,
)

from common_utils import get_tmp_dir
from fakedata_generation import places365_root
Expand Down Expand Up @@ -48,22 +54,28 @@ def inner_wrapper(request, *args, **kwargs):
urlopen = limit_requests_per_time()(urlopen)


def resolve_redirects(max_redirects=3):
def resolve_redirects(max_hops=3):
def outer_wrapper(fn):
def inner_wrapper(request, *args, **kwargs):
url = initial_url = request.full_url if isinstance(request, Request) else request
initial_url = request.full_url if isinstance(request, Request) else request
url = _get_redirect_url(initial_url, max_hops=max_hops)

for _ in range(max_redirects + 1):
response = fn(request, *args, **kwargs)
if url == initial_url:
return fn(request, *args, **kwargs)

if response.url == url or response.url is None:
if url != initial_url:
warnings.warn(f"The URL {initial_url} ultimately redirects to {url}.")
return response
warnings.warn(f"The URL {initial_url} ultimately redirects to {url}.")

url = response.url
else:
raise RecursionError(f"Request to {initial_url} exceeded {max_redirects} redirects.")
if not isinstance(request, Request):
return fn(url, *args, **kwargs)

request_attrs = {
attr: getattr(request, attr) for attr in ("data", "headers", "origin_req_host", "unverifiable")
}
# the 'method' attribute does only exist if the request was created with it
if hasattr(request, "method"):
request_attrs["method"] = request.method

return fn(Request(url, **request_attrs), *args, **kwargs)

return inner_wrapper

Expand Down Expand Up @@ -150,7 +162,7 @@ def assert_server_response_ok():


def assert_url_is_accessible(url, timeout=5.0):
request = Request(url, headers={"method": "HEAD", "User-Agent": USER_AGENT})
request = Request(url, headers={"User-Agent": USER_AGENT}, method="HEAD")
with assert_server_response_ok():
urlopen(request, timeout=timeout)

Expand Down

0 comments on commit 00c460a

Please sign in to comment.