diff --git a/google/auth/transport/requests.py b/google/auth/transport/requests.py index 1caec0b47..ce78e63a8 100644 --- a/google/auth/transport/requests.py +++ b/google/auth/transport/requests.py @@ -239,18 +239,38 @@ def __init__( # credentials.refresh). self._auth_request = auth_request - def request(self, method, url, data=None, headers=None, timeout=None, **kwargs): + def request( + self, + method, + url, + data=None, + headers=None, + max_allowed_time=None, + timeout=None, + **kwargs + ): """Implementation of Requests' request. Args: - timeout (Optional[Union[float, Tuple[float, float]]]): The number - of seconds to wait before raising a ``Timeout`` exception. If - multiple requests are made under the hood, ``timeout`` is - interpreted as the approximate total time of **all** requests. - - If passed as a tuple ``(connect_timeout, read_timeout)``, the - smaller of the values is taken as the total allowed time across - all requests. + timeout (Optional[Union[float, Tuple[float, float]]]): + The amount of time in seconds to wait for the server response + with each individual request. + + Can also be passed as a tuple (connect_timeout, read_timeout). + See :meth:`requests.Session.request` documentation for details. + + max_allowed_time (Optional[float]): + If the method runs longer than this, a ``Timeout`` exception is + automatically raised. Unlike the ``timeout` parameter, this + value applies to the total method execution time, even if + multiple requests are made under the hood. + + Mind that it is not guaranteed that the timeout error is raised + at ``max_allowed_time`. It might take longer, for example, if + an underlying request takes a lot of time, but the request + itself does not timeout, e.g. if a large file is being + transmitted. The timout error will be raised after such + request completes. """ # pylint: disable=arguments-differ # Requests has a ton of arguments to request, but only two @@ -273,11 +293,13 @@ def request(self, method, url, data=None, headers=None, timeout=None, **kwargs): else functools.partial(self._auth_request, timeout=timeout) ) - with TimeoutGuard(timeout) as guard: + remaining_time = max_allowed_time + + with TimeoutGuard(remaining_time) as guard: self.credentials.before_request(auth_request, method, url, request_headers) - timeout = guard.remaining_timeout + remaining_time = guard.remaining_timeout - with TimeoutGuard(timeout) as guard: + with TimeoutGuard(remaining_time) as guard: response = super(AuthorizedSession, self).request( method, url, @@ -286,7 +308,7 @@ def request(self, method, url, data=None, headers=None, timeout=None, **kwargs): timeout=timeout, **kwargs ) - timeout = guard.remaining_timeout + remaining_time = guard.remaining_timeout # If the response indicated that the credentials needed to be # refreshed, then refresh the credentials and re-attempt the @@ -305,14 +327,6 @@ def request(self, method, url, data=None, headers=None, timeout=None, **kwargs): self._max_refresh_attempts, ) - if self._refresh_timeout is not None: - if timeout is None: - timeout = self._refresh_timeout - elif isinstance(timeout, numbers.Number): - timeout = min(timeout, self._refresh_timeout) - else: - timeout = tuple(min(x, self._refresh_timeout) for x in timeout) - # Do not apply the timeout unconditionally in order to not override the # _auth_request's default timeout. auth_request = ( @@ -321,17 +335,18 @@ def request(self, method, url, data=None, headers=None, timeout=None, **kwargs): else functools.partial(self._auth_request, timeout=timeout) ) - with TimeoutGuard(timeout) as guard: + with TimeoutGuard(remaining_time) as guard: self.credentials.refresh(auth_request) - timeout = guard.remaining_timeout + remaining_time = guard.remaining_timeout # Recurse. Pass in the original headers, not our modified set, but - # do pass the adjusted timeout (i.e. the remaining time). + # do pass the adjusted max allowed time (i.e. the remaining total time). return self.request( method, url, data=data, headers=headers, + max_allowed_time=remaining_time, timeout=timeout, _credential_refresh_attempt=_credential_refresh_attempt + 1, **kwargs diff --git a/tests/transport/test_requests.py b/tests/transport/test_requests.py index 00269740f..8f73d4bd5 100644 --- a/tests/transport/test_requests.py +++ b/tests/transport/test_requests.py @@ -220,7 +220,25 @@ def test_request_refresh(self): assert adapter.requests[1].url == self.TEST_URL assert adapter.requests[1].headers["authorization"] == "token1" - def test_request_timeout(self, frozen_time): + def test_request_max_allowed_time_timeout_error(self, frozen_time): + tick_one_second = functools.partial(frozen_time.tick, delta=1.0) + + credentials = mock.Mock( + wraps=TimeTickCredentialsStub(time_tick=tick_one_second) + ) + adapter = TimeTickAdapterStub( + time_tick=tick_one_second, responses=[make_response(status=http_client.OK)] + ) + + authed_session = google.auth.transport.requests.AuthorizedSession(credentials) + authed_session.mount(self.TEST_URL, adapter) + + # Because a request takes a full mocked second, max_allowed_time shorter + # than that will cause a timeout error. + with pytest.raises(requests.exceptions.Timeout): + authed_session.request("GET", self.TEST_URL, max_allowed_time=0.9) + + def test_request_max_allowed_time_w_transport_timeout_no_error(self, frozen_time): tick_one_second = functools.partial(frozen_time.tick, delta=1.0) credentials = mock.Mock( @@ -237,12 +255,12 @@ def test_request_timeout(self, frozen_time): authed_session = google.auth.transport.requests.AuthorizedSession(credentials) authed_session.mount(self.TEST_URL, adapter) - # Because at least two requests have to be made, and each takes one - # second, the total timeout specified will be exceeded. - with pytest.raises(requests.exceptions.Timeout): - authed_session.request("GET", self.TEST_URL, timeout=1.9) + # A short configured transport timeout does not affect max_allowed_time. + # The latter is not adjusted to it and is only concerned with the actual + # execution time. The call below should thus not raise a timeout error. + authed_session.request("GET", self.TEST_URL, timeout=0.5, max_allowed_time=3.1) - def test_request_timeout_w_refresh_timeout(self, frozen_time): + def test_request_max_allowed_time_w_refresh_timeout_no_error(self, frozen_time): tick_one_second = functools.partial(frozen_time.tick, delta=1.0) credentials = mock.Mock( @@ -257,15 +275,17 @@ def test_request_timeout_w_refresh_timeout(self, frozen_time): ) authed_session = google.auth.transport.requests.AuthorizedSession( - credentials, refresh_timeout=1.9 + credentials, refresh_timeout=1.1 ) authed_session.mount(self.TEST_URL, adapter) - # The timeout is long, but the short refresh timeout will prevail. - with pytest.raises(requests.exceptions.Timeout): - authed_session.request("GET", self.TEST_URL, timeout=60) + # A short configured refresh timeout does not affect max_allowed_time. + # The latter is not adjusted to it and is only concerned with the actual + # execution time. The call below should thus not raise a timeout error + # (and `timeout` does not come into play either, as it's very long). + authed_session.request("GET", self.TEST_URL, timeout=60, max_allowed_time=3.1) - def test_request_timeout_w_refresh_timeout_and_tuple_timeout(self, frozen_time): + def test_request_timeout_w_refresh_timeout_timeout_error(self, frozen_time): tick_one_second = functools.partial(frozen_time.tick, delta=1.0) credentials = mock.Mock( @@ -284,7 +304,10 @@ def test_request_timeout_w_refresh_timeout_and_tuple_timeout(self, frozen_time): ) authed_session.mount(self.TEST_URL, adapter) - # The shortest timeout will prevail and cause a Timeout error, despite - # other timeouts being quite long. + # An UNAUTHORIZED response triggers a refresh (an extra request), thus + # the final request that otherwise succeeds results in a timeout error + # (all three requests together last 3 mocked seconds). with pytest.raises(requests.exceptions.Timeout): - authed_session.request("GET", self.TEST_URL, timeout=(100, 2.9)) + authed_session.request( + "GET", self.TEST_URL, timeout=60, max_allowed_time=2.9 + )