From 2c8263207b221830d07be85fa6a467680fcbc17c Mon Sep 17 00:00:00 2001 From: Dylan Jay Date: Tue, 12 Sep 2023 00:04:53 +0700 Subject: [PATCH 01/20] ensure we don't keep trying to write to a channel thats not connected --- src/waitress/channel.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/waitress/channel.py b/src/waitress/channel.py index eb59dd3f..2d5db8b3 100644 --- a/src/waitress/channel.py +++ b/src/waitress/channel.py @@ -94,6 +94,8 @@ def handle_write(self): if not self.connected: # we dont want to close the channel twice + # But we shouldn't be written to if we really are closed so unregister from loop + self.del_channel() return From 26c38551de5d361c86beb2446efe0f89a0fd43b7 Mon Sep 17 00:00:00 2001 From: Dylan Jay Date: Tue, 12 Sep 2023 14:02:39 +0700 Subject: [PATCH 02/20] add test --- tests/test_channel.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/tests/test_channel.py b/tests/test_channel.py index 8467ae7a..f2c568f8 100644 --- a/tests/test_channel.py +++ b/tests/test_channel.py @@ -1,4 +1,6 @@ +from errno import EINVAL import io +import socket import unittest import pytest @@ -11,10 +13,10 @@ def _makeOne(self, sock, addr, adj, map=None): server = DummyServer() return HTTPChannel(server, sock, addr, adj=adj, map=map) - def _makeOneWithMap(self, adj=None): + def _makeOneWithMap(self, adj=None, sock_shutdown=False): if adj is None: adj = DummyAdjustments() - sock = DummySock() + sock = DummySock(shutdown=sock_shutdown) map = {} inst = self._makeOne(sock, "127.0.0.1", adj, map=map) inst.outbuf_lock = DummyLock() @@ -770,6 +772,15 @@ def test_cancel_with_requests(self): inst.cancel() self.assertEqual(inst.requests, []) + def test_shutdown_quick_loop(self): + inst, sock, map = self._makeOneWithMap(sock_shutdown=True) + # if sock.shutdown(socket.SHUT_RD) creating the dispatcher we will get a connected == False + self.assertRaises(OSError, sock.getpeername) + self.assertFalse(inst.connected) + self.assertTrue(inst._map) # still processing + inst.handle_write() # but still half connected so select will say it can write + self.assertFalse(inst._map, "channel should be removed so we don't loop and select socket again") + # self.assertTrue(sock.closed, "Should be close the channel instead?") class TestHTTPChannelLookahead(TestHTTPChannel): def app_check_disconnect(self, environ, start_response): @@ -906,8 +917,9 @@ class DummySock: blocking = False closed = False - def __init__(self): + def __init__(self, shutdown=False): self.sent = b"" + self.shutdown = shutdown def setblocking(self, *arg): self.blocking = True @@ -916,6 +928,8 @@ def fileno(self): return 100 def getpeername(self): + if self.shutdown: + raise OSError(EINVAL) return "127.0.0.1" def getsockopt(self, level, option): From 1b9e35c6a7d84f06e959264f50b6dcf3a1fdece5 Mon Sep 17 00:00:00 2001 From: Dylan Jay Date: Tue, 12 Sep 2023 18:07:43 +0700 Subject: [PATCH 03/20] use writable to avoid loop instead --- src/waitress/channel.py | 6 +++--- tests/test_channel.py | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/waitress/channel.py b/src/waitress/channel.py index 2d5db8b3..f28d83fb 100644 --- a/src/waitress/channel.py +++ b/src/waitress/channel.py @@ -86,7 +86,7 @@ def writable(self): # the channel (possibly by our server maintenance logic), run # handle_write - return self.total_outbufs_len or self.will_close or self.close_when_flushed + return self.connected and (self.total_outbufs_len or self.will_close or self.close_when_flushed) def handle_write(self): # Precondition: there's data in the out buffer to be sent, or @@ -95,8 +95,8 @@ def handle_write(self): if not self.connected: # we dont want to close the channel twice # But we shouldn't be written to if we really are closed so unregister from loop - self.del_channel() - + # self.del_channel() + #self.close_when_flushed = True return # try to flush any pending output diff --git a/tests/test_channel.py b/tests/test_channel.py index f2c568f8..95ae6433 100644 --- a/tests/test_channel.py +++ b/tests/test_channel.py @@ -778,8 +778,9 @@ def test_shutdown_quick_loop(self): self.assertRaises(OSError, sock.getpeername) self.assertFalse(inst.connected) self.assertTrue(inst._map) # still processing - inst.handle_write() # but still half connected so select will say it can write - self.assertFalse(inst._map, "channel should be removed so we don't loop and select socket again") + # inst.handle_write() # but still half connected so select will say it can write + # self.assertFalse(inst._map, "channel should be removed so we don't loop and select socket again") + self.assertTrue(all(not c.writable() for c in inst._map.values()), "if our channel is writable we can get into a loop") # self.assertTrue(sock.closed, "Should be close the channel instead?") class TestHTTPChannelLookahead(TestHTTPChannel): From 04056abdfa4f2722a6c6d4b5acc4b03592689987 Mon Sep 17 00:00:00 2001 From: Dylan Jay Date: Thu, 14 Sep 2023 10:58:21 +0700 Subject: [PATCH 04/20] add test for loop due to getpearname failing --- tests/test_server.py | 45 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/test_server.py b/tests/test_server.py index 6edc3b24..28654748 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,4 +1,5 @@ import errno +import select import socket import unittest @@ -311,6 +312,50 @@ def test_create_with_one_socket_handle_accept_noerror(self): self.assertEqual(innersock.opts, [("level", "optname", "value")]) self.assertEqual(L, [(inst, innersock, None, inst.adj)]) + def test_quick_shutdown(self): + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + sockets[0].bind(("127.0.0.1", 8000)) + sockets[0].listen() + + inst = self._makeWithSockets(_start=False, sockets=sockets) + from waitress.channel import HTTPChannel + inst.channel_class = HTTPChannel + + client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + client.connect(("127.0.0.1", 8000)) + client.shutdown(socket.SHUT_RD) + inst.handle_accept() + + channel = list(iter(inst._map.values()))[-1] + self.assertEqual(channel.__class__, HTTPChannel) + self.assertEqual(channel.socket.getpeername(), "") + self.assertRaises(OSError, channel.socket.getpeername) + self.assertFalse(channel.connected, "race condition means our socket is marked not connected") + + inst.task_dispatcher = DummyTaskDispatcher() + selects = 0 + orig_select = select.select + + def counting_select(r, w, e, timeout): + nonlocal selects + rr, wr, er = orig_select(r, w, e, timeout) + if rr or wr or er: + selects += 1 + return rr, wr, er + + select.select = counting_select + + # Modified server run + inst.asyncore.loop( + timeout=inst.adj.asyncore_loop_timeout, + map=inst._map, + use_poll=inst.adj.asyncore_use_poll, + count=2 + ) + select.select = orig_select + sockets[0].close() + self.assertEqual(selects, 0, "ensure we aren't in a loop trying to write but can't") + if hasattr(socket, "AF_UNIX"): From 5130bec0501567c809397cc3d8003c0af11fb46f Mon Sep 17 00:00:00 2001 From: Dylan Jay Date: Thu, 14 Sep 2023 11:24:01 +0700 Subject: [PATCH 05/20] make test shutdown right after accept --- tests/test_server.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/test_server.py b/tests/test_server.py index 28654748..12d02def 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,6 +1,7 @@ import errno import select import socket +from time import sleep import unittest dummy_app = object() @@ -316,19 +317,24 @@ def test_quick_shutdown(self): sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] sockets[0].bind(("127.0.0.1", 8000)) sockets[0].listen() + client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) inst = self._makeWithSockets(_start=False, sockets=sockets) from waitress.channel import HTTPChannel - inst.channel_class = HTTPChannel + class ShutdownChannel(HTTPChannel): + def __init__(self, server, sock, addr, adj, map=None): + client.shutdown(socket.SHUT_RD) + client.close() + sleep(3) + return HTTPChannel.__init__(self, server, sock, addr, adj, map) + inst.channel_class = ShutdownChannel - client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) client.connect(("127.0.0.1", 8000)) - client.shutdown(socket.SHUT_RD) inst.handle_accept() channel = list(iter(inst._map.values()))[-1] - self.assertEqual(channel.__class__, HTTPChannel) - self.assertEqual(channel.socket.getpeername(), "") + self.assertEqual(channel.__class__, ShutdownChannel) + # self.assertEqual(channel.socket.getpeername(), "") self.assertRaises(OSError, channel.socket.getpeername) self.assertFalse(channel.connected, "race condition means our socket is marked not connected") From ea3cb3a1d4b85bc9f9d3de459619824179f4f969 Mon Sep 17 00:00:00 2001 From: Dylan Jay Date: Fri, 15 Sep 2023 09:36:24 +0700 Subject: [PATCH 06/20] try keepalive to make connection close quickly --- tests/test_server.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/test_server.py b/tests/test_server.py index 12d02def..9d2254d3 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -316,15 +316,22 @@ def test_create_with_one_socket_handle_accept_noerror(self): def test_quick_shutdown(self): sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] sockets[0].bind(("127.0.0.1", 8000)) + sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + # sockets[0].setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 1) + sockets[0].setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 1) + sockets[0].setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5) sockets[0].listen() client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + client.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + client.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 1) + client.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5) inst = self._makeWithSockets(_start=False, sockets=sockets) from waitress.channel import HTTPChannel class ShutdownChannel(HTTPChannel): def __init__(self, server, sock, addr, adj, map=None): - client.shutdown(socket.SHUT_RD) - client.close() + client.shutdown(socket.SHUT_WR) + # client.close() sleep(3) return HTTPChannel.__init__(self, server, sock, addr, adj, map) inst.channel_class = ShutdownChannel @@ -335,7 +342,7 @@ def __init__(self, server, sock, addr, adj, map=None): channel = list(iter(inst._map.values()))[-1] self.assertEqual(channel.__class__, ShutdownChannel) # self.assertEqual(channel.socket.getpeername(), "") - self.assertRaises(OSError, channel.socket.getpeername) + self.assertRaises(Exception, channel.socket.getpeername) self.assertFalse(channel.connected, "race condition means our socket is marked not connected") inst.task_dispatcher = DummyTaskDispatcher() From a71f5c7d7e7c1d9dc2792b8c77c3ffb4568d0bd4 Mon Sep 17 00:00:00 2001 From: Dylan Jay Date: Fri, 15 Sep 2023 10:17:09 +0700 Subject: [PATCH 07/20] simplify test --- tests/test_server.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/tests/test_server.py b/tests/test_server.py index 9d2254d3..dfa162c1 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -330,11 +330,17 @@ def test_quick_shutdown(self): from waitress.channel import HTTPChannel class ShutdownChannel(HTTPChannel): def __init__(self, server, sock, addr, adj, map=None): + self.count_writes = 0 client.shutdown(socket.SHUT_WR) - # client.close() + client.close() sleep(3) return HTTPChannel.__init__(self, server, sock, addr, adj, map) + def handle_write(self): + self.count_writes += 1 + return HTTPChannel.handle_write(self) + inst.channel_class = ShutdownChannel + inst.task_dispatcher = DummyTaskDispatcher() client.connect(("127.0.0.1", 8000)) inst.handle_accept() @@ -345,18 +351,7 @@ def __init__(self, server, sock, addr, adj, map=None): self.assertRaises(Exception, channel.socket.getpeername) self.assertFalse(channel.connected, "race condition means our socket is marked not connected") - inst.task_dispatcher = DummyTaskDispatcher() - selects = 0 - orig_select = select.select - - def counting_select(r, w, e, timeout): - nonlocal selects - rr, wr, er = orig_select(r, w, e, timeout) - if rr or wr or er: - selects += 1 - return rr, wr, er - - select.select = counting_select + # Modified server run inst.asyncore.loop( @@ -365,9 +360,8 @@ def counting_select(r, w, e, timeout): use_poll=inst.adj.asyncore_use_poll, count=2 ) - select.select = orig_select sockets[0].close() - self.assertEqual(selects, 0, "ensure we aren't in a loop trying to write but can't") + self.assertEqual(channel.count_writes, 0, "ensure we aren't in a loop trying to write but can't") if hasattr(socket, "AF_UNIX"): From deea18a8d424b37e581af98fea920b5ecdd5864a Mon Sep 17 00:00:00 2001 From: Dylan Jay Date: Sun, 17 Sep 2023 13:03:12 +0700 Subject: [PATCH 08/20] reproduce but with 0 sleep. --- tests/test_server.py | 53 +++++++++++++++++++++++++++++++------------- 1 file changed, 38 insertions(+), 15 deletions(-) diff --git a/tests/test_server.py b/tests/test_server.py index dfa162c1..7068f099 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,7 +1,10 @@ import errno import select import socket +import struct +from threading import Event from time import sleep +import time import unittest dummy_app = object() @@ -314,54 +317,74 @@ def test_create_with_one_socket_handle_accept_noerror(self): self.assertEqual(L, [(inst, innersock, None, inst.adj)]) def test_quick_shutdown(self): + """ Issue found in production that led to 100% useage because getpeername failed after accept but before channel setup. + """ sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + # sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + # # sockets[0].setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 1) + # sockets[0].setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 1) + # sockets[0].setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5) sockets[0].bind(("127.0.0.1", 8000)) - sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - # sockets[0].setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 1) - sockets[0].setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 1) - sockets[0].setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5) sockets[0].listen() client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - client.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - client.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 1) - client.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5) + # client.settimeout(.2) + client.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack('ii', 1, 0)) + # + # client.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + # client.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 1) + # client.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 1) inst = self._makeWithSockets(_start=False, sockets=sockets) from waitress.channel import HTTPChannel + class ShutdownChannel(HTTPChannel): def __init__(self, server, sock, addr, adj, map=None): - self.count_writes = 0 - client.shutdown(socket.SHUT_WR) - client.close() - sleep(3) + self.count_writes = self.count_close = 0 + # client.shutdown(socket.SHUT_RDWR) # has to be here to reproduce. just RD or WR won't work + client.close() # has to be here to reproduce + # sleep(1) # has to be at least 65s to reproduce + # start = time.time() + # with open("/dev/tty", "w") as out: + # while True: + # try: sock.getpeername() + # except OSError: + # print("broken", int(time.time() - start), file=out) + # break + # else: print("not yet broken", int(time.time() - start), file=out); sleep(1) return HTTPChannel.__init__(self, server, sock, addr, adj, map) + def handle_write(self): self.count_writes += 1 return HTTPChannel.handle_write(self) + def handle_close(self): + self.count_close += 1 + return HTTPChannel.handle_close(self) + inst.channel_class = ShutdownChannel inst.task_dispatcher = DummyTaskDispatcher() client.connect(("127.0.0.1", 8000)) inst.handle_accept() + self.assertEqual(len(inst._map.values()), 3) channel = list(iter(inst._map.values()))[-1] self.assertEqual(channel.__class__, ShutdownChannel) # self.assertEqual(channel.socket.getpeername(), "") self.assertRaises(Exception, channel.socket.getpeername) self.assertFalse(channel.connected, "race condition means our socket is marked not connected") - - - # Modified server run + # Modified server run to prevent infinite loop inst.asyncore.loop( timeout=inst.adj.asyncore_loop_timeout, map=inst._map, use_poll=inst.adj.asyncore_use_poll, - count=2 + count=5 ) sockets[0].close() self.assertEqual(channel.count_writes, 0, "ensure we aren't in a loop trying to write but can't") + self.assertEqual(channel.count_close, 0, "but also this connection never gets closed") if hasattr(socket, "AF_UNIX"): From f95d396353f1e266d51abced0a0abd6fdd1281f2 Mon Sep 17 00:00:00 2001 From: Dylan Jay Date: Mon, 18 Sep 2023 09:31:30 +0700 Subject: [PATCH 09/20] reproduces loop but doesn't explain how data got written --- tests/test_server.py | 56 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 10 deletions(-) diff --git a/tests/test_server.py b/tests/test_server.py index 7068f099..b59ecd42 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -7,6 +7,8 @@ import time import unittest +from tests.test_channel import DummyParser + dummy_app = object() @@ -320,7 +322,8 @@ def test_quick_shutdown(self): """ Issue found in production that led to 100% useage because getpeername failed after accept but before channel setup. """ sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] - sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 0) + # sockets[0].settimeout(.2) # sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) # # sockets[0].setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 1) # sockets[0].setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 1) @@ -337,27 +340,59 @@ def test_quick_shutdown(self): inst = self._makeWithSockets(_start=False, sockets=sockets) from waitress.channel import HTTPChannel + class DummyParser: + version = 1 + data = None + completed = True + empty = False + headers_finished = False + expect_continue = False + retval = None + error = False + connection_close = False + def __init__(self, adj): + pass + + def received(self, data): + self.data = data + if self.retval is not None: + return self.retval + return len(data) class ShutdownChannel(HTTPChannel): + # parser_class = DummyParser def __init__(self, server, sock, addr, adj, map=None): self.count_writes = self.count_close = 0 - # client.shutdown(socket.SHUT_RDWR) # has to be here to reproduce. just RD or WR won't work + # sleep(5) + client.shutdown(socket.SHUT_RDWR) # has to be here to reproduce. just RD or WR won't work + #client.recv(1) client.close() # has to be here to reproduce # sleep(1) # has to be at least 65s to reproduce - # start = time.time() - # with open("/dev/tty", "w") as out: - # while True: - # try: sock.getpeername() - # except OSError: - # print("broken", int(time.time() - start), file=out) - # break - # else: print("not yet broken", int(time.time() - start), file=out); sleep(1) + start = time.time() + with open("/dev/tty", "w") as out: + while True: + try: sock.getpeername() + except OSError: + print("broken", int(time.time() - start), file=out) + break + else: print("not yet broken", int(time.time() - start), file=out); sleep(1) return HTTPChannel.__init__(self, server, sock, addr, adj, map) def handle_write(self): self.count_writes += 1 return HTTPChannel.handle_write(self) + def received(self, data): + res = HTTPChannel.received(self, data) + if data: + # Fake app returning data fast + self.total_outbufs_len = 1 + # import pdb; pdb.set_trace() + # self.request.completed = True + # self.requests.append(DummyParser()) + pass + return res + def handle_close(self): self.count_close += 1 return HTTPChannel.handle_close(self) @@ -366,6 +401,7 @@ def handle_close(self): inst.task_dispatcher = DummyTaskDispatcher() client.connect(("127.0.0.1", 8000)) + client.send(b"1") inst.handle_accept() self.assertEqual(len(inst._map.values()), 3) From fb5f7e998999c19d2abc8e737813a75a5ed796f0 Mon Sep 17 00:00:00 2001 From: Dylan Jay Date: Mon, 18 Sep 2023 11:19:00 +0700 Subject: [PATCH 10/20] reproduce loop with send continue but fake EWOULDBLOCK --- tests/test_server.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/test_server.py b/tests/test_server.py index b59ecd42..a01deffc 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -345,8 +345,8 @@ class DummyParser: data = None completed = True empty = False - headers_finished = False - expect_continue = False + headers_finished = True + expect_continue = True retval = None error = False connection_close = False @@ -360,7 +360,7 @@ def received(self, data): return len(data) class ShutdownChannel(HTTPChannel): - # parser_class = DummyParser + parser_class = DummyParser def __init__(self, server, sock, addr, adj, map=None): self.count_writes = self.count_close = 0 # sleep(5) @@ -383,15 +383,21 @@ def handle_write(self): return HTTPChannel.handle_write(self) def received(self, data): + #import pdb; pdb.set_trace() res = HTTPChannel.received(self, data) if data: # Fake app returning data fast - self.total_outbufs_len = 1 - # import pdb; pdb.set_trace() + # self.total_outbufs_len = 1 + # Happens if send can't send all the data + #import pdb; pdb.set_trace() + #self.write_soon(b"1"*11025) + #assert self.total_outbufs_len # self.request.completed = True # self.requests.append(DummyParser()) pass return res + def send(self, data, do_close=True): + return 0 def handle_close(self): self.count_close += 1 @@ -402,6 +408,7 @@ def handle_close(self): client.connect(("127.0.0.1", 8000)) client.send(b"1") + client.send(b"1") inst.handle_accept() self.assertEqual(len(inst._map.values()), 3) From 23e7b05748d1909afd73efb83290274c90b3051d Mon Sep 17 00:00:00 2001 From: Dylan Jay Date: Mon, 18 Sep 2023 17:30:22 +0700 Subject: [PATCH 11/20] reproduce loop with request error quick close. --- tests/test_server.py | 37 ++++++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/tests/test_server.py b/tests/test_server.py index a01deffc..d8bcaed3 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -346,9 +346,9 @@ class DummyParser: completed = True empty = False headers_finished = True - expect_continue = True + expect_continue = False retval = None - error = False + error = True connection_close = False def __init__(self, adj): pass @@ -357,12 +357,16 @@ def received(self, data): self.data = data if self.retval is not None: return self.retval - return len(data) + #self.expect_continue = not self.expect_continue + #self.completed = not self.completed + return 1 + def close(self): + pass class ShutdownChannel(HTTPChannel): parser_class = DummyParser def __init__(self, server, sock, addr, adj, map=None): - self.count_writes = self.count_close = 0 + self.count_writes = self.count_close = self.count_wouldblock = 0 # sleep(5) client.shutdown(socket.SHUT_RDWR) # has to be here to reproduce. just RD or WR won't work #client.recv(1) @@ -383,7 +387,7 @@ def handle_write(self): return HTTPChannel.handle_write(self) def received(self, data): - #import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() res = HTTPChannel.received(self, data) if data: # Fake app returning data fast @@ -397,9 +401,16 @@ def received(self, data): pass return res def send(self, data, do_close=True): - return 0 + # fake EWOULDBLOCK where socket buffers are filled up. but how? + # return 0 + res = HTTPChannel.send(self, data, do_close) + if res < len(data) and not self.count_close: + self.count_wouldblock += 1 + # import pdb; pdb.set_trace() + return res def handle_close(self): + # import pdb; pdb.set_trace() self.count_close += 1 return HTTPChannel.handle_close(self) @@ -407,8 +418,8 @@ def handle_close(self): inst.task_dispatcher = DummyTaskDispatcher() client.connect(("127.0.0.1", 8000)) - client.send(b"1") - client.send(b"1") + for i in range(0, 1): + client.send(b"1") inst.handle_accept() self.assertEqual(len(inst._map.values()), 3) @@ -425,7 +436,15 @@ def handle_close(self): use_poll=inst.adj.asyncore_use_poll, count=5 ) - sockets[0].close() + channel.service() + inst.asyncore.loop( + timeout=inst.adj.asyncore_loop_timeout, + map=inst._map, + use_poll=inst.adj.asyncore_use_poll, + count=5 + ) + #sockets[0].close() + # self.assertEqual(channel.count_wouldblock, 1, "we need data left to send to be in a loop") self.assertEqual(channel.count_writes, 0, "ensure we aren't in a loop trying to write but can't") self.assertEqual(channel.count_close, 0, "but also this connection never gets closed") From 9c564721fa2031935810003c56da4c0959783edc Mon Sep 17 00:00:00 2001 From: Dylan Jay Date: Mon, 18 Sep 2023 17:55:49 +0700 Subject: [PATCH 12/20] comment out parts of test not used --- tests/test_server.py | 92 ++++++++++++++++++++++---------------------- 1 file changed, 45 insertions(+), 47 deletions(-) diff --git a/tests/test_server.py b/tests/test_server.py index d8bcaed3..d110da94 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -368,52 +368,61 @@ class ShutdownChannel(HTTPChannel): def __init__(self, server, sock, addr, adj, map=None): self.count_writes = self.count_close = self.count_wouldblock = 0 # sleep(5) - client.shutdown(socket.SHUT_RDWR) # has to be here to reproduce. just RD or WR won't work + #client.shutdown(socket.SHUT_RDWR) # has to be here to reproduce. just RD or WR won't work #client.recv(1) - client.close() # has to be here to reproduce + client.close() # simulate race condition where close happens between accept adn getpeername # sleep(1) # has to be at least 65s to reproduce - start = time.time() - with open("/dev/tty", "w") as out: - while True: - try: sock.getpeername() - except OSError: - print("broken", int(time.time() - start), file=out) - break - else: print("not yet broken", int(time.time() - start), file=out); sleep(1) + # start = time.time() + # with open("/dev/tty", "w") as out: + # while True: + # try: sock.getpeername() + # except OSError: + # print("broken", int(time.time() - start), file=out) + # break + # else: print("not yet broken", int(time.time() - start), file=out); sleep(1) return HTTPChannel.__init__(self, server, sock, addr, adj, map) def handle_write(self): self.count_writes += 1 return HTTPChannel.handle_write(self) - def received(self, data): - # import pdb; pdb.set_trace() - res = HTTPChannel.received(self, data) - if data: - # Fake app returning data fast - # self.total_outbufs_len = 1 - # Happens if send can't send all the data - #import pdb; pdb.set_trace() - #self.write_soon(b"1"*11025) - #assert self.total_outbufs_len - # self.request.completed = True - # self.requests.append(DummyParser()) - pass - return res - def send(self, data, do_close=True): - # fake EWOULDBLOCK where socket buffers are filled up. but how? - # return 0 - res = HTTPChannel.send(self, data, do_close) - if res < len(data) and not self.count_close: - self.count_wouldblock += 1 - # import pdb; pdb.set_trace() - return res + # def received(self, data): + # # import pdb; pdb.set_trace() + # res = HTTPChannel.received(self, data) + # if data: + # # Fake app returning data fast + # # self.total_outbufs_len = 1 + # # Happens if send can't send all the data + # #import pdb; pdb.set_trace() + # #self.write_soon(b"1"*11025) + # #assert self.total_outbufs_len + # # self.request.completed = True + # # self.requests.append(DummyParser()) + # pass + # return res + # def send(self, data, do_close=True): + # # fake EWOULDBLOCK where socket buffers are filled up. but how? + # # return 0 + # res = HTTPChannel.send(self, data, do_close) + # if res < len(data) and not self.count_close: + # self.count_wouldblock += 1 + # # import pdb; pdb.set_trace() + # return res def handle_close(self): # import pdb; pdb.set_trace() self.count_close += 1 return HTTPChannel.handle_close(self) + def server_run(count=1): + # Modified server run to prevent infinite loop + inst.asyncore.loop( + timeout=inst.adj.asyncore_loop_timeout, + map=inst._map, + use_poll=inst.adj.asyncore_use_poll, + count=count + ) + inst.channel_class = ShutdownChannel inst.task_dispatcher = DummyTaskDispatcher() @@ -429,21 +438,10 @@ def handle_close(self): self.assertRaises(Exception, channel.socket.getpeername) self.assertFalse(channel.connected, "race condition means our socket is marked not connected") - # Modified server run to prevent infinite loop - inst.asyncore.loop( - timeout=inst.adj.asyncore_loop_timeout, - map=inst._map, - use_poll=inst.adj.asyncore_use_poll, - count=5 - ) - channel.service() - inst.asyncore.loop( - timeout=inst.adj.asyncore_loop_timeout, - map=inst._map, - use_poll=inst.adj.asyncore_use_poll, - count=5 - ) - #sockets[0].close() + server_run(1) + channel.service() # Our error request sets close_after_flushed + server_run(5) + # self.assertEqual(channel.count_wouldblock, 1, "we need data left to send to be in a loop") self.assertEqual(channel.count_writes, 0, "ensure we aren't in a loop trying to write but can't") self.assertEqual(channel.count_close, 0, "but also this connection never gets closed") From 97620b34cfe384e6cd621fce6e71b830bc93cac0 Mon Sep 17 00:00:00 2001 From: Dylan Jay Date: Mon, 18 Sep 2023 22:47:19 +0700 Subject: [PATCH 13/20] fix looping --- src/waitress/channel.py | 11 +++++------ tests/test_server.py | 4 ++-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/waitress/channel.py b/src/waitress/channel.py index f28d83fb..1cdf342f 100644 --- a/src/waitress/channel.py +++ b/src/waitress/channel.py @@ -67,6 +67,9 @@ def __init__(self, server, sock, addr, adj, map=None): self.outbuf_lock = threading.Condition() wasyncore.dispatcher.__init__(self, sock, map=map) + if not self.connected: + # Sometimes can be closed quickly and getpeername fails. + self.handle_close() # Don't let wasyncore.dispatcher throttle self.addr on us. self.addr = addr @@ -86,17 +89,14 @@ def writable(self): # the channel (possibly by our server maintenance logic), run # handle_write - return self.connected and (self.total_outbufs_len or self.will_close or self.close_when_flushed) + return (self.total_outbufs_len or self.will_close or self.close_when_flushed) def handle_write(self): # Precondition: there's data in the out buffer to be sent, or # there's a pending will_close request - if not self.connected: + if not self.connected and not self.close_when_flushed: # we dont want to close the channel twice - # But we shouldn't be written to if we really are closed so unregister from loop - # self.del_channel() - #self.close_when_flushed = True return # try to flush any pending output @@ -152,7 +152,6 @@ def readable(self): # 3. There are not too many tasks already queued # 4. There's no data in the output buffer that needs to be sent # before we potentially create a new task. - return not ( self.will_close or self.close_when_flushed diff --git a/tests/test_server.py b/tests/test_server.py index d110da94..2766c194 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -368,7 +368,7 @@ class ShutdownChannel(HTTPChannel): def __init__(self, server, sock, addr, adj, map=None): self.count_writes = self.count_close = self.count_wouldblock = 0 # sleep(5) - #client.shutdown(socket.SHUT_RDWR) # has to be here to reproduce. just RD or WR won't work + # client.shutdown(socket.SHUT_RDWR) # has to be here to reproduce. just RD or WR won't work #client.recv(1) client.close() # simulate race condition where close happens between accept adn getpeername # sleep(1) # has to be at least 65s to reproduce @@ -438,7 +438,7 @@ def server_run(count=1): self.assertRaises(Exception, channel.socket.getpeername) self.assertFalse(channel.connected, "race condition means our socket is marked not connected") - server_run(1) + server_run(5) channel.service() # Our error request sets close_after_flushed server_run(5) From a3fda8cbaae9bb20f48e5062b1792c3439eff43c Mon Sep 17 00:00:00 2001 From: Dylan Jay Date: Mon, 18 Sep 2023 23:16:07 +0700 Subject: [PATCH 14/20] remove test that didn't work --- tests/test_channel.py | 10 ---------- tests/test_server.py | 2 +- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/tests/test_channel.py b/tests/test_channel.py index 95ae6433..a501123b 100644 --- a/tests/test_channel.py +++ b/tests/test_channel.py @@ -772,16 +772,6 @@ def test_cancel_with_requests(self): inst.cancel() self.assertEqual(inst.requests, []) - def test_shutdown_quick_loop(self): - inst, sock, map = self._makeOneWithMap(sock_shutdown=True) - # if sock.shutdown(socket.SHUT_RD) creating the dispatcher we will get a connected == False - self.assertRaises(OSError, sock.getpeername) - self.assertFalse(inst.connected) - self.assertTrue(inst._map) # still processing - # inst.handle_write() # but still half connected so select will say it can write - # self.assertFalse(inst._map, "channel should be removed so we don't loop and select socket again") - self.assertTrue(all(not c.writable() for c in inst._map.values()), "if our channel is writable we can get into a loop") - # self.assertTrue(sock.closed, "Should be close the channel instead?") class TestHTTPChannelLookahead(TestHTTPChannel): def app_check_disconnect(self, environ, start_response): diff --git a/tests/test_server.py b/tests/test_server.py index 2766c194..710af63b 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -443,7 +443,7 @@ def server_run(count=1): server_run(5) # self.assertEqual(channel.count_wouldblock, 1, "we need data left to send to be in a loop") - self.assertEqual(channel.count_writes, 0, "ensure we aren't in a loop trying to write but can't") + self.assertEqual(channel.count_writes > 1, "ensure we aren't in a loop trying to write but can't") self.assertEqual(channel.count_close, 0, "but also this connection never gets closed") From 73a25c9da90fbf00c47a771ae0a5f93d1b0b3177 Mon Sep 17 00:00:00 2001 From: Dylan Jay Date: Tue, 19 Sep 2023 10:52:40 +0700 Subject: [PATCH 15/20] clean up test and make pass --- tests/test_server.py | 109 ++++++++++++++----------------------------- 1 file changed, 34 insertions(+), 75 deletions(-) diff --git a/tests/test_server.py b/tests/test_server.py index 710af63b..765de1b1 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -321,26 +321,10 @@ def test_create_with_one_socket_handle_accept_noerror(self): def test_quick_shutdown(self): """ Issue found in production that led to 100% useage because getpeername failed after accept but before channel setup. """ - sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] - sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 0) - # sockets[0].settimeout(.2) - # sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - # # sockets[0].setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 1) - # sockets[0].setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 1) - # sockets[0].setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5) - sockets[0].bind(("127.0.0.1", 8000)) - sockets[0].listen() - client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - # client.settimeout(.2) - client.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack('ii', 1, 0)) - # - # client.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - # client.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 1) - # client.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 1) - inst = self._makeWithSockets(_start=False, sockets=sockets) - from waitress.channel import HTTPChannel class DummyParser: + error = True # We are simulating a header parsing error + version = 1 data = None completed = True @@ -348,8 +332,8 @@ class DummyParser: headers_finished = True expect_continue = False retval = None - error = True connection_close = False + def __init__(self, adj): pass @@ -357,58 +341,26 @@ def received(self, data): self.data = data if self.retval is not None: return self.retval - #self.expect_continue = not self.expect_continue - #self.completed = not self.completed - return 1 + return len(data) + def close(self): pass + from waitress.channel import HTTPChannel + client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + class ShutdownChannel(HTTPChannel): parser_class = DummyParser + def __init__(self, server, sock, addr, adj, map=None): self.count_writes = self.count_close = self.count_wouldblock = 0 - # sleep(5) - # client.shutdown(socket.SHUT_RDWR) # has to be here to reproduce. just RD or WR won't work - #client.recv(1) client.close() # simulate race condition where close happens between accept adn getpeername - # sleep(1) # has to be at least 65s to reproduce - # start = time.time() - # with open("/dev/tty", "w") as out: - # while True: - # try: sock.getpeername() - # except OSError: - # print("broken", int(time.time() - start), file=out) - # break - # else: print("not yet broken", int(time.time() - start), file=out); sleep(1) return HTTPChannel.__init__(self, server, sock, addr, adj, map) def handle_write(self): self.count_writes += 1 return HTTPChannel.handle_write(self) - # def received(self, data): - # # import pdb; pdb.set_trace() - # res = HTTPChannel.received(self, data) - # if data: - # # Fake app returning data fast - # # self.total_outbufs_len = 1 - # # Happens if send can't send all the data - # #import pdb; pdb.set_trace() - # #self.write_soon(b"1"*11025) - # #assert self.total_outbufs_len - # # self.request.completed = True - # # self.requests.append(DummyParser()) - # pass - # return res - # def send(self, data, do_close=True): - # # fake EWOULDBLOCK where socket buffers are filled up. but how? - # # return 0 - # res = HTTPChannel.send(self, data, do_close) - # if res < len(data) and not self.count_close: - # self.count_wouldblock += 1 - # # import pdb; pdb.set_trace() - # return res - def handle_close(self): # import pdb; pdb.set_trace() self.count_close += 1 @@ -423,28 +375,35 @@ def server_run(count=1): count=count ) + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 0) + sockets[0].bind(("127.0.0.1", 8000)) + sockets[0].listen() + inst = self._makeWithSockets(_start=False, sockets=sockets) inst.channel_class = ShutdownChannel inst.task_dispatcher = DummyTaskDispatcher() + # This will make getpeername fail fast with EINVAL OSError + client.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack('ii', 1, 0)) client.connect(("127.0.0.1", 8000)) - for i in range(0, 1): - client.send(b"1") - inst.handle_accept() - - self.assertEqual(len(inst._map.values()), 3) - channel = list(iter(inst._map.values()))[-1] - self.assertEqual(channel.__class__, ShutdownChannel) - # self.assertEqual(channel.socket.getpeername(), "") - self.assertRaises(Exception, channel.socket.getpeername) - self.assertFalse(channel.connected, "race condition means our socket is marked not connected") - - server_run(5) - channel.service() # Our error request sets close_after_flushed - server_run(5) - - # self.assertEqual(channel.count_wouldblock, 1, "we need data left to send to be in a loop") - self.assertEqual(channel.count_writes > 1, "ensure we aren't in a loop trying to write but can't") - self.assertEqual(channel.count_close, 0, "but also this connection never gets closed") + client.send(b"1") # Send our fake request before we accept and close the connection + inst.handle_accept() # ShutdownServer will close the connection after access but before getpeername + self.assertRaises(OSError, sockets[0].getpeername) + self.assertEqual(len(inst._map.values()), 2, "3 means we didn't get an automatic close") + + # To reproduce previous looping behaviour uncomment + # channel = list(iter(inst._map.values()))[3] + # self.assertEqual(channel.__class__, ShutdownChannel) + # self.assertFalse(channel.connected, "race condition means our socket is marked not connected") + + # server_run(5) # Read the request + # self.assertTrue(channel.request.error, "Error will cause a close") + # channel.service() + # self.assertTrue(channel.close_after_flushed, "This prevents reads and loops trying to write but can't") + # server_run(5) # Our loop + # # self.assertEqual(channel.count_wouldblock, 1, "we need data left to send to be in a loop") + # self.assertEqual(channel.count_writes > 1, "ensure we aren't in a loop trying to write but can't") + # self.assertEqual(channel.count_close, 0, "but also this connection never gets closed") if hasattr(socket, "AF_UNIX"): From f078e459270aea172d5aeb92bb769065be04711d Mon Sep 17 00:00:00 2001 From: Dylan Jay Date: Tue, 19 Sep 2023 13:58:44 +0700 Subject: [PATCH 16/20] fix comments --- tests/test_server.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_server.py b/tests/test_server.py index 765de1b1..1158ffb8 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -387,7 +387,7 @@ def server_run(count=1): client.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack('ii', 1, 0)) client.connect(("127.0.0.1", 8000)) client.send(b"1") # Send our fake request before we accept and close the connection - inst.handle_accept() # ShutdownServer will close the connection after access but before getpeername + inst.handle_accept() # ShutdownServer will close the connection after acceot but before getpeername self.assertRaises(OSError, sockets[0].getpeername) self.assertEqual(len(inst._map.values()), 2, "3 means we didn't get an automatic close") @@ -398,11 +398,13 @@ def server_run(count=1): # server_run(5) # Read the request # self.assertTrue(channel.request.error, "Error will cause a close") + # # channel_request_lookahead > 0 would avoid this bug + # self.assertTrue(len(channel.requests) <= channel.adj.channel_request_lookahead, "channel_request_lookahead == 0 means we don't read the disconnect") + # # simulate thread processing the request # channel.service() - # self.assertTrue(channel.close_after_flushed, "This prevents reads and loops trying to write but can't") + # self.assertTrue(channel.close_after_flushed, "This prevents reads (which lead to close) and loops on handle_write (with 100% CPU)") # server_run(5) # Our loop - # # self.assertEqual(channel.count_wouldblock, 1, "we need data left to send to be in a loop") - # self.assertEqual(channel.count_writes > 1, "ensure we aren't in a loop trying to write but can't") + # self.assertEqual(channel.count_writes > 1, "We're supposed to be in a loop trying to write but can't") # self.assertEqual(channel.count_close, 0, "but also this connection never gets closed") From ae1bf12bfbd4b5df8d0f5ff38e1facb9da32f3e3 Mon Sep 17 00:00:00 2001 From: Dylan Jay Date: Wed, 20 Sep 2023 11:12:37 +0700 Subject: [PATCH 17/20] also fix maintenance not cleaning up broken channel --- src/waitress/channel.py | 5 +++-- tests/test_server.py | 34 ++++++++++++++++++++-------------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/src/waitress/channel.py b/src/waitress/channel.py index 1cdf342f..a37c3a7f 100644 --- a/src/waitress/channel.py +++ b/src/waitress/channel.py @@ -95,8 +95,9 @@ def handle_write(self): # Precondition: there's data in the out buffer to be sent, or # there's a pending will_close request - if not self.connected and not self.close_when_flushed: - # we dont want to close the channel twice + if not self.connected and not (self.will_close or self.close_when_flushed): + # we dont want to close the channel twice. + # but we need let the channel close if it's marked to close return # try to flush any pending output diff --git a/tests/test_server.py b/tests/test_server.py index 1158ffb8..115c605e 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -7,8 +7,6 @@ import time import unittest -from tests.test_channel import DummyParser - dummy_app = object() @@ -321,10 +319,8 @@ def test_create_with_one_socket_handle_accept_noerror(self): def test_quick_shutdown(self): """ Issue found in production that led to 100% useage because getpeername failed after accept but before channel setup. """ - class DummyParser: error = True # We are simulating a header parsing error - version = 1 data = None completed = True @@ -348,12 +344,15 @@ def close(self): from waitress.channel import HTTPChannel client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + channel = None class ShutdownChannel(HTTPChannel): parser_class = DummyParser def __init__(self, server, sock, addr, adj, map=None): self.count_writes = self.count_close = self.count_wouldblock = 0 + nonlocal channel + channel = self client.close() # simulate race condition where close happens between accept adn getpeername return HTTPChannel.__init__(self, server, sock, addr, adj, map) @@ -389,23 +388,30 @@ def server_run(count=1): client.send(b"1") # Send our fake request before we accept and close the connection inst.handle_accept() # ShutdownServer will close the connection after acceot but before getpeername self.assertRaises(OSError, sockets[0].getpeername) - self.assertEqual(len(inst._map.values()), 2, "3 means we didn't get an automatic close") + self.assertFalse(channel.connected, "race condition means our socket is marked not connected") + self.assertNotIn(channel, inst._map.values(), "we should get an automatic close") - # To reproduce previous looping behaviour uncomment - # channel = list(iter(inst._map.values()))[3] - # self.assertEqual(channel.__class__, ShutdownChannel) - # self.assertFalse(channel.connected, "race condition means our socket is marked not connected") + # UNCOMMENT: To reproduce previous 100% CPU looping behaviour + # self.assertIn(channel, inst._map.values(), "broken request still active to get this bug") - # server_run(5) # Read the request - # self.assertTrue(channel.request.error, "Error will cause a close") + # server_run(1) # Read the request + # self.assertTrue(channel.requests[0].error, "for this bug we need the request to have a parsing error") + # server_run(5) + # self.assertIn(channel, inst._map.values(), "our rchannel doesn't get read and closed") # # channel_request_lookahead > 0 would avoid this bug - # self.assertTrue(len(channel.requests) <= channel.adj.channel_request_lookahead, "channel_request_lookahead == 0 means we don't read the disconnect") + # self.assertTrue(len(channel.requests) > channel.adj.channel_request_lookahead, "channel_request_lookahead == 0 means we don't read the disconnect") # # simulate thread processing the request # channel.service() - # self.assertTrue(channel.close_after_flushed, "This prevents reads (which lead to close) and loops on handle_write (with 100% CPU)") + # self.assertTrue(channel.close_when_flushed, "This prevents reads (which lead to close) and loops on handle_write (with 100% CPU)") # server_run(5) # Our loop - # self.assertEqual(channel.count_writes > 1, "We're supposed to be in a loop trying to write but can't") + # self.assertEqual(channel.count_writes, 5, "We're supposed to be in a loop trying to write but can't") # self.assertEqual(channel.count_close, 0, "but also this connection never gets closed") + # # But shouldn't maintenance clear this up? + # channel.last_activity = 0 + # inst.maintenance(1000) + # self.assertEqual(channel.will_close, 1, "maintenance will try to close it") + # server_run(5) # Our loop + # self.assertEqual(channel.count_writes, 10, "But we still get our loop") if hasattr(socket, "AF_UNIX"): From fcb35a40c825b0913045c345ee356e41e755bd8a Mon Sep 17 00:00:00 2001 From: Dylan Jay Date: Mon, 5 Feb 2024 10:51:34 +0700 Subject: [PATCH 18/20] change fix to mark when closed and always end loop --- src/waitress/channel.py | 4 +++- tests/test_channel.py | 10 ++++++++++ tests/test_server.py | 12 ++++++------ 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/waitress/channel.py b/src/waitress/channel.py index a37c3a7f..976bff8d 100644 --- a/src/waitress/channel.py +++ b/src/waitress/channel.py @@ -45,6 +45,7 @@ class HTTPChannel(wasyncore.dispatcher): last_activity = 0 # Time of last activity will_close = False # set to True to close the socket. close_when_flushed = False # set to True to close the socket when flushed + closed = False # set to True when closed not just due to being disconnected at the start sent_continue = False # used as a latch after sending 100 continue total_outbufs_len = 0 # total bytes ready to send current_outbuf_count = 0 # total bytes written to current outbuf @@ -95,7 +96,7 @@ def handle_write(self): # Precondition: there's data in the out buffer to be sent, or # there's a pending will_close request - if not self.connected and not (self.will_close or self.close_when_flushed): + if self.closed: # we dont want to close the channel twice. # but we need let the channel close if it's marked to close return @@ -316,6 +317,7 @@ def handle_close(self): self.total_outbufs_len = 0 self.connected = False self.outbuf_lock.notify() + self.closed = True wasyncore.dispatcher.close(self) def add_channel(self, map=None): diff --git a/tests/test_channel.py b/tests/test_channel.py index a501123b..86dacfbc 100644 --- a/tests/test_channel.py +++ b/tests/test_channel.py @@ -67,8 +67,18 @@ def test_writable_nothing_in_outbuf_will_close(self): def test_handle_write_not_connected(self): inst, sock, map = self._makeOneWithMap() inst.connected = False + # TODO: handle_write never returns anything anyway self.assertFalse(inst.handle_write()) + def test_handle_write_not_connected_but_will_close(self): + inst, sock, map = self._makeOneWithMap() + inst.connected = False + inst.will_close = True + # https://github.com/Pylons/waitress/issues/418 + # Ensure we actually handle_close even if not connected + self.assertFalse(inst.handle_write()) + self.assertEqual(len(map), 0) + def test_handle_write_with_requests(self): inst, sock, map = self._makeOneWithMap() inst.requests = True diff --git a/tests/test_server.py b/tests/test_server.py index 115c605e..d0f3b0bb 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -319,7 +319,7 @@ def test_create_with_one_socket_handle_accept_noerror(self): def test_quick_shutdown(self): """ Issue found in production that led to 100% useage because getpeername failed after accept but before channel setup. """ - class DummyParser: + class ErrorRequest: error = True # We are simulating a header parsing error version = 1 data = None @@ -347,7 +347,7 @@ def close(self): channel = None class ShutdownChannel(HTTPChannel): - parser_class = DummyParser + parser_class = ErrorRequest def __init__(self, server, sock, addr, adj, map=None): self.count_writes = self.count_close = self.count_wouldblock = 0 @@ -365,7 +365,7 @@ def handle_close(self): self.count_close += 1 return HTTPChannel.handle_close(self) - def server_run(count=1): + def server_run_for_count(count=1): # Modified server run to prevent infinite loop inst.asyncore.loop( timeout=inst.adj.asyncore_loop_timeout, @@ -396,21 +396,21 @@ def server_run(count=1): # server_run(1) # Read the request # self.assertTrue(channel.requests[0].error, "for this bug we need the request to have a parsing error") - # server_run(5) + # server_run_for_count(5) # self.assertIn(channel, inst._map.values(), "our rchannel doesn't get read and closed") # # channel_request_lookahead > 0 would avoid this bug # self.assertTrue(len(channel.requests) > channel.adj.channel_request_lookahead, "channel_request_lookahead == 0 means we don't read the disconnect") # # simulate thread processing the request # channel.service() # self.assertTrue(channel.close_when_flushed, "This prevents reads (which lead to close) and loops on handle_write (with 100% CPU)") - # server_run(5) # Our loop + # server_run_for_count(5) # Our loop # self.assertEqual(channel.count_writes, 5, "We're supposed to be in a loop trying to write but can't") # self.assertEqual(channel.count_close, 0, "but also this connection never gets closed") # # But shouldn't maintenance clear this up? # channel.last_activity = 0 # inst.maintenance(1000) # self.assertEqual(channel.will_close, 1, "maintenance will try to close it") - # server_run(5) # Our loop + # server_run_for_count(5) # Our loop # self.assertEqual(channel.count_writes, 10, "But we still get our loop") From cb1b196d818c78ce0aa007fa94290523282da9d3 Mon Sep 17 00:00:00 2001 From: Dylan Jay Date: Mon, 5 Feb 2024 10:53:19 +0700 Subject: [PATCH 19/20] add git ignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 146736ff..116bc786 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ build/ coverage.xml docs/_themes docs/_build +.venv/ From b33848b5c0e96e286d2df1554b03d3172e622a8c Mon Sep 17 00:00:00 2001 From: Dylan Jay Date: Mon, 5 Feb 2024 13:32:13 +0700 Subject: [PATCH 20/20] add extra tests to show the loop behaviour --- tests/test_server.py | 212 +++++++++++++++++++++++++++---------------- 1 file changed, 133 insertions(+), 79 deletions(-) diff --git a/tests/test_server.py b/tests/test_server.py index d0f3b0bb..fac986f0 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -6,6 +6,7 @@ from time import sleep import time import unittest +from waitress.channel import HTTPChannel dummy_app = object() @@ -316,70 +317,74 @@ def test_create_with_one_socket_handle_accept_noerror(self): self.assertEqual(innersock.opts, [("level", "optname", "value")]) self.assertEqual(L, [(inst, innersock, None, inst.adj)]) - def test_quick_shutdown(self): + def test_error_request_quick_shutdown(self): + """ Issue found in production that led to 100% useage because getpeername failed after accept but before channel setup. + """ + client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 0) + sockets[0].bind(("127.0.0.1", 8000)) + sockets[0].listen() + inst = self._makeWithSockets(_start=False, sockets=sockets) + channels = [] + inst.channel_class = make_quick_shutdown_channel(client, channels) + inst.task_dispatcher = DummyTaskDispatcher() + + # This will make getpeername fail fast with EINVAL OSError + client.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack('ii', 1, 0)) + client.connect(("127.0.0.1", 8000)) + client.send(b"1") # Send our fake request before we accept and close the connection + inst.handle_accept() # ShutdownServer will close the connection after acceot but before getpeername + self.assertRaises(OSError, sockets[0].getpeername) + self.assertFalse(channels[0].connected, "race condition means our socket is marked not connected") + self.assertNotIn(channels[0], inst._map.values(), "we should get an automatic close") + + def test_error_request_no_loop(self): + """ Issue found in production that led to 100% useage because getpeername failed after accept but before channel setup. + """ + client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 0) + sockets[0].bind(("127.0.0.1", 8000)) + sockets[0].listen() + inst = self._makeWithSockets(_start=False, sockets=sockets) + channels = [] + inst.channel_class = make_quick_shutdown_channel(client, channels, shutdown=False) + inst.task_dispatcher = DummyTaskDispatcher() + + # This will make getpeername fail fast with EINVAL OSError + client.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack('ii', 1, 0)) + client.connect(("127.0.0.1", 8000)) + client.send(b"1") # Send our fake request before we accept and close the connection + inst.handle_accept() # ShutdownServer will close the connection after acceot but before getpeername + self.assertRaises(OSError, sockets[0].getpeername) + self.assertEquals(len(channels), 1) + channels[0].connected = False # This used to create a 100% CPU loop + + server_run_for_count(inst, 1) # Read the request + self.assertTrue(channels[0].requests[0].error, "for this bug we need the request to have a parsing error") + server_run_for_count(inst, 5) + # simulate thread processing the request + channels[0].service() + self.assertTrue(channels[0].close_when_flushed, "This prevents reads (which lead to close) and loops on handle_write (with 100% CPU)") + server_run_for_count(inst, 5) # Our loop + self.assertNotIn(channels[0], inst._map.values(), "broken request didn't close the channel") + self.assertEqual(channels[0].count_close, 1, "but also this connection never gets closed") + self.assertLess(channels[0].count_writes, 5, "We're supposed to be in a loop trying to write but can't") + + def test_error_request_maintainace_cleanup(self): """ Issue found in production that led to 100% useage because getpeername failed after accept but before channel setup. """ - class ErrorRequest: - error = True # We are simulating a header parsing error - version = 1 - data = None - completed = True - empty = False - headers_finished = True - expect_continue = False - retval = None - connection_close = False - - def __init__(self, adj): - pass - - def received(self, data): - self.data = data - if self.retval is not None: - return self.retval - return len(data) - - def close(self): - pass - - from waitress.channel import HTTPChannel client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - channel = None - - class ShutdownChannel(HTTPChannel): - parser_class = ErrorRequest - - def __init__(self, server, sock, addr, adj, map=None): - self.count_writes = self.count_close = self.count_wouldblock = 0 - nonlocal channel - channel = self - client.close() # simulate race condition where close happens between accept adn getpeername - return HTTPChannel.__init__(self, server, sock, addr, adj, map) - - def handle_write(self): - self.count_writes += 1 - return HTTPChannel.handle_write(self) - - def handle_close(self): - # import pdb; pdb.set_trace() - self.count_close += 1 - return HTTPChannel.handle_close(self) - - def server_run_for_count(count=1): - # Modified server run to prevent infinite loop - inst.asyncore.loop( - timeout=inst.adj.asyncore_loop_timeout, - map=inst._map, - use_poll=inst.adj.asyncore_use_poll, - count=count - ) sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 0) sockets[0].bind(("127.0.0.1", 8000)) sockets[0].listen() inst = self._makeWithSockets(_start=False, sockets=sockets) - inst.channel_class = ShutdownChannel + channels = [] + inst.channel_class = make_quick_shutdown_channel(client, channels, shutdown=False) inst.task_dispatcher = DummyTaskDispatcher() # This will make getpeername fail fast with EINVAL OSError @@ -388,30 +393,21 @@ def server_run_for_count(count=1): client.send(b"1") # Send our fake request before we accept and close the connection inst.handle_accept() # ShutdownServer will close the connection after acceot but before getpeername self.assertRaises(OSError, sockets[0].getpeername) - self.assertFalse(channel.connected, "race condition means our socket is marked not connected") - self.assertNotIn(channel, inst._map.values(), "we should get an automatic close") - - # UNCOMMENT: To reproduce previous 100% CPU looping behaviour - # self.assertIn(channel, inst._map.values(), "broken request still active to get this bug") - - # server_run(1) # Read the request - # self.assertTrue(channel.requests[0].error, "for this bug we need the request to have a parsing error") - # server_run_for_count(5) - # self.assertIn(channel, inst._map.values(), "our rchannel doesn't get read and closed") - # # channel_request_lookahead > 0 would avoid this bug - # self.assertTrue(len(channel.requests) > channel.adj.channel_request_lookahead, "channel_request_lookahead == 0 means we don't read the disconnect") - # # simulate thread processing the request - # channel.service() - # self.assertTrue(channel.close_when_flushed, "This prevents reads (which lead to close) and loops on handle_write (with 100% CPU)") - # server_run_for_count(5) # Our loop - # self.assertEqual(channel.count_writes, 5, "We're supposed to be in a loop trying to write but can't") - # self.assertEqual(channel.count_close, 0, "but also this connection never gets closed") - # # But shouldn't maintenance clear this up? - # channel.last_activity = 0 - # inst.maintenance(1000) - # self.assertEqual(channel.will_close, 1, "maintenance will try to close it") - # server_run_for_count(5) # Our loop - # self.assertEqual(channel.count_writes, 10, "But we still get our loop") + self.assertNotEqual(channels, []) + channels[0].connected = False ## race condition means our socket is marked not connected + + server_run_for_count(inst, 1) # Read the request + # self.assertTrue(channels[0].requests[0].error, "for this bug we need the request to have a parsing error") + server_run_for_count(inst, 5) + channels[0].service() + self.assertTrue(channels[0].close_when_flushed, "This prevents reads (which lead to close) and loops on handle_write (with 100% CPU)") + server_run_for_count(inst, 5) # Our loop + channels[0].last_activity = 0 + inst.maintenance(1000) + self.assertEqual(channels[0].will_close, 1, "maintenance will try to close it") + self.assertNotIn(channels[0], inst._map.values(), "broken request didn't close the channel") + server_run_for_count(inst, 5) # Our loop + self.assertNotEqual(channels[0].count_writes, 10, "But we still get our loop") if hasattr(socket, "AF_UNIX"): @@ -618,3 +614,61 @@ def __init__(self): def warning(self, msg, **kw): self.logged.append(msg) + + +class ErrorRequest: + error = True # We are simulating a header parsing error + version = 1 + data = None + completed = True + empty = False + headers_finished = True + expect_continue = False + retval = None + connection_close = False + + def __init__(self, adj): + pass + + def received(self, data): + self.data = data + if self.retval is not None: + return self.retval + return len(data) + + def close(self): + pass + + +def make_quick_shutdown_channel(client, channels, shutdown=True): + class ShutdownChannel(HTTPChannel): + parser_class = ErrorRequest + + def __init__(self, server, sock, addr, adj, map=None): + self.count_writes = self.count_close = self.count_wouldblock = 0 + if shutdown: + client.close() + channels.append(self) + return HTTPChannel.__init__(self, server, sock, addr, adj, map) + + def handle_write(self): + self.count_writes += 1 + return HTTPChannel.handle_write(self) + + def handle_close(self): + # import pdb; pdb.set_trace() + self.count_close += 1 + return HTTPChannel.handle_close(self) + + return ShutdownChannel + + +def server_run_for_count(inst, count=1): + # Modified server run to prevent infinite loop + inst.asyncore.loop( + timeout=inst.adj.asyncore_loop_timeout, + map=inst._map, + use_poll=inst.adj.asyncore_use_poll, + count=count + ) +