diff --git a/google/auth/transport/requests.py b/google/auth/transport/requests.py index d1971cd88..8686ff260 100644 --- a/google/auth/transport/requests.py +++ b/google/auth/transport/requests.py @@ -18,6 +18,7 @@ import functools import logging +import time try: import requests @@ -64,6 +65,33 @@ def data(self): return self._response.content +class TimeoutGuard(object): + """A context manager raising an error if the suite execution took too long. + """ + + def __init__(self, timeout, timeout_error_type=requests.exceptions.Timeout): + self._timeout = timeout + self.remaining_timeout = timeout + self._timeout_error_type = timeout_error_type + + def __enter__(self): + self._start = time.time() + return self + + def __exit__(self, exc_type, exc_value, traceback): + if exc_value: + return # let the error bubble up automatically + + if self._timeout is None: + return # nothing to do, the timeout was not specified + + elapsed = time.time() - self._start + self.remaining_timeout = self._timeout - elapsed + + if self.remaining_timeout <= 0: + raise self._timeout_error_type() + + class Request(transport.Request): """Requests request adapter. @@ -193,8 +221,12 @@ def __init__( # credentials.refresh). self._auth_request = auth_request - def request(self, method, url, data=None, headers=None, **kwargs): - """Implementation of Requests' request.""" + def request(self, method, url, data=None, headers=None, timeout=None, **kwargs): + """Implementation of Requests' request. + + The ``timeout`` argument is interpreted as the approximate total time + of **all** requests that are made under the hood. + """ # pylint: disable=arguments-differ # Requests has a ton of arguments to request, but only two # (method, url) are required. We pass through all of the other @@ -208,13 +240,28 @@ def request(self, method, url, data=None, headers=None, **kwargs): # and we want to pass the original headers if we recurse. request_headers = headers.copy() if headers is not None else {} - self.credentials.before_request( - self._auth_request, method, url, request_headers + # Do not apply the timeout unconditionally in order to not override the + # _auth_request's default timeout. + auth_request = ( + self._auth_request + if timeout is None + else functools.partial(self._auth_request, timeout=timeout) ) - response = super(AuthorizedSession, self).request( - method, url, data=data, headers=request_headers, **kwargs - ) + with TimeoutGuard(timeout) as guard: + self.credentials.before_request(auth_request, method, url, request_headers) + timeout = guard.remaining_timeout + + with TimeoutGuard(timeout) as guard: + response = super(AuthorizedSession, self).request( + method, + url, + data=data, + headers=request_headers, + timeout=timeout, + **kwargs + ) + timeout = guard.remaining_timeout # If the response indicated that the credentials needed to be # refreshed, then refresh the credentials and re-attempt the @@ -233,17 +280,31 @@ def request(self, method, url, data=None, headers=None, **kwargs): self._max_refresh_attempts, ) - auth_request_with_timeout = functools.partial( - self._auth_request, timeout=self._refresh_timeout + if timeout is not None and self._refresh_timeout is not None: + timeout = min(timeout, self._refresh_timeout) + elif self._refresh_timeout is not None: + timeout = self._refresh_timeout + + # Do not apply the timeout unconditionally in order to not override the + # _auth_request's default timeout. + auth_request = ( + self._auth_request + if timeout is None + else functools.partial(self._auth_request, timeout=timeout) ) - self.credentials.refresh(auth_request_with_timeout) - # Recurse. Pass in the original headers, not our modified set. + with TimeoutGuard(timeout) as guard: + self.credentials.refresh(auth_request) + timeout = guard.remaining_timeout + + # Recurse. Pass in the original headers, not our modified set, but + # do pass the adjusted timeout (i.e. the remaining time). return self.request( method, url, data=data, headers=headers, + timeout=timeout, _credential_refresh_attempt=_credential_refresh_attempt + 1, **kwargs ) diff --git a/noxfile.py b/noxfile.py index aaf1bc57d..e170ee51d 100644 --- a/noxfile.py +++ b/noxfile.py @@ -16,6 +16,7 @@ TEST_DEPENDENCIES = [ "flask", + "freezegun", "mock", "oauth2client", "pytest", diff --git a/tests/transport/test_requests.py b/tests/transport/test_requests.py index 0e165ac54..252e4a67e 100644 --- a/tests/transport/test_requests.py +++ b/tests/transport/test_requests.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime +import functools + +import freezegun import mock +import pytest import requests import requests.adapters from six.moves import http_client @@ -22,6 +27,12 @@ from tests.transport import compliance +@pytest.fixture +def frozen_time(): + with freezegun.freeze_time("1970-01-01 00:00:00", tick=False) as frozen: + yield frozen + + class TestRequestResponse(compliance.RequestResponseTests): def make_request(self): return google.auth.transport.requests.Request() @@ -34,6 +45,41 @@ def test_timeout(self): assert http.request.call_args[1]["timeout"] == 5 +class TestTimeoutGuard(object): + def make_guard(self, *args, **kwargs): + return google.auth.transport.requests.TimeoutGuard(*args, **kwargs) + + def test_tracks_elapsed_time(self, frozen_time): + with self.make_guard(timeout=10) as guard: + frozen_time.tick(delta=3.8) + assert guard.remaining_timeout == 6.2 + + def test_noop_if_no_timeout(self, frozen_time): + with self.make_guard(timeout=None) as guard: + frozen_time.tick(delta=datetime.timedelta(days=3650)) + # NOTE: no timeout error raised, despite years have passed + assert guard.remaining_timeout is None + + def test_error_on_timeout(self, frozen_time): + with pytest.raises(requests.exceptions.Timeout): + with self.make_guard(timeout=10) as guard: + frozen_time.tick(delta=10.001) + assert guard.remaining_timeout == pytest.approx(-0.001) + + def test_custom_timeout_error_type(self, frozen_time): + class FooError(Exception): + pass + + with pytest.raises(FooError): + with self.make_guard(timeout=1, timeout_error_type=FooError): + frozen_time.tick(2) + + def test_lets_errors_bubble_up(self, frozen_time): + with pytest.raises(IndexError): + with self.make_guard(timeout=1): + [1, 2, 3][3] + + class CredentialsStub(google.auth.credentials.Credentials): def __init__(self, token="token"): super(CredentialsStub, self).__init__() @@ -49,6 +95,18 @@ def refresh(self, request): self.token += "1" +class TimeTickCredentialsStub(CredentialsStub): + """Credentials that spend some (mocked) time when refreshing a token.""" + + def __init__(self, time_tick, token="token"): + self._time_tick = time_tick + super(TimeTickCredentialsStub, self).__init__(token=token) + + def refresh(self, request): + self._time_tick() + super(TimeTickCredentialsStub, self).refresh(requests) + + class AdapterStub(requests.adapters.BaseAdapter): def __init__(self, responses, headers=None): super(AdapterStub, self).__init__() @@ -69,6 +127,18 @@ def close(self): # pragma: NO COVER return +class TimeTickAdapterStub(AdapterStub): + """Adapter that spends some (mocked) time when making a request.""" + + def __init__(self, time_tick, responses, headers=None): + self._time_tick = time_tick + super(TimeTickAdapterStub, self).__init__(responses, headers=headers) + + def send(self, request, **kwargs): + self._time_tick() + return super(TimeTickAdapterStub, self).send(request, **kwargs) + + def make_response(status=http_client.OK, data=None): response = requests.Response() response.status_code = status @@ -121,7 +191,9 @@ def test_request_refresh(self): [make_response(status=http_client.UNAUTHORIZED), final_response] ) - authed_session = google.auth.transport.requests.AuthorizedSession(credentials) + authed_session = google.auth.transport.requests.AuthorizedSession( + credentials, refresh_timeout=60 + ) authed_session.mount(self.TEST_URL, adapter) result = authed_session.request("GET", self.TEST_URL) @@ -136,3 +208,44 @@ def test_request_refresh(self): assert adapter.requests[1].url == self.TEST_URL assert adapter.requests[1].headers["authorization"] == "token1" + + def test_request_timout(self, frozen_time): + tick_one_second = functools.partial(frozen_time.tick, delta=1.0) + + credentials = mock.Mock( + wraps=TimeTickCredentialsStub(time_tick=tick_one_second) + ) + adapter = TimeTickAdapterStub( + time_tick=tick_one_second, + responses=[ + make_response(status=http_client.UNAUTHORIZED), + make_response(status=http_client.OK), + ], + ) + + authed_session = google.auth.transport.requests.AuthorizedSession(credentials) + authed_session.mount(self.TEST_URL, adapter) + + # Because at least two requests have to be made, and each takes one + # second, the total timeout specified will be exceeded. + with pytest.raises(requests.exceptions.Timeout): + authed_session.request("GET", self.TEST_URL, timeout=1.9) + + def test_request_timeout_w_refresh_timeout(self, frozen_time): + credentials = mock.Mock(wraps=CredentialsStub()) + adapter = TimeTickAdapterStub( + time_tick=functools.partial(frozen_time.tick, delta=1.0), # one second + responses=[ + make_response(status=http_client.UNAUTHORIZED), + make_response(status=http_client.OK), + ], + ) + + authed_session = google.auth.transport.requests.AuthorizedSession( + credentials, refresh_timeout=0.9 + ) + authed_session.mount(self.TEST_URL, adapter) + + # The timeout is long, but the short refresh timeout will prevail. + with pytest.raises(requests.exceptions.Timeout): + authed_session.request("GET", self.TEST_URL, timeout=60)