Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix issues caused by websocket frame fragmentation #1962

Merged
merged 1 commit into from
Jun 8, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Changes

-

-
- Fix websocket issues caused by frame fragmentation. #1962

-

Expand Down
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ Joongi Kim
Josep Cugat
Julia Tsemusheva
Julien Duponchelle
Jungkook Park
Junjie Tao
Justas Trimailovas
Justin Turner Arthur
Expand Down
33 changes: 12 additions & 21 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,17 +224,17 @@ def _feed_data(self, data):
WSMessage(WSMsgType.PONG, payload, ''), len(payload))

elif opcode not in (
WSMsgType.TEXT, WSMsgType.BINARY) and not self._opcode:
WSMsgType.TEXT, WSMsgType.BINARY) and self._opcode is None:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Unexpected opcode={!r}".format(opcode))
else:
# load text/binary

if not fin:
# got partial frame payload
if opcode != WSMsgType.CONTINUATION:
self._opcode = opcode

self._partial.append(payload)
else:
# previous frame was non finished
Expand All @@ -248,39 +248,37 @@ def _feed_data(self, data):

if opcode == WSMsgType.CONTINUATION:
opcode = self._opcode
self._opcode = None

self._partial.append(payload)
payload_merged = b''.join(self._partial) + payload
self._partial.clear()

if opcode == WSMsgType.TEXT:
try:
text = b''.join(self._partial).decode('utf-8')
text = payload_merged.decode('utf-8')
self.queue.feed_data(
WSMessage(WSMsgType.TEXT, text, ''), len(text))
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT,
'Invalid UTF-8 text message') from exc
else:
data = b''.join(self._partial)
self.queue.feed_data(
WSMessage(WSMsgType.BINARY, data, ''), len(data))

self._start_opcode = None
self._partial.clear()
WSMessage(WSMsgType.BINARY, payload_merged, ''),
len(payload_merged))

return False, b''

def parse_frame(self, buf, continuation=False, EMPTY=b''):
def parse_frame(self, buf):
"""Return the next frame from the socket."""
frames = []
if self._tail:
buf, self._tail = self._tail + buf, EMPTY
buf, self._tail = self._tail + buf, b''

start_pos = 0
buf_length = len(buf)

while True:

# read header
if self._state == WSParserState.READ_HEADER:
if buf_length - start_pos >= 2:
Expand Down Expand Up @@ -312,15 +310,6 @@ def parse_frame(self, buf, continuation=False, EMPTY=b''):
WSCloseCode.PROTOCOL_ERROR,
'Received fragmented control frame')

continuation = not self._frame_fin
if (fin == 0 and
opcode == WSMsgType.CONTINUATION and
not continuation):
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Received new fragment frame with non-zero '
'opcode {!r}'.format(opcode))

has_mask = (second_byte >> 7) & 1
length = (second_byte) & 0x7f

Expand Down Expand Up @@ -409,6 +398,8 @@ def parse_frame(self, buf, continuation=False, EMPTY=b''):
else:
break

self._tail = buf[start_pos:]

return frames


Expand Down
48 changes: 31 additions & 17 deletions tests/test_websocket_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,28 @@
_websocket_mask)


def build_frame(message, opcode, use_mask=False, noheader=False):
def build_frame(message, opcode, use_mask=False, noheader=False, is_fin=True):
"""Send a frame over the websocket with message as its payload."""
msg_length = len(message)
if use_mask: # pragma: no cover
mask_bit = 0x80
else:
mask_bit = 0

if is_fin:
header_first_byte = 0x80 | opcode
else:
header_first_byte = opcode

if msg_length < 126:
header = PACK_LEN1(
0x80 | opcode, msg_length | mask_bit)
header_first_byte, msg_length | mask_bit)
elif msg_length < (1 << 16): # pragma: no cover
header = PACK_LEN2(
0x80 | opcode, 126 | mask_bit, msg_length)
header_first_byte, 126 | mask_bit, msg_length)
else:
header = PACK_LEN3(
0x80 | opcode, 127 | mask_bit, msg_length)
header_first_byte, 127 | mask_bit, msg_length)

if use_mask: # pragma: no cover
mask = random.randrange(0, 0xffffffff)
Expand Down Expand Up @@ -117,13 +122,6 @@ def test_parse_frame_header_control_frame(out, parser):
raise out.exception()


def test_parse_frame_header_continuation(out, parser):
with pytest.raises(WebSocketError):
parser._frame_fin = True
parser.parse_frame(struct.pack('!BB', 0b00000000, 0b00000000))
raise out.exception()


def _test_parse_frame_header_new_data_err(out, parser):
with pytest.raises(WebSocketError):
parser.parse_frame(struct.pack('!BB', 0b000000000, 0b00000000))
Expand Down Expand Up @@ -234,13 +232,21 @@ def test_simple_binary(out, parser):
assert res == ((WSMsgType.BINARY, b'binary', ''), 6)


def test_fragmentation_header(out, parser):
data = build_frame(b'a', WSMsgType.TEXT)
parser._feed_data(data[:1])
parser._feed_data(data[1:])

res = out._buffer[0]
assert res == (WSMessage(WSMsgType.TEXT, 'a', ''), 1)


def test_continuation(out, parser):
parser.parse_frame = mock.Mock()
parser.parse_frame.return_value = [
(0, WSMsgType.TEXT, b'line1'),
(1, WSMsgType.CONTINUATION, b'line2')]
data1 = build_frame(b'line1', WSMsgType.TEXT, is_fin=False)
parser._feed_data(data1)

parser._feed_data(b'')
data2 = build_frame(b'line2', WSMsgType.CONTINUATION)
parser._feed_data(data2)

res = out._buffer[0]
assert res == (WSMessage(WSMsgType.TEXT, 'line1line2', ''), 10)
Expand All @@ -254,7 +260,15 @@ def test_continuation_with_ping(out, parser):
(1, WSMsgType.CONTINUATION, b'line2'),
]

parser.feed_data(b'')
data1 = build_frame(b'line1', WSMsgType.TEXT, is_fin=False)
parser._feed_data(data1)

data2 = build_frame(b'', WSMsgType.PING)
parser._feed_data(data2)

data3 = build_frame(b'line2', WSMsgType.CONTINUATION)
parser._feed_data(data3)

res = out._buffer[0]
assert res == (WSMessage(WSMsgType.PING, b'', ''), 0)
res = out._buffer[1]
Expand Down