Skip to content

Commit

Permalink
Limit the size of requests received from HTTP clients (#220)
Browse files Browse the repository at this point in the history
Applies an arbitrary limit of 512K to incoming request bodies.
  • Loading branch information
richvdh authored Apr 19, 2021
1 parent 7a3d6f8 commit ff1e98e
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 18 deletions.
1 change: 1 addition & 0 deletions changelog.d/220.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Limit the size of requests received from HTTP clients.
23 changes: 22 additions & 1 deletion sygnal/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,24 @@ def render_GET(self, request):
return b""


class SizeLimitingRequest(server.Request):
# Arbitrarily limited to 512 KiB.
MAX_REQUEST_SIZE = 512 * 1024

def handleContentChunk(self, data):
# we should have a content by now
assert self.content, "handleContentChunk() called before gotLength()"
if self.content.tell() + len(data) > self.MAX_REQUEST_SIZE:
logger.info(
"Aborting connection from %s because the request exceeds maximum size",
self.client.host,
)
self.transport.abortConnection()
return

return super().handleContentChunk(data)


class SygnalLoggedSite(server.Site):
"""
A subclass of Site to perform access logging in a way that makes sense for
Expand Down Expand Up @@ -354,5 +372,8 @@ def __init__(self, sygnal):
)

self.site = SygnalLoggedSite(
root, reactor=sygnal.reactor, log_formatter=log_formatter
root,
reactor=sygnal.reactor,
log_formatter=log_formatter,
requestFactory=SizeLimitingRequest,
)
15 changes: 5 additions & 10 deletions sygnal/sygnal.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ async def _make_pushkin(self, app_name, app_config):
clarse = getattr(pushkin_module, to_construct)
return await clarse.create(app_name, self, app_config)

async def _make_pushkins_then_start(self, port, bind_addresses, pushgateway_api):
async def make_pushkins_then_start(self):
for app_id, app_cfg in self.config["apps"].items():
try:
self.pushkins[app_id] = await self._make_pushkin(app_id, app_cfg)
Expand All @@ -215,26 +215,21 @@ async def _make_pushkins_then_start(self, port, bind_addresses, pushgateway_api)

logger.info("Configured with app IDs: %r", self.pushkins.keys())

for interface in bind_addresses:
pushgateway_api = PushGatewayApiServer(self)
port = int(self.config["http"]["port"])
for interface in self.config["http"]["bind_addresses"]:
logger.info("Starting listening on %s port %d", interface, port)
self.reactor.listenTCP(port, pushgateway_api.site, interface=interface)

def run(self):
"""
Attempt to run Sygnal and then exit the application.
"""
port = int(self.config["http"]["port"])
bind_addresses = self.config["http"]["bind_addresses"]
pushgateway_api = PushGatewayApiServer(self)

@defer.inlineCallbacks
def start():
try:
yield ensureDeferred(
self._make_pushkins_then_start(
port, bind_addresses, pushgateway_api
)
)
yield ensureDeferred(self.make_pushkins_then_start())
except Exception:
# Print the exception and bail out.
print("Error during startup:", file=sys.stderr)
Expand Down
49 changes: 49 additions & 0 deletions tests/test_pushgateway_api_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from twisted.internet.address import IPv6Address
from twisted.internet.testing import StringTransport

from sygnal.exceptions import (
NotificationDispatchException,
TemporaryNotificationDispatchException,
Expand Down Expand Up @@ -183,3 +186,49 @@ def test_remote_errors_give_502(self):
),
502,
)

def test_overlong_requests_are_rejected(self):
# as a control case, first send a regular request.

# connect the site to a fake transport.
transport = StringTransport()
protocol = self.site.buildProtocol(IPv6Address("TCP", "::1", "2345"))
protocol.makeConnection(transport)

protocol.dataReceived(
b"POST / HTTP/1.1\r\n"
b"Connection: close\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
b"0\r\n"
b"\r\n"
)

# we should get a 404
self.assertRegex(transport.value().decode(), r"^HTTP/1\.1 404 ")

# now send an oversized request
transport = StringTransport()
protocol = self.site.buildProtocol(IPv6Address("TCP", "::1", "2345"))
protocol.makeConnection(transport)

protocol.dataReceived(
b"POST / HTTP/1.1\r\n"
b"Connection: close\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
)

# we deliberately send all the data in one big chunk, to ensure that
# twisted isn't buffering the data in the chunked transfer decoder.
# we start with the chunk size, in hex. (We won't actually send this much)
protocol.dataReceived(b"10000000\r\n")
sent = 0
while not transport.disconnected:
self.assertLess(sent, 0x10000000, "connection did not drop")
protocol.dataReceived(b"\0" * 1024)
sent += 1024

# default max upload size is 512K, so it should drop on the next buffer after
# that.
self.assertEqual(sent, 513 * 1024)
16 changes: 9 additions & 7 deletions tests/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from twisted.web.server import Request
from zope.interface.declarations import implementer

from sygnal.http import PushGatewayApiServer
from sygnal.sygnal import CONFIG_DEFAULTS, Sygnal, merge_left_with_defaults

REQ_PATH = b"/_matrix/push/v1/notify"
Expand Down Expand Up @@ -130,17 +129,20 @@ def setUp(self):
self.sygnal = Sygnal(config, reactor)
self.reactor = reactor
self.sygnal.database.start()
self.v1api = PushGatewayApiServer(self.sygnal)

start_deferred = ensureDeferred(
self.sygnal._make_pushkins_then_start(0, [], None)
)
start_deferred = ensureDeferred(self.sygnal.make_pushkins_then_start())

while not start_deferred.called:
# we need to advance until the pushkins have started up
self.sygnal.reactor.advance(1)
self.sygnal.reactor.wait_for_work(lambda: start_deferred.called)

# sygnal should have started a single (fake) tcp listener
listeners = self.reactor.tcpServers
self.assertEqual(len(listeners), 1)
(port, site, _backlog, interface) = listeners[0]
self.site = site

def tearDown(self):
super().tearDown()
self.sygnal.database.close()
Expand Down Expand Up @@ -205,7 +207,7 @@ def _request(self, payload: Union[str, dict]) -> Union[dict, int]:
payload = json.dumps(payload)
content = BytesIO(payload.encode())

channel = FakeChannel(self.v1api.site, self.sygnal.reactor)
channel = FakeChannel(self.site, self.sygnal.reactor)
channel.process_request(b"POST", REQ_PATH, content)

while not channel.done:
Expand Down Expand Up @@ -245,7 +247,7 @@ def dump_if_needed(payload):

contents = [BytesIO(dump_if_needed(payload).encode()) for payload in payloads]

channels = [FakeChannel(self.v1api.site, self.sygnal.reactor) for _ in contents]
channels = [FakeChannel(self.site, self.sygnal.reactor) for _ in contents]

for channel, content in zip(channels, contents):
channel.process_request(b"POST", REQ_PATH, content)
Expand Down

0 comments on commit ff1e98e

Please sign in to comment.