Skip to content

Commit

Permalink
Fix #987: implement client timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
asvetlov committed Jul 31, 2016
1 parent 1bfc02a commit 05890bd
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 36 deletions.
16 changes: 10 additions & 6 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .client_reqrep import ClientRequest, ClientResponse
from .client_ws import ClientWebSocketResponse
from .errors import WSServerHandshakeError
from .helpers import CookieJar
from .helpers import CookieJar, Timeout

__all__ = ('ClientSession', 'request', 'get', 'options', 'head',
'delete', 'post', 'put', 'patch', 'ws_connect')
Expand Down Expand Up @@ -106,7 +106,8 @@ def request(self, method, url, *,
expect100=False,
read_until_eof=True,
proxy=None,
proxy_auth=None):
proxy_auth=None,
timeout=5*60):
"""Perform HTTP request."""

return _RequestContextManager(
Expand All @@ -127,7 +128,8 @@ def request(self, method, url, *,
expect100=expect100,
read_until_eof=read_until_eof,
proxy=proxy,
proxy_auth=proxy_auth,))
proxy_auth=proxy_auth,
timeout=timeout))

@asyncio.coroutine
def _request(self, method, url, *,
Expand All @@ -145,7 +147,8 @@ def _request(self, method, url, *,
expect100=False,
read_until_eof=True,
proxy=None,
proxy_auth=None):
proxy_auth=None,
timeout=5*60):

if version is not None:
warnings.warn("HTTP version should be specified "
Expand Down Expand Up @@ -187,9 +190,10 @@ def _request(self, method, url, *,
auth=auth, version=version, compress=compress, chunked=chunked,
expect100=expect100,
loop=self._loop, response_class=self._response_class,
proxy=proxy, proxy_auth=proxy_auth,)
proxy=proxy, proxy_auth=proxy_auth, timeout=timeout)

conn = yield from self._connector.connect(req)
with Timeout(timeout, loop=self._loop):
conn = yield from self._connector.connect(req)
try:
resp = req.send(conn.writer, conn.reader)
try:
Expand Down
21 changes: 14 additions & 7 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import aiohttp

from . import hdrs, helpers, streams
from .helpers import Timeout
from .log import client_logger
from .multipart import MultipartWriter
from .protocol import HttpMessage
Expand Down Expand Up @@ -68,7 +69,8 @@ def __init__(self, method, url, *,
version=aiohttp.HttpVersion11, compress=None,
chunked=None, expect100=False,
loop=None, response_class=None,
proxy=None, proxy_auth=None):
proxy=None, proxy_auth=None,
timeout=5*60):

if loop is None:
loop = asyncio.get_event_loop()
Expand All @@ -80,6 +82,7 @@ def __init__(self, method, url, *,
self.compress = compress
self.loop = loop
self.response_class = response_class or ClientResponse
self._timeout = timeout

if loop.get_debug():
self._source_traceback = traceback.extract_stack(sys._getframe(1))
Expand Down Expand Up @@ -502,7 +505,8 @@ def send(self, writer, reader):

self.response = self.response_class(
self.method, self.url, self.host,
writer=self._writer, continue100=self._continue)
writer=self._writer, continue100=self._continue,
timeout=self._timeout)
self.response._post_init(self.loop)
return self.response

Expand Down Expand Up @@ -546,7 +550,8 @@ class ClientResponse:
_loop = None
_closed = True # to allow __del__ for non-initialized properly response

def __init__(self, method, url, host='', *, writer=None, continue100=None):
def __init__(self, method, url, host='', *, writer=None, continue100=None,
timeout=5*60):
super().__init__()

self.method = method
Expand All @@ -558,6 +563,7 @@ def __init__(self, method, url, host='', *, writer=None, continue100=None):
self._closed = False
self._should_close = True # override by message.should_close later
self._history = ()
self._timeout = timeout

def _post_init(self, loop):
self._loop = loop
Expand Down Expand Up @@ -609,7 +615,7 @@ def _setup_connection(self, connection):
self._reader = connection.reader
self._connection = connection
self.content = self.flow_control_class(
connection.reader, loop=connection.loop)
connection.reader, loop=connection.loop, timeout=self._timeout)

def _need_parse_response_body(self):
return (self.method.lower() != 'head' and
Expand All @@ -624,7 +630,8 @@ def start(self, connection, read_until_eof=False):
httpstream = self._reader.set_parser(self._response_parser)

# read response
message = yield from httpstream.read()
with Timeout(self._timeout, loop=self._loop):
message = yield from httpstream.read()
if message.code != 100:
break

Expand All @@ -643,11 +650,11 @@ def start(self, connection, read_until_eof=False):
self.raw_headers = tuple(message.raw_headers)

# payload
response_with_body = self._need_parse_response_body()
rwb = self._need_parse_response_body()
self._reader.set_parser(
aiohttp.HttpPayloadParser(message,
readall=read_until_eof,
response_with_body=response_with_body),
response_with_body=rwb),
self.content)

# cookies
Expand Down
53 changes: 35 additions & 18 deletions aiohttp/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class StreamReader(asyncio.StreamReader, AsyncStreamReaderMixin):

total_bytes = 0

