Skip to content

Commit

Permalink
feat: add timeout to AuthorisedSession.request()
Browse files Browse the repository at this point in the history
  • Loading branch information
plamut committed Dec 6, 2019
1 parent ab3dc1e commit 94276e8
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 12 deletions.
83 changes: 72 additions & 11 deletions google/auth/transport/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import functools
import logging
import time

try:
import requests
Expand Down Expand Up @@ -64,6 +65,33 @@ def data(self):
return self._response.content


class TimeoutGuard(object):
"""A context manager raising an error if the suite execution took too long.
"""

def __init__(self, timeout, timeout_error_type=requests.exceptions.Timeout):
self._timeout = timeout
self.remaining_timeout = timeout
self._timeout_error_type = timeout_error_type

def __enter__(self):
self._start = time.time()
return self

def __exit__(self, exc_type, exc_value, traceback):
if exc_value:
return # let the error bubble up automatically

if self._timeout is None:
return # nothing to do, the timeout was not specified

elapsed = time.time() - self._start
self.remaining_timeout = self._timeout - elapsed

if self.remaining_timeout <= 0:
raise self._timeout_error_type()


class Request(transport.Request):
"""Requests request adapter.
Expand Down Expand Up @@ -193,8 +221,12 @@ def __init__(
# credentials.refresh).
self._auth_request = auth_request

def request(self, method, url, data=None, headers=None, **kwargs):
"""Implementation of Requests' request."""
def request(self, method, url, data=None, headers=None, timeout=None, **kwargs):
"""Implementation of Requests' request.
The ``timeout`` argument is interpreted as the approximate total time
of **all** requests that are made under the hood.
"""
# pylint: disable=arguments-differ
# Requests has a ton of arguments to request, but only two
# (method, url) are required. We pass through all of the other
Expand All @@ -208,13 +240,28 @@ def request(self, method, url, data=None, headers=None, **kwargs):
# and we want to pass the original headers if we recurse.
request_headers = headers.copy() if headers is not None else {}

self.credentials.before_request(
self._auth_request, method, url, request_headers
# Do not apply the timeout unconditionally in order to not override the
# _auth_request's default timeout.
auth_request = (
self._auth_request
if timeout is None
else functools.partial(self._auth_request, timeout=timeout)
)

response = super(AuthorizedSession, self).request(
method, url, data=data, headers=request_headers, **kwargs
)
with TimeoutGuard(timeout) as guard:
self.credentials.before_request(auth_request, method, url, request_headers)
timeout = guard.remaining_timeout

with TimeoutGuard(timeout) as guard:
response = super(AuthorizedSession, self).request(
method,
url,
data=data,
headers=request_headers,
timeout=timeout,
**kwargs
)
timeout = guard.remaining_timeout

# If the response indicated that the credentials needed to be
# refreshed, then refresh the credentials and re-attempt the
Expand All @@ -233,17 +280,31 @@ def request(self, method, url, data=None, headers=None, **kwargs):
self._max_refresh_attempts,
)

