Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit the size of requests received from HTTP clients #220

Merged
merged 3 commits into from
Apr 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/220.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Limit the size of requests received from HTTP clients.
23 changes: 22 additions & 1 deletion sygnal/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,24 @@ def render_GET(self, request):
return b""


class SizeLimitingRequest(server.Request):
# Arbitrarily limited to 512 KiB.
MAX_REQUEST_SIZE = 512 * 1024

def handleContentChunk(self, data):
# we should have a content by now
assert self.content, "handleContentChunk() called before gotLength()"
if self.content.tell() + len(data) > self.MAX_REQUEST_SIZE:
logger.info(
"Aborting connection from %s because the request exceeds maximum size",
self.client.host,
)
self.transport.abortConnection()
return

return super().handleContentChunk(data)


class SygnalLoggedSite(server.Site):
"""
A subclass of Site to perform access logging in a way that makes sense for
Expand Down Expand Up @@ -354,5 +372,8 @@ def __init__(self, sygnal):
)

self.site = SygnalLoggedSite(
root, reactor=sygnal.reactor, log_formatter=log_formatter
root,
reactor=sygnal.reactor,
log_formatter=log_formatter,
requestFactory=SizeLimitingRequest,
)
15 changes: 5 additions & 10 deletions sygnal/sygnal.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ async def _make_pushkin(self, app_name, app_config):
clarse = getattr(pushkin_module, to_construct)
return await clarse.create(app_name, self, app_config)

async def _make_pushkins_then_start(self, port, bind_addresses, pushgateway_api):
async def make_pushkins_then_start(self):
for app_id, app_cfg in self.config["apps"].items():
try:
self.pushkins[app_id] = await self._make_pushkin(app_id, app_cfg)
Expand All @@ -215,26 +215,21 @@ async def _make_pushkins_then_start(self, port, bind_addresses, pushgateway_api)

logger.info("Configured with app IDs: %r", self.pushkins.keys())

for interface in bind_addresses:
pushgateway_api = PushGatewayApiServer(self)
port = int(self.config["http"]["port"])
for interface in self.config["http"]["bind_addresses"]:
logger.info("Starting listening on %s port %d", interface, port)
self.reactor.listenTCP(port, pushgateway_api.site, interface=interface)

def run(self):
"""
Attempt to run Sygnal and then exit the application.
"""
port = int(self.config["http"]["port"])
bind_addresses = self.config["http"]["bind_addresses"]
pushgateway_api = PushGatewayApiServer(self)

@defer.inlineCallbacks
def start():
try:
yield ensureDeferred(
self._make_pushkins_then_start(
port, bind_addresses, pushgateway_api
)
)
yield ensureDeferred(self.make_pushkins_then_start())
except Exception:
# Print the exception and bail out.
print("Error during startup:", file=sys.stderr)
Expand Down
49 changes: 49 additions & 0 deletions tests/test_pushgateway_api_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from twisted.internet.address import IPv6Address
from twisted.internet.testing import StringTransport

from sygnal.exceptions import (
NotificationDispatchException,
TemporaryNotificationDispatchException,
Expand Down Expand Up @@ -183,3 +186,49 @@ def test_remote_errors_give_502(self):
),
502,
)

def test_overlong_requests_are_rejected(self):
# as a control case, first send a regular request.

# connect the site to a fake transport.
transport = StringTransport()
protocol = self.site.buildProtocol(IPv6Address("TCP", "::1", "2345"))
protocol.makeConnection(transport)

protocol.dataReceived(
b"POST / HTTP/1.1\r\n"
b"Connection: close\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
b"0\r\n"
b"\r\n"
)

# we should get a 404
self.assertRegex(transport.value().decode(), r"^HTTP/1\.1 404 ")

# now send an oversized request
transport = StringTransport()
protocol = self.site.buildProtocol(IPv6Address("TCP", "::1", "2345"))
protocol.makeConnection(transport)

protocol.dataReceived(
b"POST / HTTP/1.1\r\n"
b"Connection: close\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
)

# we deliberately send all the data in one big chunk, to ensure that
# twisted isn't buffering the data in the chunked transfer decoder.
# we start with the chunk size, in hex. (We won't actually send this much)
protocol.dataReceived(b"10000000\r\n")
sent = 0
while not transport.disconnected:
self.assertLess(sent, 0x10000000, "connection did not drop")
protocol.dataReceived(b"\0" * 1024)
sent += 1024

# default max upload size is 512K, so it should drop on the next buffer after
# that.
self.assertEqual(sent, 513 * 1024)
16 changes: 9 additions & 7 deletions tests/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from twisted.web.server import Request
from zope.interface.declarations import implementer

from sygnal.http import PushGatewayApiServer
from sygnal.sygnal import CONFIG_DEFAULTS, Sygnal, merge_left_with_defaults

REQ_PATH = b"/_matrix/push/v1/notify"
Expand Down Expand Up @@ -130,17 +129,20 @@ def setUp(self):
self.sygnal = Sygnal(config, reactor)
self.reactor = reactor
self.sygnal.database.start()
self.v1api = PushGatewayApiServer(self.sygnal)

start_deferred = ensureDeferred(
self.sygnal._make_pushkins_then_start(0, [], None)
)
start_deferred = ensureDeferred(self.sygnal.make_pushkins_then_start())

while not start_deferred.called:
# we need to advance until the pushkins have started up
self.sygnal.reactor.advance(1)
self.sygnal.reactor.wait_for_work(lambda: start_deferred.called)

# sygnal should have started a single (fake) tcp listener
listeners = self.reactor.tcpServers
self.assertEqual(len(listeners), 1)
(port, site, _backlog, interface) = listeners[0]
self.site = site

def tearDown(self):
super().tearDown()
self.sygnal.database.close()
Expand Down Expand Up @@ -205,7 +207,7 @@ def _request(self, payload: Union[str, dict]) -> Union[dict, int]:
payload = json.dumps(payload)
content = BytesIO(payload.encode())

channel = FakeChannel(self.v1api.site, self.sygnal.reactor)
channel = FakeChannel(self.site, self.sygnal.reactor)
channel.process_request(b"POST", REQ_PATH, content)

while not channel.done:
Expand Down Expand Up @@ -245,7 +247,7 @@ def dump_if_needed(payload):

contents = [BytesIO(dump_if_needed(payload).encode()) for payload in payloads]

channels = [FakeChannel(self.v1api.site, self.sygnal.reactor) for _ in contents]
channels = [FakeChannel(self.site, self.sygnal.reactor) for _ in contents]

for channel, content in zip(channels, contents):
channel.process_request(b"POST", REQ_PATH, content)
Expand Down