From 3a6a2bb38a4ec7caf232ff90a22f2bc99109447d Mon Sep 17 00:00:00 2001 From: Justin Mayfield Date: Wed, 19 Oct 2016 02:25:11 -0600 Subject: [PATCH 1/5] API for connection limits of `Server` Add a `pause`/`resume` API to `Server` which removes/adds a server's listening sockets on its event loop selector. This facility allows DoS prevention from SYN flooding connection herds. Add `max_connections` kwarg to `loop.create_server` and `Server` which controls pause/resume behavior when not `None`. Notes: 1. Using Server.pause/resume and create_server(max_connections) are mutually exclusive. 2. The listen backlog and accept semantics are not taken into consideration. As a result the actual number of connections established will vary and should be considered platform and/or event loop dependant. --- asyncio/base_events.py | 37 ++++++++++++++++++++++++++++++++++--- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 648b9b9b..6ae6004b 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -176,9 +176,15 @@ def _run_until_complete_cb(fut): class Server(events.AbstractServer): - def __init__(self, loop, sockets): + def __init__(self, loop, sockets, protocol_factory, ssl, backlog, *, + max_connections=None): self._loop = loop self.sockets = sockets + self._protocol_factory = protocol_factory + self._ssl = ssl + self._backlog = backlog + self._max_connections = max_connections + self._paused = False self._active_count = 0 self._waiters = [] @@ -188,14 +194,37 @@ def __repr__(self): def _attach(self): assert self.sockets is not None self._active_count += 1 + if self._max_connections is not None and \ + not self._paused and \ + self._active_count >= self._max_connections: + self.pause() def _detach(self): assert self._active_count > 0 self._active_count -= 1 if self._active_count == 0 and self.sockets is None: self._wakeup() + elif self._paused and self._max_connections is not None and \ + self._active_count < self._max_connections: + self.resume() + + def pause(self): + """Pause future calls to accept().""" + assert not self._paused + self._paused = True + for sock in self.sockets: + self._loop.remove_reader(sock.fileno()) + + def resume(self): + """Resume use of accept() on listening socket(s).""" + assert self._paused + self._paused = False + for sock in self.sockets: + self._loop._start_serving(self._protocol_factory, sock, self._ssl, + self, self._backlog) def close(self): + self._protocol_factory = None sockets = self.sockets if sockets is None: return @@ -943,7 +972,8 @@ def create_server(self, protocol_factory, host=None, port=None, backlog=100, ssl=None, reuse_address=None, - reuse_port=None): + reuse_port=None, + max_connections=None): """Create a TCP server. The host parameter can be a string, in that case the TCP server is bound @@ -1026,7 +1056,8 @@ def create_server(self, protocol_factory, host=None, port=None, raise ValueError('Neither host/port nor sock were specified') sockets = [sock] - server = Server(self, sockets) + server = Server(self, sockets, protocol_factory, ssl, backlog, + max_connections=max_connections) for sock in sockets: sock.listen(backlog) sock.setblocking(False) From 526b74d37483dda0230adba46b396f46c2a86ef9 Mon Sep 17 00:00:00 2001 From: Justin Mayfield Date: Wed, 19 Oct 2016 02:46:19 -0600 Subject: [PATCH 2/5] Add max_connections to create_unix_server. Fix Server create for unix sockets and add support for max_connections. --- asyncio/unix_events.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 65b61db6..22b8e6fb 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -247,7 +247,8 @@ def create_unix_connection(self, protocol_factory, path, *, @coroutine def create_unix_server(self, protocol_factory, path=None, *, - sock=None, backlog=100, ssl=None): + sock=None, backlog=100, ssl=None, + max_connections=None): if isinstance(ssl, bool): raise TypeError('ssl argument must be an SSLContext or None') @@ -294,7 +295,8 @@ def create_unix_server(self, protocol_factory, path=None, *, 'A UNIX Domain Stream Socket was expected, got {!r}' .format(sock)) - server = base_events.Server(self, [sock]) + server = base_events.Server(self, [sock], protocol_factory, ssl, + backlog, max_connections=max_connections) sock.listen(backlog) sock.setblocking(False) self._start_serving(protocol_factory, sock, ssl, server) From 56d0d0bb9c922a47645fcebff9e40df4cad36fe4 Mon Sep 17 00:00:00 2001 From: Justin Mayfield Date: Wed, 19 Oct 2016 15:38:38 -0600 Subject: [PATCH 3/5] Add tests for Server pause behavior. Tests for create_server(max_connections=number) and Server.pause/resume. --- tests/test_events.py | 70 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/tests/test_events.py b/tests/test_events.py index 7df926f1..8f70d8ae 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1312,6 +1312,76 @@ def connection_made(self, transport): server.close() + def test_create_server_max_connections(self): + protos = [] + on_data = asyncio.Event(loop=self.loop) + + class MaxConnTestProto(MyBaseProto): + def connection_made(self, transport): + super().connection_made(transport) + protos.append(self) + def data_received(self, data): + super().data_received(data) + on_data.set() + + f = self.loop.create_server(lambda: MaxConnTestProto(loop=self.loop), + '0.0.0.0', 0, max_connections=2) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + port = sock.getsockname()[1] + + # Low water.. + c1 = socket.socket() + c1.connect(('127.0.0.1', port)) + c1.sendall(b'x') + self.loop.run_until_complete(on_data.wait()) + on_data.clear() + self.assertFalse(server._paused) + self.loop._selector.get_key(sock.fileno()) # has reader + + # High water.. + c2 = socket.socket() + c2.connect(('127.0.0.1', port)) + c2.sendall(b'x') + self.loop.run_until_complete(on_data.wait()) + on_data.clear() + self.assertEqual(server._active_count, 2) + self.assertTrue(server._paused) + self.assertRaises(KeyError, self.loop._selector.get_key, sock.fileno()) + + # Low water again.. + p = protos.pop(0) + p.transport.close() + self.loop.run_until_complete(p.done) + self.assertFalse(server._paused) + self.loop._selector.get_key(sock.fileno()) # has reader + + # cleanup + p = protos.pop(0) + p.transport.close() + self.loop.run_until_complete(p.done) + c1.close() + c2.close() + server.close() + self.assertFalse(protos) + + def test_create_server_pause_resume(self): + f = self.loop.create_server(lambda: None, '0.0.0.0', 0) + server = self.loop.run_until_complete(f) + sock_fd = server.sockets[0].fileno() + + server.pause() + self.assertTrue(server._paused) + self.assertRaises(KeyError, self.loop._selector.get_key, sock_fd) + self.assertRaises(AssertionError, server.pause) + + server.resume() + self.assertFalse(server._paused) + self.loop._selector.get_key(sock_fd) # has reader + self.assertRaises(AssertionError, server.resume) + + server.close() + def test_server_close(self): f = self.loop.create_server(MyProto, '0.0.0.0', 0) server = self.loop.run_until_complete(f) From f157ec71358cc7c5cf3def672b976e9ae5938d6a Mon Sep 17 00:00:00 2001 From: Justin Mayfield Date: Wed, 19 Oct 2016 16:16:55 -0600 Subject: [PATCH 4/5] Add Server pause tests for unix servers. Refactored a bit to share most the logic with the TCP server. --- tests/test_events.py | 51 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/tests/test_events.py b/tests/test_events.py index 8f70d8ae..f2cab282 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1327,34 +1327,59 @@ def data_received(self, data): f = self.loop.create_server(lambda: MaxConnTestProto(loop=self.loop), '0.0.0.0', 0, max_connections=2) server = self.loop.run_until_complete(f) - sock = server.sockets[0] - port = sock.getsockname()[1] + port = server.sockets[0].getsockname()[1] + self._test_create_server_max_connections(server, socket.socket, + ('127.0.0.1', port), + protos, on_data) + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server_max_connections(self): + protos = [] + on_data = asyncio.Event(loop=self.loop) + + class MaxConnTestProto(MyBaseProto): + def connection_made(self, transport): + super().connection_made(transport) + protos.append(self) + def data_received(self, data): + super().data_received(data) + on_data.set() + + factory = lambda: MaxConnTestProto(loop=self.loop) + server, path = self._make_unix_server(factory, max_connections=2) + socket_factory = lambda: socket.socket(socket.AF_UNIX) + self._test_create_server_max_connections(server, socket_factory, path, + protos, on_data) + + def _test_create_server_max_connections(self, server, socket_factory, + connect_to, protos, on_data): + sock_fd = server.sockets[0].fileno() # Low water.. - c1 = socket.socket() - c1.connect(('127.0.0.1', port)) + c1 = socket_factory() + c1.connect(connect_to) c1.sendall(b'x') self.loop.run_until_complete(on_data.wait()) on_data.clear() self.assertFalse(server._paused) - self.loop._selector.get_key(sock.fileno()) # has reader + self.loop._selector.get_key(sock_fd) # has reader # High water.. - c2 = socket.socket() - c2.connect(('127.0.0.1', port)) + c2 = socket_factory() + c2.connect(connect_to) c2.sendall(b'x') self.loop.run_until_complete(on_data.wait()) on_data.clear() self.assertEqual(server._active_count, 2) self.assertTrue(server._paused) - self.assertRaises(KeyError, self.loop._selector.get_key, sock.fileno()) + self.assertRaises(KeyError, self.loop._selector.get_key, sock_fd) # Low water again.. p = protos.pop(0) p.transport.close() self.loop.run_until_complete(p.done) self.assertFalse(server._paused) - self.loop._selector.get_key(sock.fileno()) # has reader + self.loop._selector.get_key(sock_fd) # has reader # cleanup p = protos.pop(0) @@ -1369,7 +1394,15 @@ def test_create_server_pause_resume(self): f = self.loop.create_server(lambda: None, '0.0.0.0', 0) server = self.loop.run_until_complete(f) sock_fd = server.sockets[0].fileno() + self._test_create_server_pause_resume(server, sock_fd) + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server_pause_resume(self): + server, path = self._make_unix_server(lambda: None) + sock_fd = server.sockets[0].fileno() + self._test_create_server_pause_resume(server, sock_fd) + def _test_create_server_pause_resume(self, server, sock_fd): server.pause() self.assertTrue(server._paused) self.assertRaises(KeyError, self.loop._selector.get_key, sock_fd) From 1a38d87b6c8a4172b18001d293c5b1b17a6e87ae Mon Sep 17 00:00:00 2001 From: Justin Mayfield Date: Wed, 19 Oct 2016 16:32:51 -0600 Subject: [PATCH 5/5] Skip Server pause tests with Proactor event loop --- tests/test_events.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_events.py b/tests/test_events.py index f2cab282..4b667f15 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -2265,6 +2265,12 @@ def test_create_datagram_endpoint(self): def test_remove_fds_after_closing(self): raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + + def test_create_server_max_connections(self): + raise unittest.SkipTest("IocpEventLoop incompatible with max_connections") + + def test_create_server_pause_resume(self): + raise unittest.SkipTest("IocpEventLoop incompatible with Server pause") else: from asyncio import selectors