Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

websocket: Limit post-decompression size of received messages #2391

Merged
merged 1 commit into from
May 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions tornado/test/websocket_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,10 +553,22 @@ class CompressionTestMixin(object):

def get_app(self):
self.close_future = Future()

class LimitedHandler(TestWebSocketHandler):
@property
def max_message_size(self):
return 1024

def on_message(self, message):
self.write_message(str(len(message)))

return Application([
('/echo', EchoHandler, dict(
close_future=self.close_future,
compression_options=self.get_server_compression_options())),
('/limited', LimitedHandler, dict(
close_future=self.close_future,
compression_options=self.get_server_compression_options())),
])

def get_server_compression_options(self):
Expand All @@ -582,6 +594,22 @@ def test_message_sizes(self):
ws.protocol._wire_bytes_out)
yield self.close(ws)

@gen_test
def test_size_limit(self):
ws = yield self.ws_connect(
'/limited',
compression_options=self.get_client_compression_options())
# Small messages pass through.
ws.write_message('a' * 128)
response = yield ws.read_message()
self.assertEqual(response, '128')
# This message is too big after decompression, but it compresses
# down to a size that will pass the initial checks.
ws.write_message('a' * 2048)
response = yield ws.read_message()
self.assertIsNone(response)
yield self.close(ws)


class UncompressedTestMixin(CompressionTestMixin):
"""Specialization of CompressionTestMixin when we expect no compression."""
Expand Down
28 changes: 22 additions & 6 deletions tornado/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
else:
from urlparse import urlparse # py3

_default_max_message_size = 10 * 1024 * 1024


class WebSocketError(Exception):
pass
Expand All @@ -57,6 +59,10 @@ class WebSocketClosedError(WebSocketError):
pass


class _DecompressTooLargeError(Exception):
pass


class WebSocketHandler(tornado.web.RequestHandler):
"""Subclass this class to create a basic WebSocket handler.
Expand Down Expand Up @@ -225,7 +231,7 @@ def max_message_size(self):
Default is 10MiB.
"""
return self.settings.get('websocket_max_message_size', None)
return self.settings.get('websocket_max_message_size', _default_max_message_size)

def write_message(self, message, binary=False):
"""Sends the given message to the client of this Web Socket.
Expand Down Expand Up @@ -596,7 +602,8 @@ def compress(self, data):


class _PerMessageDeflateDecompressor(object):
def __init__(self, persistent, max_wbits, compression_options=None):
def __init__(self, persistent, max_wbits, max_message_size, compression_options=None):
self._max_message_size = max_message_size
if max_wbits is None:
max_wbits = zlib.MAX_WBITS
if not (8 <= max_wbits <= zlib.MAX_WBITS):
Expand All @@ -613,7 +620,10 @@ def _create_decompressor(self):

def decompress(self, data):
decompressor = self._decompressor or self._create_decompressor()
return decompressor.decompress(data + b'\x00\x00\xff\xff')
result = decompressor.decompress(data + b'\x00\x00\xff\xff', self._max_message_size)
if decompressor.unconsumed_tail:
raise _DecompressTooLargeError()
return result


class WebSocketProtocol13(WebSocketProtocol):
Expand Down Expand Up @@ -801,6 +811,7 @@ def _create_compressors(self, side, agreed_parameters, compression_options=None)
self._compressor = _PerMessageDeflateCompressor(
**self._get_compressor_options(side, agreed_parameters, compression_options))
self._decompressor = _PerMessageDeflateDecompressor(
max_message_size=self.handler.max_message_size,
**self._get_compressor_options(other_side, agreed_parameters, compression_options))

def _write_frame(self, fin, opcode, data, flags=0):
Expand Down Expand Up @@ -920,7 +931,7 @@ def _receive_frame(self):
new_len = payloadlen
if self._fragmented_message_buffer is not None:
new_len += len(self._fragmented_message_buffer)
if new_len > (self.handler.max_message_size or 10 * 1024 * 1024):
if new_len > self.handler.max_message_size:
self.close(1009, "message too big")
self._abort()
return
Expand Down Expand Up @@ -971,7 +982,12 @@ def _handle_message(self, opcode, data):
return

if self._frame_compressed:
data = self._decompressor.decompress(data)
try:
data = self._decompressor.decompress(data)
except _DecompressTooLargeError:
self.close(1009, "message too big after decompression")
self._abort()
return

if opcode == 0x1:
# UTF-8 data
Expand Down Expand Up @@ -1260,7 +1276,7 @@ def selected_subprotocol(self):
def websocket_connect(url, callback=None, connect_timeout=None,
on_message_callback=None, compression_options=None,
ping_interval=None, ping_timeout=None,
max_message_size=None, subprotocols=None):
max_message_size=_default_max_message_size, subprotocols=None):
"""Client-side websocket support.
Takes a url and returns a Future whose result is a
Expand Down