def __init__(self, limit=DEFAULT_LIMIT, loop=None):
def __init__(self, limit=DEFAULT_LIMIT, timeout=None, loop=None):
self._limit = limit
if loop is None:
loop = asyncio.get_event_loop()
Expand All @@ -93,8 +93,10 @@ def __init__(self, limit=DEFAULT_LIMIT, loop=None):
self._buffer_offset = 0
self._eof = False
self._waiter = None
self._canceller = None
self._eof_waiter = None
self._exception = None
self._timeout = timeout

def __repr__(self):
info = ['StreamReader']
Expand Down Expand Up @@ -122,6 +124,11 @@ def set_exception(self, exc):
if not waiter.cancelled():
waiter.set_exception(exc)

canceller = self._canceller
if canceller is not None:
self._canceller = None
canceller.cancel()

def feed_eof(self):
self._eof = True

Expand All @@ -131,6 +138,11 @@ def feed_eof(self):
if not waiter.cancelled():
waiter.set_result(True)

canceller = self._canceller
if canceller is not None:
self._canceller = None
canceller.cancel()

waiter = self._eof_waiter
if waiter is not None:
self._eof_waiter = None
Expand Down Expand Up @@ -185,15 +197,32 @@ def feed_data(self, data):
if not waiter.cancelled():
waiter.set_result(False)

def _create_waiter(self, func_name):
canceller = self._canceller
if canceller is not None:
self._canceller = None
canceller.cancel()

@asyncio.coroutine
def _wait(self, func_name):
# StreamReader uses a future to link the protocol feed_data() method
# to a read coroutine. Running two read coroutines at the same time
# would have an unexpected behaviour. It would not possible to know
# which coroutine would get the next data.
if self._waiter is not None:
raise RuntimeError('%s() called while another coroutine is '
'already waiting for incoming data' % func_name)
return helpers.create_future(self._loop)
waiter = self._waiter = helpers.create_future(self._loop)
if self._timeout:
self._canceller = self._loop.call_later(self._timeout,
self.set_exception,
asyncio.TimeoutError())
try:
yield from waiter
finally:
self._waiter = None
if self._canceller is not None:
self._canceller.cancel()
self._canceller = None

@asyncio.coroutine
def readline(self):
Expand Down Expand Up @@ -222,11 +251,7 @@ def readline(self):
break

if not_enough:
self._waiter = self._create_waiter('readline')
try:
yield from self._waiter
finally:
self._waiter = None
yield from self._wait('readline')

return b''.join(line)

Expand Down Expand Up @@ -265,11 +290,7 @@ def read(self, n=-1):
return b''.join(blocks)

if not self._buffer and not self._eof:
self._waiter = self._create_waiter('read')
try:
yield from self._waiter
finally:
self._waiter = None
yield from self._wait('read')

return self._read_nowait(n)

Expand All @@ -279,11 +300,7 @@ def readany(self):
raise self._exception

if not self._buffer and not self._eof:
self._waiter = self._create_waiter('readany')
try:
yield from self._waiter
finally:
self._waiter = None
yield from self._wait('readany')

return self._read_nowait()

Expand Down
34 changes: 34 additions & 0 deletions tests/test_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,3 +474,37 @@ def handler(request):
resp = yield from client.delete('/')
assert resp.status == 204
yield from resp.release()


@pytest.mark.run_loop
def test_timeout_on_reading_headers(create_app_and_client, loop):

@asyncio.coroutine
def handler(request):
resp = web.StreamResponse()
yield from asyncio.sleep(0.1, loop=loop)
yield from resp.prepare(request)
return resp

app, client = yield from create_app_and_client()
app.router.add_route('GET', '/', handler)
with pytest.raises(asyncio.TimeoutError):
yield from client.get('/', timeout=0.01)


@pytest.mark.run_loop
def test_timeout_on_reading_data(create_app_and_client, loop):

@asyncio.coroutine
def handler(request):
resp = web.StreamResponse()
yield from resp.prepare(request)
yield from asyncio.sleep(0.1, loop=loop)
return resp

app, client = yield from create_app_and_client()
app.router.add_route('GET', '/', handler)
resp = yield from client.get('/', timeout=0.05)

with pytest.raises(asyncio.TimeoutError):
yield from resp.read()
4 changes: 2 additions & 2 deletions tests/test_client_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,11 @@ def side_effect(*args, **kwargs):

def test_override_flow_control(self):
class MyResponse(ClientResponse):
flow_control_class = aiohttp.FlowControlDataQueue
flow_control_class = aiohttp.StreamReader
response = MyResponse('get', 'http://my-cl-resp.org')
response._post_init(self.loop)
response._setup_connection(self.connection)
self.assertIsInstance(response.content, aiohttp.FlowControlDataQueue)
self.assertIsInstance(response.content, aiohttp.StreamReader)
response.close()

@mock.patch('aiohttp.client_reqrep.chardet')
Expand Down
2 changes: 2 additions & 0 deletions tests/test_client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,9 @@ def create_connection(req):
assert e.strerror == err.strerror


@pytest.mark.run_loop
def test_request_ctx_manager_props(loop):
yield from asyncio.sleep(0, loop=loop) # to make it a task
with aiohttp.ClientSession(loop=loop) as client:
ctx_mgr = client.get('http://example.com')

Expand Down
Loading

0 comments on commit 05890bd

Please sign in to comment.