From 03984c24e8652c2ded36bb4170ef8a29c2cf4a39 Mon Sep 17 00:00:00 2001 From: Jungkook Park Date: Wed, 7 Jun 2017 02:52:43 +0900 Subject: [PATCH] fix issues caused by websocket fragmentation --- CHANGES.rst | 2 +- CONTRIBUTORS.txt | 1 + aiohttp/http_websocket.py | 33 +++++++++-------------- tests/test_websocket_parser.py | 48 ++++++++++++++++++++++------------ 4 files changed, 45 insertions(+), 39 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index af28e641ecf..131bd02999c 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -11,7 +11,7 @@ Changes - -- +- Fix websocket issues caused by frame fragmentation. #1962 - diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 29eb599723b..bd0928fcb41 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -91,6 +91,7 @@ Joongi Kim Josep Cugat Julia Tsemusheva Julien Duponchelle +Jungkook Park Junjie Tao Justas Trimailovas Justin Turner Arthur diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index 9752b38e568..f853c0d1a86 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -224,17 +224,17 @@ def _feed_data(self, data): WSMessage(WSMsgType.PONG, payload, ''), len(payload)) elif opcode not in ( - WSMsgType.TEXT, WSMsgType.BINARY) and not self._opcode: + WSMsgType.TEXT, WSMsgType.BINARY) and self._opcode is None: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, "Unexpected opcode={!r}".format(opcode)) else: # load text/binary - if not fin: # got partial frame payload if opcode != WSMsgType.CONTINUATION: self._opcode = opcode + self._partial.append(payload) else: # previous frame was non finished @@ -248,12 +248,14 @@ def _feed_data(self, data): if opcode == WSMsgType.CONTINUATION: opcode = self._opcode + self._opcode = None - self._partial.append(payload) + payload_merged = b''.join(self._partial) + payload + self._partial.clear() if opcode == WSMsgType.TEXT: try: - text = b''.join(self._partial).decode('utf-8') + text = payload_merged.decode('utf-8') self.queue.feed_data( WSMessage(WSMsgType.TEXT, text, ''), len(text)) except UnicodeDecodeError as exc: @@ -261,26 +263,22 @@ def _feed_data(self, data): WSCloseCode.INVALID_TEXT, 'Invalid UTF-8 text message') from exc else: - data = b''.join(self._partial) self.queue.feed_data( - WSMessage(WSMsgType.BINARY, data, ''), len(data)) - - self._start_opcode = None - self._partial.clear() + WSMessage(WSMsgType.BINARY, payload_merged, ''), + len(payload_merged)) return False, b'' - def parse_frame(self, buf, continuation=False, EMPTY=b''): + def parse_frame(self, buf): """Return the next frame from the socket.""" frames = [] if self._tail: - buf, self._tail = self._tail + buf, EMPTY + buf, self._tail = self._tail + buf, b'' start_pos = 0 buf_length = len(buf) while True: - # read header if self._state == WSParserState.READ_HEADER: if buf_length - start_pos >= 2: @@ -312,15 +310,6 @@ def parse_frame(self, buf, continuation=False, EMPTY=b''): WSCloseCode.PROTOCOL_ERROR, 'Received fragmented control frame') - continuation = not self._frame_fin - if (fin == 0 and - opcode == WSMsgType.CONTINUATION and - not continuation): - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - 'Received new fragment frame with non-zero ' - 'opcode {!r}'.format(opcode)) - has_mask = (second_byte >> 7) & 1 length = (second_byte) & 0x7f @@ -409,6 +398,8 @@ def parse_frame(self, buf, continuation=False, EMPTY=b''): else: break + self._tail = buf[start_pos:] + return frames diff --git a/tests/test_websocket_parser.py b/tests/test_websocket_parser.py index 9de1ea61998..89b7a7c6ef8 100644 --- a/tests/test_websocket_parser.py +++ b/tests/test_websocket_parser.py @@ -12,7 +12,7 @@ _websocket_mask) -def build_frame(message, opcode, use_mask=False, noheader=False): +def build_frame(message, opcode, use_mask=False, noheader=False, is_fin=True): """Send a frame over the websocket with message as its payload.""" msg_length = len(message) if use_mask: # pragma: no cover @@ -20,15 +20,20 @@ def build_frame(message, opcode, use_mask=False, noheader=False): else: mask_bit = 0 + if is_fin: + header_first_byte = 0x80 | opcode + else: + header_first_byte = opcode + if msg_length < 126: header = PACK_LEN1( - 0x80 | opcode, msg_length | mask_bit) + header_first_byte, msg_length | mask_bit) elif msg_length < (1 << 16): # pragma: no cover header = PACK_LEN2( - 0x80 | opcode, 126 | mask_bit, msg_length) + header_first_byte, 126 | mask_bit, msg_length) else: header = PACK_LEN3( - 0x80 | opcode, 127 | mask_bit, msg_length) + header_first_byte, 127 | mask_bit, msg_length) if use_mask: # pragma: no cover mask = random.randrange(0, 0xffffffff) @@ -117,13 +122,6 @@ def test_parse_frame_header_control_frame(out, parser): raise out.exception() -def test_parse_frame_header_continuation(out, parser): - with pytest.raises(WebSocketError): - parser._frame_fin = True - parser.parse_frame(struct.pack('!BB', 0b00000000, 0b00000000)) - raise out.exception() - - def _test_parse_frame_header_new_data_err(out, parser): with pytest.raises(WebSocketError): parser.parse_frame(struct.pack('!BB', 0b000000000, 0b00000000)) @@ -234,13 +232,21 @@ def test_simple_binary(out, parser): assert res == ((WSMsgType.BINARY, b'binary', ''), 6) +def test_fragmentation_header(out, parser): + data = build_frame(b'a', WSMsgType.TEXT) + parser._feed_data(data[:1]) + parser._feed_data(data[1:]) + + res = out._buffer[0] + assert res == (WSMessage(WSMsgType.TEXT, 'a', ''), 1) + + def test_continuation(out, parser): - parser.parse_frame = mock.Mock() - parser.parse_frame.return_value = [ - (0, WSMsgType.TEXT, b'line1'), - (1, WSMsgType.CONTINUATION, b'line2')] + data1 = build_frame(b'line1', WSMsgType.TEXT, is_fin=False) + parser._feed_data(data1) - parser._feed_data(b'') + data2 = build_frame(b'line2', WSMsgType.CONTINUATION) + parser._feed_data(data2) res = out._buffer[0] assert res == (WSMessage(WSMsgType.TEXT, 'line1line2', ''), 10) @@ -254,7 +260,15 @@ def test_continuation_with_ping(out, parser): (1, WSMsgType.CONTINUATION, b'line2'), ] - parser.feed_data(b'') + data1 = build_frame(b'line1', WSMsgType.TEXT, is_fin=False) + parser._feed_data(data1) + + data2 = build_frame(b'', WSMsgType.PING) + parser._feed_data(data2) + + data3 = build_frame(b'line2', WSMsgType.CONTINUATION) + parser._feed_data(data3) + res = out._buffer[0] assert res == (WSMessage(WSMsgType.PING, b'', ''), 0) res = out._buffer[1]