diff --git a/.gitignore b/.gitignore index 146736ff..116bc786 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ build/ coverage.xml docs/_themes docs/_build +.venv/ diff --git a/src/waitress/channel.py b/src/waitress/channel.py index eb59dd3f..976bff8d 100644 --- a/src/waitress/channel.py +++ b/src/waitress/channel.py @@ -45,6 +45,7 @@ class HTTPChannel(wasyncore.dispatcher): last_activity = 0 # Time of last activity will_close = False # set to True to close the socket. close_when_flushed = False # set to True to close the socket when flushed + closed = False # set to True when closed not just due to being disconnected at the start sent_continue = False # used as a latch after sending 100 continue total_outbufs_len = 0 # total bytes ready to send current_outbuf_count = 0 # total bytes written to current outbuf @@ -67,6 +68,9 @@ def __init__(self, server, sock, addr, adj, map=None): self.outbuf_lock = threading.Condition() wasyncore.dispatcher.__init__(self, sock, map=map) + if not self.connected: + # Sometimes can be closed quickly and getpeername fails. + self.handle_close() # Don't let wasyncore.dispatcher throttle self.addr on us. self.addr = addr @@ -86,15 +90,15 @@ def writable(self): # the channel (possibly by our server maintenance logic), run # handle_write - return self.total_outbufs_len or self.will_close or self.close_when_flushed + return (self.total_outbufs_len or self.will_close or self.close_when_flushed) def handle_write(self): # Precondition: there's data in the out buffer to be sent, or # there's a pending will_close request - if not self.connected: - # we dont want to close the channel twice - + if self.closed: + # we dont want to close the channel twice. + # but we need let the channel close if it's marked to close return # try to flush any pending output @@ -150,7 +154,6 @@ def readable(self): # 3. There are not too many tasks already queued # 4. There's no data in the output buffer that needs to be sent # before we potentially create a new task. - return not ( self.will_close or self.close_when_flushed @@ -314,6 +317,7 @@ def handle_close(self): self.total_outbufs_len = 0 self.connected = False self.outbuf_lock.notify() + self.closed = True wasyncore.dispatcher.close(self) def add_channel(self, map=None): diff --git a/tests/test_channel.py b/tests/test_channel.py index 8467ae7a..86dacfbc 100644 --- a/tests/test_channel.py +++ b/tests/test_channel.py @@ -1,4 +1,6 @@ +from errno import EINVAL import io +import socket import unittest import pytest @@ -11,10 +13,10 @@ def _makeOne(self, sock, addr, adj, map=None): server = DummyServer() return HTTPChannel(server, sock, addr, adj=adj, map=map) - def _makeOneWithMap(self, adj=None): + def _makeOneWithMap(self, adj=None, sock_shutdown=False): if adj is None: adj = DummyAdjustments() - sock = DummySock() + sock = DummySock(shutdown=sock_shutdown) map = {} inst = self._makeOne(sock, "127.0.0.1", adj, map=map) inst.outbuf_lock = DummyLock() @@ -65,8 +67,18 @@ def test_writable_nothing_in_outbuf_will_close(self): def test_handle_write_not_connected(self): inst, sock, map = self._makeOneWithMap() inst.connected = False + # TODO: handle_write never returns anything anyway self.assertFalse(inst.handle_write()) + def test_handle_write_not_connected_but_will_close(self): + inst, sock, map = self._makeOneWithMap() + inst.connected = False + inst.will_close = True + # https://github.com/Pylons/waitress/issues/418 + # Ensure we actually handle_close even if not connected + self.assertFalse(inst.handle_write()) + self.assertEqual(len(map), 0) + def test_handle_write_with_requests(self): inst, sock, map = self._makeOneWithMap() inst.requests = True @@ -906,8 +918,9 @@ class DummySock: blocking = False closed = False - def __init__(self): + def __init__(self, shutdown=False): self.sent = b"" + self.shutdown = shutdown def setblocking(self, *arg): self.blocking = True @@ -916,6 +929,8 @@ def fileno(self): return 100 def getpeername(self): + if self.shutdown: + raise OSError(EINVAL) return "127.0.0.1" def getsockopt(self, level, option): diff --git a/tests/test_server.py b/tests/test_server.py index 6edc3b24..fac986f0 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,6 +1,12 @@ import errno +import select import socket +import struct +from threading import Event +from time import sleep +import time import unittest +from waitress.channel import HTTPChannel dummy_app = object() @@ -311,6 +317,98 @@ def test_create_with_one_socket_handle_accept_noerror(self): self.assertEqual(innersock.opts, [("level", "optname", "value")]) self.assertEqual(L, [(inst, innersock, None, inst.adj)]) + def test_error_request_quick_shutdown(self): + """ Issue found in production that led to 100% useage because getpeername failed after accept but before channel setup. + """ + client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 0) + sockets[0].bind(("127.0.0.1", 8000)) + sockets[0].listen() + inst = self._makeWithSockets(_start=False, sockets=sockets) + channels = [] + inst.channel_class = make_quick_shutdown_channel(client, channels) + inst.task_dispatcher = DummyTaskDispatcher() + + # This will make getpeername fail fast with EINVAL OSError + client.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack('ii', 1, 0)) + client.connect(("127.0.0.1", 8000)) + client.send(b"1") # Send our fake request before we accept and close the connection + inst.handle_accept() # ShutdownServer will close the connection after acceot but before getpeername + self.assertRaises(OSError, sockets[0].getpeername) + self.assertFalse(channels[0].connected, "race condition means our socket is marked not connected") + self.assertNotIn(channels[0], inst._map.values(), "we should get an automatic close") + + def test_error_request_no_loop(self): + """ Issue found in production that led to 100% useage because getpeername failed after accept but before channel setup. + """ + client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 0) + sockets[0].bind(("127.0.0.1", 8000)) + sockets[0].listen() + inst = self._makeWithSockets(_start=False, sockets=sockets) + channels = [] + inst.channel_class = make_quick_shutdown_channel(client, channels, shutdown=False) + inst.task_dispatcher = DummyTaskDispatcher() + + # This will make getpeername fail fast with EINVAL OSError + client.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack('ii', 1, 0)) + client.connect(("127.0.0.1", 8000)) + client.send(b"1") # Send our fake request before we accept and close the connection + inst.handle_accept() # ShutdownServer will close the connection after acceot but before getpeername + self.assertRaises(OSError, sockets[0].getpeername) + self.assertEquals(len(channels), 1) + channels[0].connected = False # This used to create a 100% CPU loop + + server_run_for_count(inst, 1) # Read the request + self.assertTrue(channels[0].requests[0].error, "for this bug we need the request to have a parsing error") + server_run_for_count(inst, 5) + # simulate thread processing the request + channels[0].service() + self.assertTrue(channels[0].close_when_flushed, "This prevents reads (which lead to close) and loops on handle_write (with 100% CPU)") + server_run_for_count(inst, 5) # Our loop + self.assertNotIn(channels[0], inst._map.values(), "broken request didn't close the channel") + self.assertEqual(channels[0].count_close, 1, "but also this connection never gets closed") + self.assertLess(channels[0].count_writes, 5, "We're supposed to be in a loop trying to write but can't") + + def test_error_request_maintainace_cleanup(self): + """ Issue found in production that led to 100% useage because getpeername failed after accept but before channel setup. + """ + client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 0) + sockets[0].bind(("127.0.0.1", 8000)) + sockets[0].listen() + inst = self._makeWithSockets(_start=False, sockets=sockets) + channels = [] + inst.channel_class = make_quick_shutdown_channel(client, channels, shutdown=False) + inst.task_dispatcher = DummyTaskDispatcher() + + # This will make getpeername fail fast with EINVAL OSError + client.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack('ii', 1, 0)) + client.connect(("127.0.0.1", 8000)) + client.send(b"1") # Send our fake request before we accept and close the connection + inst.handle_accept() # ShutdownServer will close the connection after acceot but before getpeername + self.assertRaises(OSError, sockets[0].getpeername) + self.assertNotEqual(channels, []) + channels[0].connected = False ## race condition means our socket is marked not connected + + server_run_for_count(inst, 1) # Read the request + # self.assertTrue(channels[0].requests[0].error, "for this bug we need the request to have a parsing error") + server_run_for_count(inst, 5) + channels[0].service() + self.assertTrue(channels[0].close_when_flushed, "This prevents reads (which lead to close) and loops on handle_write (with 100% CPU)") + server_run_for_count(inst, 5) # Our loop + channels[0].last_activity = 0 + inst.maintenance(1000) + self.assertEqual(channels[0].will_close, 1, "maintenance will try to close it") + self.assertNotIn(channels[0], inst._map.values(), "broken request didn't close the channel") + server_run_for_count(inst, 5) # Our loop + self.assertNotEqual(channels[0].count_writes, 10, "But we still get our loop") + if hasattr(socket, "AF_UNIX"): @@ -516,3 +614,61 @@ def __init__(self): def warning(self, msg, **kw): self.logged.append(msg) + + +class ErrorRequest: + error = True # We are simulating a header parsing error + version = 1 + data = None + completed = True + empty = False + headers_finished = True + expect_continue = False + retval = None + connection_close = False + + def __init__(self, adj): + pass + + def received(self, data): + self.data = data + if self.retval is not None: + return self.retval + return len(data) + + def close(self): + pass + + +def make_quick_shutdown_channel(client, channels, shutdown=True): + class ShutdownChannel(HTTPChannel): + parser_class = ErrorRequest + + def __init__(self, server, sock, addr, adj, map=None): + self.count_writes = self.count_close = self.count_wouldblock = 0 + if shutdown: + client.close() + channels.append(self) + return HTTPChannel.__init__(self, server, sock, addr, adj, map) + + def handle_write(self): + self.count_writes += 1 + return HTTPChannel.handle_write(self) + + def handle_close(self): + # import pdb; pdb.set_trace() + self.count_close += 1 + return HTTPChannel.handle_close(self) + + return ShutdownChannel + + +def server_run_for_count(inst, count=1): + # Modified server run to prevent infinite loop + inst.asyncore.loop( + timeout=inst.adj.asyncore_loop_timeout, + map=inst._map, + use_poll=inst.adj.asyncore_use_poll, + count=count + ) +