diff --git a/CHANGES/2567.bugfix b/CHANGES/2567.bugfix new file mode 100644 index 00000000000..b6c235d87e9 --- /dev/null +++ b/CHANGES/2567.bugfix @@ -0,0 +1 @@ +Return client connection back to free pool on error in `connector.connect()`. diff --git a/aiohttp/connector.py b/aiohttp/connector.py index c67ad78bf92..8fbc0a152d0 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -283,6 +283,15 @@ def _cleanup(self): self._cleanup_handle = helpers.weakref_handle( self, '_cleanup', timeout, self._loop) + def _drop_acquired_per_host(self, key, val): + acquired_per_host = self._acquired_per_host + if key not in acquired_per_host: + return + conns = acquired_per_host[key] + conns.remove(val) + if not conns: + del self._acquired_per_host[key] + def _cleanup_closed(self): """Double confirmation for transport close. Some broken ssl servers may leave socket open without proper close. @@ -354,7 +363,7 @@ def connect(self, req): if self._limit: # total calc available connections - available = self._limit - len(self._waiters) - len(self._acquired) + available = self._limit - len(self._acquired) # check limit per host if (self._limit_per_host and available > 0 and @@ -396,15 +405,16 @@ def connect(self, req): raise ClientConnectionError("Connector is closed.") except: # signal to waiter - for waiter in self._waiters[key]: - if not waiter.done(): - waiter.set_result(None) - break + if key in self._waiters: + for waiter in self._waiters[key]: + if not waiter.done(): + waiter.set_result(None) + break raise finally: if not self._closed: self._acquired.remove(placeholder) - self._acquired_per_host[key].remove(placeholder) + self._drop_acquired_per_host(key, placeholder) self._acquired.add(proto) self._acquired_per_host[key].add(proto) @@ -463,9 +473,7 @@ def _release_acquired(self, key, proto): try: self._acquired.remove(proto) - self._acquired_per_host[key].remove(proto) - if not self._acquired_per_host[key]: - del self._acquired_per_host[key] + self._drop_acquired_per_host(key, proto) except KeyError: # pragma: no cover # this may be result of undetermenistic order of objects # finalization due garbage collection. diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 2335de83065..66d1fe672b3 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -2288,3 +2288,63 @@ def close(self): assert resp.status == 200 finally: yield from client.close() + + +@asyncio.coroutine +def test_error_in_performing_request(loop, ssl_ctx, + test_client, test_server): + @asyncio.coroutine + def handler(request): + return web.Response() + + app = web.Application() + app.router.add_route('GET', '/', handler) + + server = yield from test_server(app, ssl=ssl_ctx) + + conn = aiohttp.TCPConnector(limit=1, loop=loop) + client = yield from test_client(server, connector=conn) + + with pytest.raises(aiohttp.ClientConnectionError): + yield from client.get('/') + + # second try should not hang + with pytest.raises(aiohttp.ClientConnectionError): + yield from client.get('/') + + +@asyncio.coroutine +def test_await_after_cancelling(loop, test_client): + @asyncio.coroutine + def handler(request): + return web.Response() + + app = web.Application() + app.router.add_route('GET', '/', handler) + + client = yield from test_client(app) + + fut1 = create_future(loop) + fut2 = create_future(loop) + + @asyncio.coroutine + def fetch1(): + resp = yield from client.get('/') + assert resp.status == 200 + fut1.set_result(None) + with pytest.raises(asyncio.CancelledError): + yield from fut2 + resp.release() + + @asyncio.coroutine + def fetch2(): + yield from fut1 + resp = yield from client.get('/') + assert resp.status == 200 + + @asyncio.coroutine + def canceller(): + yield from fut1 + fut2.cancel() + + yield from asyncio.gather(fetch1(), fetch2(), canceller(), loop=loop) diff --git a/tests/test_connector.py b/tests/test_connector.py index b85d72d5d8b..c2cfab833fb 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -343,6 +343,28 @@ def test_release_close(loop): assert proto.close.called +def test__drop_acquire_per_host1(loop): + conn = aiohttp.BaseConnector(loop=loop) + conn._drop_acquired_per_host(123, 456) + assert len(conn._acquired_per_host) == 0 + + +def test__drop_acquire_per_host2(loop): + conn = aiohttp.BaseConnector(loop=loop) + conn._acquired_per_host[123].add(456) + conn._drop_acquired_per_host(123, 456) + assert len(conn._acquired_per_host) == 0 + + +def test__drop_acquire_per_host3(loop): + conn = aiohttp.BaseConnector(loop=loop) + conn._acquired_per_host[123].add(456) + conn._acquired_per_host[123].add(789) + conn._drop_acquired_per_host(123, 456) + assert len(conn._acquired_per_host) == 1 + assert conn._acquired_per_host[123] == {789} + + @asyncio.coroutine def test_tcp_connector_certificate_error(loop): req = ClientRequest('GET', URL('https://127.0.0.1:443'), loop=loop)