From 00c460afbe7fd0f931fc702f46544405a792b2f9 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 17 Mar 2021 15:24:48 +0100 Subject: [PATCH] fix redirection in download tests (#3568) Co-authored-by: Francisco Massa --- test/test_datasets_download.py | 38 ++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/test/test_datasets_download.py b/test/test_datasets_download.py index e81677baa0d..6ff3a33bcc9 100644 --- a/test/test_datasets_download.py +++ b/test/test_datasets_download.py @@ -14,7 +14,13 @@ import pytest from torchvision import datasets -from torchvision.datasets.utils import download_url, check_integrity, download_file_from_google_drive, USER_AGENT +from torchvision.datasets.utils import ( + download_url, + check_integrity, + download_file_from_google_drive, + _get_redirect_url, + USER_AGENT, +) from common_utils import get_tmp_dir from fakedata_generation import places365_root @@ -48,22 +54,28 @@ def inner_wrapper(request, *args, **kwargs): urlopen = limit_requests_per_time()(urlopen) -def resolve_redirects(max_redirects=3): +def resolve_redirects(max_hops=3): def outer_wrapper(fn): def inner_wrapper(request, *args, **kwargs): - url = initial_url = request.full_url if isinstance(request, Request) else request + initial_url = request.full_url if isinstance(request, Request) else request + url = _get_redirect_url(initial_url, max_hops=max_hops) - for _ in range(max_redirects + 1): - response = fn(request, *args, **kwargs) + if url == initial_url: + return fn(request, *args, **kwargs) - if response.url == url or response.url is None: - if url != initial_url: - warnings.warn(f"The URL {initial_url} ultimately redirects to {url}.") - return response + warnings.warn(f"The URL {initial_url} ultimately redirects to {url}.") - url = response.url - else: - raise RecursionError(f"Request to {initial_url} exceeded {max_redirects} redirects.") + if not isinstance(request, Request): + return fn(url, *args, **kwargs) + + request_attrs = { + attr: getattr(request, attr) for attr in ("data", "headers", "origin_req_host", "unverifiable") + } + # the 'method' attribute does only exist if the request was created with it + if hasattr(request, "method"): + request_attrs["method"] = request.method + + return fn(Request(url, **request_attrs), *args, **kwargs) return inner_wrapper @@ -150,7 +162,7 @@ def assert_server_response_ok(): def assert_url_is_accessible(url, timeout=5.0): - request = Request(url, headers={"method": "HEAD", "User-Agent": USER_AGENT}) + request = Request(url, headers={"User-Agent": USER_AGENT}, method="HEAD") with assert_server_response_ok(): urlopen(request, timeout=timeout)