Skip to content

Commit

Permalink
Merge pull request #2391 from bdarnell/websocket-decompress-limit
Browse files Browse the repository at this point in the history
websocket: Limit post-decompression size of received messages
  • Loading branch information
bdarnell committed May 19, 2018
2 parents 697bbe9 + 7bd3ef3 commit 729f1d9
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 6 deletions.
28 changes: 28 additions & 0 deletions tornado/test/websocket_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,10 +553,22 @@ class CompressionTestMixin(object):

def get_app(self):
self.close_future = Future()

class LimitedHandler(TestWebSocketHandler):
@property
def max_message_size(self):
return 1024

def on_message(self, message):
self.write_message(str(len(message)))

return Application([
('/echo', EchoHandler, dict(
close_future=self.close_future,
compression_options=self.get_server_compression_options())),
('/limited', LimitedHandler, dict(
close_future=self.close_future,
compression_options=self.get_server_compression_options())),
])

def get_server_compression_options(self):
Expand All @@ -582,6 +594,22 @@ def test_message_sizes(self):
ws.protocol._wire_bytes_out)
yield self.close(ws)

@gen_test
def test_size_limit(self):
ws = yield self.ws_connect(
'/limited',
compression_options=self.get_client_compression_options())
# Small messages pass through.
ws.write_message('a' * 128)
response = yield ws.read_message()
self.assertEqual(response, '128')
# This message is too big after decompression, but it compresses
# down to a size that will pass the initial checks.
ws.write_message('a' * 2048)
response = yield ws.read_message()
self.assertIsNone(response)
yield self.close(ws)


class UncompressedTestMixin(CompressionTestMixin):
"""Specialization of CompressionTestMixin when we expect no compression."""
Expand Down
28 changes: 22 additions & 6 deletions tornado/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
else:
from urlparse import urlparse # py3

_default_max_message_size = 10 * 1024 * 1024


class WebSocketError(Exception):
pass
Expand All @@ -57,6 +59,10 @@ class WebSocketClosedError(WebSocketError):
pass


class _DecompressTooLargeError(Exception):
pass


class WebSocketHandler(tornado.web.RequestHandler):
"""Subclass this class to create a basic WebSocket handler.
Expand Down Expand Up @@ -225,7 +231,7 @@ def max_message_size(self):
Default is 10MiB.
"""
return self.settings.get('websocket_max_message_size', None)
return self.settings.get('websocket_max_message_size', _default_max_message_size)

def write_message(self, message, binary=False):
"""Sends the given message to the client of this Web Socket.
Expand Down Expand Up @@ -596,7 +602,8 @@ def compress(self, data):


class _PerMessageDeflateDecompressor(object):
def __init__(self, persistent, max_wbits, compression_options=None):
def __init__(self, persistent, max_wbits, max_message_size, compression_options=None):
self._max_message_size = max_message_size
if max_wbits is None:
max_wbits = zlib.MAX_WBITS
if not (8 <= max_wbits <= zlib.MAX_WBITS):
Expand All @@ -613,7 +620,10 @@ def _create_decompressor(self):

def decompress(self, data):
decompressor = self._decompressor or self._create_decompressor()
return decompressor.decompress(data + b'\x00\x00\xff\xff')
result = decompressor.decompress(data + b'\x00\x00\xff\xff', self._max_message_size)
if decompressor.unconsumed_tail:
raise _DecompressTooLargeError()
return result


class WebSocketProtocol13(WebSocketProtocol):
Expand Down Expand Up @@ -801,6 +811,7 @@ def _create_compressors(self, side, agreed_parameters, compression_options=None)
self._compressor = _PerMessageDeflateCompressor(
**self._get_compressor_options(side, agreed_parameters, compression_options))
self._decompressor = _PerMessageDeflateDecompressor(
max_message_size=self.handler.max_message_size,
**self._get_compressor_options(other_side, agreed_parameters, compression_options))

def _write_frame(self, fin, opcode, data, flags=0):
Expand Down Expand Up @@ -920,7 +931,7 @@ def _receive_frame(self):
new_len = payloadlen
if self._fragmented_message_buffer is not None:
new_len += len(self._fragmented_message_buffer)
if new_len > (self.handler.max_message_size or 10 * 1024 * 1024):
if new_len > self.handler.max_message_size:
self.close(1009, "message too big")
self._abort()
return
Expand Down Expand Up @@ -971,7 +982,12 @@ def _handle_message(self, opcode, data):
return

if self._frame_compressed:
data = self._decompressor.decompress(data)
try:
data = self._decompressor.decompress(data)
except _DecompressTooLargeError:
self.close(1009, "message too big after decompression")
self._abort()
return

if opcode == 0x1:
# UTF-8 data
Expand Down Expand Up @@ -1260,7 +1276,7 @@ def selected_subprotocol(self):
def websocket_connect(url, callback=None, connect_timeout=None,
on_message_callback=None, compression_options=None,
ping_interval=None, ping_timeout=None,
max_message_size=None, subprotocols=None):
max_message_size=_default_max_message_size, subprotocols=None):
"""Client-side websocket support.
Takes a url and returns a Future whose result is a
Expand Down

0 comments on commit 729f1d9

Please sign in to comment.