Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Security fix: defend against zip bombs. #407

Merged
merged 1 commit into from
May 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ Changelog

*In development*

.. note::

**Version 5.0 fixes a security issue introduced in version 4.0.**

websockets 4.0 was vulnerable to denial of service by memory exhaustion
because it didn't enforce ``max_size`` when decompressing compressed
messages.

.. warning::

**Version 5.0 adds a** ``user_info`` **field to the return value of**
Expand Down
2 changes: 1 addition & 1 deletion websockets/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class Extension:
"""
name = ...

def decode(self, frame):
def decode(self, frame, *, max_size=None):
"""
Decode an incoming frame.

Expand Down
15 changes: 11 additions & 4 deletions websockets/extensions/permessage_deflate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from ..exceptions import (
DuplicateParameter, InvalidParameterName, InvalidParameterValue,
NegotiationError
NegotiationError, PayloadTooBig
)
from ..framing import CTRL_OPCODES, OP_CONT

Expand Down Expand Up @@ -463,7 +463,7 @@ def __repr__(self):
self.local_max_window_bits),
]))

def decode(self, frame):
def decode(self, frame, *, max_size=None):
"""
Decode an incoming frame.

Expand Down Expand Up @@ -495,11 +495,18 @@ def decode(self, frame):
self.decoder = zlib.decompressobj(
wbits=-self.remote_max_window_bits)

# Uncompress compressed frames.
# Uncompress compressed frames. Protect against zip bombs by
# preventing zlib from decompressing more than max_length bytes
# (except when the limit is disabled with max_size = None).
data = frame.data
if frame.fin:
data += _EMPTY_UNCOMPRESSED_BLOCK
data = self.decoder.decompress(data)
max_length = 0 if max_size is None else max_size
data = self.decoder.decompress(data, max_length)
if self.decoder.unconsumed_tail:
raise PayloadTooBig(
"Uncompressed payload length exceeds size limit (? > {} bytes)"
.format(max_size))

# Allow garbage collection of the decoder if it won't be reused.
if frame.fin and self.remote_no_context_takeover:
Expand Down
14 changes: 13 additions & 1 deletion websockets/extensions/test_permessage_deflate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from ..exceptions import (
DuplicateParameter, InvalidParameterName, InvalidParameterValue,
NegotiationError
NegotiationError, PayloadTooBig
)
from ..framing import (
OP_BINARY, OP_CLOSE, OP_CONT, OP_PING, OP_PONG, OP_TEXT, Frame,
Expand Down Expand Up @@ -835,3 +835,15 @@ def test_compress_settings(self):
rsv1=True,
data=b'\x00\x05\x00\xfa\xffcaf\xc3\xa9\x00', # not compressed
))

# Frames aren't decoded beyond max_length.

def test_decompress_max_size(self):
frame = Frame(True, OP_TEXT, ('a' * 20).encode('utf-8'))

enc_frame = self.extension.encode(frame)

self.assertEqual(enc_frame.data, b'JL\xc4\x04\x00\x00')

with self.assertRaises(PayloadTooBig):
self.extension.decode(enc_frame, max_size=10)
4 changes: 2 additions & 2 deletions websockets/framing.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def read(cls, reader, *, mask, max_size=None, extensions=None):
length, = struct.unpack('!Q', data)
if max_size is not None and length > max_size:
raise PayloadTooBig(
"Payload length exceeds limit: {} > {} bytes"
"Payload length exceeds size limit ({} > {} bytes)"
.format(length, max_size))
if mask:
mask_bits = yield from reader(4)
Expand All @@ -134,7 +134,7 @@ def read(cls, reader, *, mask, max_size=None, extensions=None):
if extensions is None:
extensions = []
for extension in reversed(extensions):
frame = extension.decode(frame)
frame = extension.decode(frame, max_size=max_size)

frame.check()

Expand Down
2 changes: 1 addition & 1 deletion websockets/test_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ class NoOpExtension:
def __repr__(self):
return 'NoOpExtension()'

def decode(self, frame):
def decode(self, frame, *, max_size=None):
return frame

def encode(self, frame):
Expand Down
4 changes: 3 additions & 1 deletion websockets/test_framing.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ def encode(frame):
return frame._replace(data=data)

# This extensions is symmetrical.
decode = encode
@staticmethod
def decode(frame, *, max_size=None):
return Rot13.encode(frame)

self.round_trip(
b'\x81\x05uryyb',
Expand Down