auth_request_with_timeout = functools.partial(
self._auth_request, timeout=self._refresh_timeout
if timeout is not None and self._refresh_timeout is not None:
timeout = min(timeout, self._refresh_timeout)
elif self._refresh_timeout is not None:
timeout = self._refresh_timeout

# Do not apply the timeout unconditionally in order to not override the
# _auth_request's default timeout.
auth_request = (
self._auth_request
if timeout is None
else functools.partial(self._auth_request, timeout=timeout)
)
self.credentials.refresh(auth_request_with_timeout)

# Recurse. Pass in the original headers, not our modified set.
with TimeoutGuard(timeout) as guard:
self.credentials.refresh(auth_request)
timeout = guard.remaining_timeout

# Recurse. Pass in the original headers, not our modified set, but
# do pass the adjusted timeout (i.e. the remaining time).
return self.request(
method,
url,
data=data,
headers=headers,
timeout=timeout,
_credential_refresh_attempt=_credential_refresh_attempt + 1,
**kwargs
)
Expand Down
1 change: 1 addition & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

TEST_DEPENDENCIES = [
"flask",
"freezegun",
"mock",
"oauth2client",
"pytest",
Expand Down
115 changes: 114 additions & 1 deletion tests/transport/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import functools

import freezegun
import mock
import pytest
import requests
import requests.adapters
from six.moves import http_client
Expand All @@ -22,6 +27,12 @@
from tests.transport import compliance


@pytest.fixture
def frozen_time():
with freezegun.freeze_time("1970-01-01 00:00:00", tick=False) as frozen:
yield frozen


class TestRequestResponse(compliance.RequestResponseTests):
def make_request(self):
return google.auth.transport.requests.Request()
Expand All @@ -34,6 +45,41 @@ def test_timeout(self):
assert http.request.call_args[1]["timeout"] == 5


class TestTimeoutGuard(object):
def make_guard(self, *args, **kwargs):
return google.auth.transport.requests.TimeoutGuard(*args, **kwargs)

def test_tracks_elapsed_time(self, frozen_time):
with self.make_guard(timeout=10) as guard:
frozen_time.tick(delta=3.8)
assert guard.remaining_timeout == 6.2

def test_noop_if_no_timeout(self, frozen_time):
with self.make_guard(timeout=None) as guard:
frozen_time.tick(delta=datetime.timedelta(days=3650))
# NOTE: no timeout error raised, despite years have passed
assert guard.remaining_timeout is None

def test_error_on_timeout(self, frozen_time):
with pytest.raises(requests.exceptions.Timeout):
with self.make_guard(timeout=10) as guard:
frozen_time.tick(delta=10.001)
assert guard.remaining_timeout == pytest.approx(-0.001)

def test_custom_timeout_error_type(self, frozen_time):
class FooError(Exception):
pass

with pytest.raises(FooError):
with self.make_guard(timeout=1, timeout_error_type=FooError):
frozen_time.tick(2)

def test_lets_errors_bubble_up(self, frozen_time):
with pytest.raises(IndexError):
with self.make_guard(timeout=1):
[1, 2, 3][3]


class CredentialsStub(google.auth.credentials.Credentials):
def __init__(self, token="token"):
super(CredentialsStub, self).__init__()
Expand All @@ -49,6 +95,18 @@ def refresh(self, request):
self.token += "1"


class TimeTickCredentialsStub(CredentialsStub):
"""Credentials that spend some (mocked) time when refreshing a token."""

def __init__(self, time_tick, token="token"):
self._time_tick = time_tick
super(TimeTickCredentialsStub, self).__init__(token=token)

def refresh(self, request):
self._time_tick()
super(TimeTickCredentialsStub, self).refresh(requests)


class AdapterStub(requests.adapters.BaseAdapter):
def __init__(self, responses, headers=None):
super(AdapterStub, self).__init__()
Expand All @@ -69,6 +127,18 @@ def close(self): # pragma: NO COVER
return


class TimeTickAdapterStub(AdapterStub):
"""Adapter that spends some (mocked) time when making a request."""

def __init__(self, time_tick, responses, headers=None):
self._time_tick = time_tick
super(TimeTickAdapterStub, self).__init__(responses, headers=headers)

def send(self, request, **kwargs):
self._time_tick()
return super(TimeTickAdapterStub, self).send(request, **kwargs)


def make_response(status=http_client.OK, data=None):
response = requests.Response()
response.status_code = status
Expand Down Expand Up @@ -121,7 +191,9 @@ def test_request_refresh(self):
[make_response(status=http_client.UNAUTHORIZED), final_response]
)

authed_session = google.auth.transport.requests.AuthorizedSession(credentials)
authed_session = google.auth.transport.requests.AuthorizedSession(
credentials, refresh_timeout=60
)
authed_session.mount(self.TEST_URL, adapter)

result = authed_session.request("GET", self.TEST_URL)
Expand All @@ -136,3 +208,44 @@ def test_request_refresh(self):

assert adapter.requests[1].url == self.TEST_URL
assert adapter.requests[1].headers["authorization"] == "token1"

def test_request_timout(self, frozen_time):
tick_one_second = functools.partial(frozen_time.tick, delta=1.0)

credentials = mock.Mock(
wraps=TimeTickCredentialsStub(time_tick=tick_one_second)
)
adapter = TimeTickAdapterStub(
time_tick=tick_one_second,
responses=[
make_response(status=http_client.UNAUTHORIZED),
make_response(status=http_client.OK),
],
)

authed_session = google.auth.transport.requests.AuthorizedSession(credentials)
authed_session.mount(self.TEST_URL, adapter)

# Because at least two requests have to be made, and each takes one
# second, the total timeout specified will be exceeded.
with pytest.raises(requests.exceptions.Timeout):
authed_session.request("GET", self.TEST_URL, timeout=1.9)

def test_request_timeout_w_refresh_timeout(self, frozen_time):
credentials = mock.Mock(wraps=CredentialsStub())
adapter = TimeTickAdapterStub(
time_tick=functools.partial(frozen_time.tick, delta=1.0), # one second
responses=[
make_response(status=http_client.UNAUTHORIZED),
make_response(status=http_client.OK),
],
)

authed_session = google.auth.transport.requests.AuthorizedSession(
credentials, refresh_timeout=0.9
)
authed_session.mount(self.TEST_URL, adapter)

# The timeout is long, but the short refresh timeout will prevail.
with pytest.raises(requests.exceptions.Timeout):
authed_session.request("GET", self.TEST_URL, timeout=60)

0 comments on commit 94276e8

Please sign in to comment.