Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ensure we don't loop trying to write to a channel thats not connected (fix 100% CPU) #419

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/waitress/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def __init__(self, server, sock, addr, adj, map=None):
self.outbuf_lock = threading.Condition()

wasyncore.dispatcher.__init__(self, sock, map=map)
if not self.connected:
# Sometimes can be closed quickly and getpeername fails.
self.handle_close()
Copy link
Author

@djay djay Sep 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@digitalresistor @d-mauer I'm still not sure on this fix. I think I read somewhere how windows can sometimes fail on getpeername?
The other fix will still prevent the looping bug by letting it write and error out. This one will close it before it wastes the app time if indeed the connection really is closed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If getpeername() fails on Windows then it would get self.connected set to False anyway, this would cause the bug. So trying to keep going after getpeername() failed is not sustainable.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

connect = False doesn't cause the bug alone. It also needs the request to be malformed as that prevents both reading and writing and maintainance from cleaning it up. So in someways the real bug is in handle_write

if not self.connected:
    return

The test i put in shows that in the most likely scenario that makes this occur it's trying to close the channel but is prevented from doing so by the above line.

The more I think about it @d-maurer is correct that this should be changed to self.closed or something that explicitly prevents a close from happening twice. Thats the safest minimal change.


# Don't let wasyncore.dispatcher throttle self.addr on us.
self.addr = addr
Expand All @@ -86,15 +89,15 @@ def writable(self):
# the channel (possibly by our server maintenance logic), run
# handle_write

return self.total_outbufs_len or self.will_close or self.close_when_flushed
return (self.total_outbufs_len or self.will_close or self.close_when_flushed)

def handle_write(self):
# Precondition: there's data in the out buffer to be sent, or
# there's a pending will_close request

if not self.connected:
# we dont want to close the channel twice

if not self.connected and not (self.will_close or self.close_when_flushed):
# we dont want to close the channel twice.
# but we need let the channel close if it's marked to close
return

# try to flush any pending output
Expand Down Expand Up @@ -150,7 +153,6 @@ def readable(self):
# 3. There are not too many tasks already queued
# 4. There's no data in the output buffer that needs to be sent
# before we potentially create a new task.

return not (
self.will_close
or self.close_when_flushed
Expand Down
11 changes: 8 additions & 3 deletions tests/test_channel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from errno import EINVAL
import io
import socket
import unittest

import pytest
Expand All @@ -11,10 +13,10 @@ def _makeOne(self, sock, addr, adj, map=None):
server = DummyServer()
return HTTPChannel(server, sock, addr, adj=adj, map=map)

def _makeOneWithMap(self, adj=None):
def _makeOneWithMap(self, adj=None, sock_shutdown=False):
if adj is None:
adj = DummyAdjustments()
sock = DummySock()
sock = DummySock(shutdown=sock_shutdown)
map = {}
inst = self._makeOne(sock, "127.0.0.1", adj, map=map)
inst.outbuf_lock = DummyLock()
Expand Down Expand Up @@ -906,8 +908,9 @@ class DummySock:
blocking = False
closed = False

def __init__(self):
def __init__(self, shutdown=False):
self.sent = b""
self.shutdown = shutdown

def setblocking(self, *arg):
self.blocking = True
Expand All @@ -916,6 +919,8 @@ def fileno(self):
return 100

def getpeername(self):
if self.shutdown:
raise OSError(EINVAL)
return "127.0.0.1"

def getsockopt(self, level, option):
Expand Down
102 changes: 102 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import errno
import select
import socket
import struct
from threading import Event
from time import sleep
import time
import unittest

dummy_app = object()
Expand Down Expand Up @@ -311,6 +316,103 @@ def test_create_with_one_socket_handle_accept_noerror(self):
self.assertEqual(innersock.opts, [("level", "optname", "value")])
self.assertEqual(L, [(inst, innersock, None, inst.adj)])

def test_quick_shutdown(self):
""" Issue found in production that led to 100% useage because getpeername failed after accept but before channel setup.
"""
class DummyParser:
error = True # We are simulating a header parsing error
version = 1
data = None
completed = True
empty = False
headers_finished = True
expect_continue = False
retval = None
connection_close = False

def __init__(self, adj):
pass

def received(self, data):
self.data = data
if self.retval is not None:
return self.retval
return len(data)

def close(self):
pass

from waitress.channel import HTTPChannel
client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
channel = None

class ShutdownChannel(HTTPChannel):
parser_class = DummyParser

def __init__(self, server, sock, addr, adj, map=None):
self.count_writes = self.count_close = self.count_wouldblock = 0
nonlocal channel
channel = self
client.close() # simulate race condition where close happens between accept adn getpeername
return HTTPChannel.__init__(self, server, sock, addr, adj, map)

def handle_write(self):
self.count_writes += 1
return HTTPChannel.handle_write(self)

def handle_close(self):
# import pdb; pdb.set_trace()
self.count_close += 1
return HTTPChannel.handle_close(self)

def server_run(count=1):
# Modified server run to prevent infinite loop
inst.asyncore.loop(
timeout=inst.adj.asyncore_loop_timeout,
map=inst._map,
use_poll=inst.adj.asyncore_use_poll,
count=count
)

sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)]
sockets[0].setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 0)
sockets[0].bind(("127.0.0.1", 8000))
sockets[0].listen()
inst = self._makeWithSockets(_start=False, sockets=sockets)
inst.channel_class = ShutdownChannel
inst.task_dispatcher = DummyTaskDispatcher()

# This will make getpeername fail fast with EINVAL OSError
client.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack('ii', 1, 0))
client.connect(("127.0.0.1", 8000))
client.send(b"1") # Send our fake request before we accept and close the connection
inst.handle_accept() # ShutdownServer will close the connection after acceot but before getpeername
self.assertRaises(OSError, sockets[0].getpeername)
self.assertFalse(channel.connected, "race condition means our socket is marked not connected")
self.assertNotIn(channel, inst._map.values(), "we should get an automatic close")

# UNCOMMENT: To reproduce previous 100% CPU looping behaviour
# self.assertIn(channel, inst._map.values(), "broken request still active to get this bug")

# server_run(1) # Read the request
# self.assertTrue(channel.requests[0].error, "for this bug we need the request to have a parsing error")
# server_run(5)
# self.assertIn(channel, inst._map.values(), "our rchannel doesn't get read and closed")
# # channel_request_lookahead > 0 would avoid this bug
# self.assertTrue(len(channel.requests) > channel.adj.channel_request_lookahead, "channel_request_lookahead == 0 means we don't read the disconnect")
# # simulate thread processing the request
# channel.service()
# self.assertTrue(channel.close_when_flushed, "This prevents reads (which lead to close) and loops on handle_write (with 100% CPU)")
# server_run(5) # Our loop
# self.assertEqual(channel.count_writes, 5, "We're supposed to be in a loop trying to write but can't")
# self.assertEqual(channel.count_close, 0, "but also this connection never gets closed")
# # But shouldn't maintenance clear this up?
# channel.last_activity = 0
# inst.maintenance(1000)
# self.assertEqual(channel.will_close, 1, "maintenance will try to close it")
# server_run(5) # Our loop
# self.assertEqual(channel.count_writes, 10, "But we still get our loop")


if hasattr(socket, "AF_UNIX"):

Expand Down