Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

limit requests per time in download tests #2699

Merged
merged 1 commit into from
Sep 24, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions test/test_datasets_download.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import contextlib
import itertools
import time
import unittest
import unittest.mock
from datetime import datetime
from os import path
from time import sleep
from urllib.parse import urlparse
from urllib.request import urlopen, Request

from torchvision import datasets
Expand All @@ -13,6 +15,34 @@
from fakedata_generation import places365_root


def limit_requests_per_time(min_secs_between_requests=2.0):
last_requests = {}

def outer_wrapper(fn):
def inner_wrapper(request, *args, **kwargs):
url = request.full_url if isinstance(request, Request) else request

netloc = urlparse(url).netloc
last_request = last_requests.get(netloc)
if last_request is not None:
elapsed_secs = (datetime.now() - last_request).total_seconds()
delta = min_secs_between_requests - elapsed_secs
if delta > 0:
time.sleep(delta)

response = fn(request, *args, **kwargs)
last_requests[netloc] = datetime.now()

return response

return inner_wrapper

return outer_wrapper


urlopen = limit_requests_per_time()(urlopen)


class DownloadTester(unittest.TestCase):
@staticmethod
@contextlib.contextmanager
Expand All @@ -37,7 +67,7 @@ def retry(fn, times=1, wait=5.0):
return fn()
except AssertionError as error:
msgs.append(str(error))
sleep(wait)
time.sleep(wait)
else:
raise AssertionError(
"\n".join(
Expand Down Expand Up @@ -80,7 +110,6 @@ def test_download(self):
for url, md5 in self.collect_urls_and_md5s():
with self.subTest(url=url, md5=md5):
self.retry(lambda: assert_fn(url, md5))
sleep(2.0)

def collect_urls_and_md5s(self):
raise NotImplementedError
Expand Down