From 9bf30fec025b33b65f9356416cf39911d802d5ba Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 2 Jan 2015 15:32:20 +0200 Subject: [PATCH 01/21] Implement aiohttp.web websockets --- aiohttp/web.py | 184 ++++++++++++++++++----- tests/test_urldispatch.py | 3 +- tests/test_web_exceptions.py | 3 +- tests/test_web_request.py | 4 +- tests/test_web_response.py | 6 +- tests/test_web_websocket_functional.py | 194 +++++++++++++++++++++++++ 6 files changed, 353 insertions(+), 41 deletions(-) create mode 100644 tests/test_web_websocket_functional.py diff --git a/aiohttp/web.py b/aiohttp/web.py index 06988806ab7..2b0e304ecfc 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -22,9 +22,11 @@ from .protocol import Response as ResponseImpl, HttpVersion, HttpVersion11 from .server import ServerHttpProtocol from .streams import EOF_MARKER +from .websocket import do_handshake, MSG_BINARY, MSG_CLOSE, MSG_PING, MSG_TEXT __all__ = [ + 'WebSocketClosed', 'Application', 'HttpVersion', 'RequestHandler', @@ -32,6 +34,7 @@ 'Request', 'StreamResponse', 'Response', + 'WebSocketResponse', 'UrlDispatcher', 'UrlMappingMatchInfo', 'HTTPException', @@ -81,6 +84,21 @@ ] +class WebSocketClosed(GeneratorExit): + """Raised on closing websocket by peer.""" + + def __init__(self, code=None, message=None): + super().__init__(code, message) + + @property + def code(self): + return self.args[0] + + @property + def message(self): + return self.args[1] + + sentinel = object() @@ -135,11 +153,12 @@ def content_length(self): class Request(HeadersMixin): - def __init__(self, app, message, payload, transport, writer, + def __init__(self, app, message, payload, transport, reader, writer, keep_alive_timeout): self._app = app self._version = message.version self._transport = transport + self._reader = reader self._writer = writer self._method = message.method self._host = message.headers.get('HOST') @@ -374,17 +393,6 @@ def post(self): self._post = MultiDict(out.items(getall=True)) return self._post - # @asyncio.coroutine - # def start_websocket(self): - # """Upgrade connection to websocket. - - # Returns (reader, writer) pair. - # """ - - # upgrade = 'websocket' in message.headers.get('UPGRADE', '').lower() - # if not upgrade: - # pass - ############################################################ # HTTP Response classes @@ -532,15 +540,22 @@ def _generate_content_type_header(self): ctype = self._content_type self.headers['Content-Type'] = ctype - def start(self, request): + def _start_pre_check(self, request): if self._resp_impl is not None: if self._req is not request: raise RuntimeError( 'Response has been started with different request.') - return self._resp_impl + else: + return self._resp_impl + else: + return None - self._req = request + def start(self, request): + resp_impl = self._start_pre_check(request) + if resp_impl is not None: + return resp_impl + self._req = request keep_alive = self._keep_alive if keep_alive is None: keep_alive = request.keep_alive @@ -647,6 +662,93 @@ def write_eof(self): yield from super().write_eof() +class WebSocketResponse(StreamResponse): + + def __init__(self, *protocols): + super().__init__(status=101) + self._protocols = protocols + self._protocol = None + self._writer = None + self._reader = None + + def start(self, request): + # make pre-check to don't hide it by do_handshake() exceptions + resp_impl = self._start_pre_check(request) + if resp_impl is not None: + return resp_impl + + status, headers, parser, writer, protocol = do_handshake( + request.method, request.headers, request.transport) + + if self.status != status: + self.set_status(status) + for k, v in headers: + self.headers[k] = v + + resp_impl = super().start(request) + + self._reader = request._reader.set_parser(parser) + self._writer = writer + self._protocol = protocol + + return resp_impl + + @property + def protocol(self): + return self._protocol + + def ping(self): + if self._writer is None: + raise RuntimeError('Call .start() first') + self._writer.ping() + + def send_str(self, data): + if self._writer is None: + raise RuntimeError('Call .start() first') + if not isinstance(data, str): + raise TypeError('data argument must be str (%r)', type(data)) + self._writer.send(data, binary=False) + + def send_bytes(self, data): + if self._writer is None: + raise RuntimeError('Call .start() first') + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError('data argument must be byte-ish (%r)', + type(data)) + self._writer.send(data, binary=True) + + def close(self, *, code=1000, message=b''): + if self._writer is None: + raise RuntimeError('Call .start() first') + self._writer.close(code, message) + + @asyncio.coroutine + def receive(self): + if self._reader is None: + raise RuntimeError('Call .start() first') + while True: + try: + msg = yield from self._reader.read() + except Exception as exc: + # client dropped connection + raise WebSocketClosed(code=None, message=str(exc)) from exc + + if msg.tp == MSG_PING: + self._writer.pong() + elif msg.tp == MSG_CLOSE: + raise WebSocketClosed(msg.data, msg.extra) + elif msg.tp == MSG_TEXT: + return msg.data + elif msg.tp == MSG_BINARY: + return msg.data + else: + # ignore MSG_PONG + pass + + def write(self, data): + raise RuntimeError("Cannot call .write() for websocket") + + ############################################################ # HTTP Exceptions ############################################################ @@ -1195,35 +1297,45 @@ def handle_request(self, message, payload): app = self._app request = Request(app, message, payload, - self.transport, self.writer, self.keep_alive_timeout) + self.transport, self.reader, self.writer, + self.keep_alive_timeout) try: - match_info = yield from self._router.resolve(request) + try: + match_info = yield from self._router.resolve(request) - assert isinstance(match_info, AbstractMatchInfo), match_info + assert isinstance(match_info, AbstractMatchInfo), match_info - request._match_info = match_info - handler = match_info.handler + request._match_info = match_info + handler = match_info.handler - for factory in reversed(self._middlewares): - handler = yield from factory(app, handler) - resp = yield from handler(request) + for factory in reversed(self._middlewares): + handler = yield from factory(app, handler) + resp = yield from handler(request) - if not isinstance(resp, StreamResponse): - raise RuntimeError( - ("Handler {!r} should return response instance, got {!r} " - "[middlewares {!r}]") - .format(match_info.handler, type(resp), self._middlewares)) - except HTTPException as exc: - resp = exc + if not isinstance(resp, StreamResponse): + raise RuntimeError( + ("Handler {!r} should return response instance, " + "got {!r} [middlewares {!r}]") + .format(match_info.handler, type(resp), + self._middlewares)) + except HTTPException as exc: + resp = exc - resp_msg = resp.start(request) - yield from resp.write_eof() + resp_msg = resp.start(request) + yield from resp.write_eof() + + # notify server about keep-alive + self.keep_alive(resp_msg.keep_alive()) + + # log access + self.log_access(message, None, resp_msg, self._loop.time() - now) + + except GeneratorExit: + self.close() - # notify server about keep-alive - self.keep_alive(resp_msg.keep_alive()) + # log access + self.log_access(message, None, None, self._loop.time() - now) - # log access - self.log_access(message, None, resp_msg, self._loop.time() - now) class RequestHandlerFactory: diff --git a/tests/test_urldispatch.py b/tests/test_urldispatch.py index 98e9c570f5a..0564c4b51ae 100644 --- a/tests/test_urldispatch.py +++ b/tests/test_urldispatch.py @@ -25,9 +25,10 @@ def make_request(self, method, path): MultiDict(), False, False) self.payload = mock.Mock() self.transport = mock.Mock() + self.reader = mock.Mock() self.writer = mock.Mock() req = Request(self.app, message, self.payload, - self.transport, self.writer, 15) + self.transport, self.reader, self.writer, 15) return req def test_add_route_root(self): diff --git a/tests/test_web_exceptions.py b/tests/test_web_exceptions.py index fc9d190d7ac..3fa566ba0f5 100644 --- a/tests/test_web_exceptions.py +++ b/tests/test_web_exceptions.py @@ -17,6 +17,7 @@ def setUp(self): self.payload = mock.Mock() self.transport = mock.Mock() + self.reader = mock.Mock() self.writer = mock.Mock() self.writer.drain.return_value = () self.buf = b'' @@ -34,7 +35,7 @@ def make_request(self, method='GET', path='/', headers=MultiDict()): message = RawRequestMessage(method, path, HttpVersion11, headers, False, False) req = Request(self.app, message, self.payload, - self.transport, self.writer, 15) + self.transport, self.reader, self.writer, 15) return req def test_all_http_exceptions_exported(self): diff --git a/tests/test_web_request.py b/tests/test_web_request.py index c69b5ebb022..d6a938f3f24 100644 --- a/tests/test_web_request.py +++ b/tests/test_web_request.py @@ -25,8 +25,10 @@ def make_request(self, method, path, headers=MultiDict(), *, self.payload = mock.Mock() self.transport = mock.Mock() self.writer = mock.Mock() + self.reader = mock.Mock() req = Request(self.app, message, self.payload, - self.transport, self.writer, keep_alive_timeout) + self.transport, self.reader, self.writer, + keep_alive_timeout) return req def test_ctor(self): diff --git a/tests/test_web_response.py b/tests/test_web_response.py index 865d1086304..4763012021f 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -21,9 +21,10 @@ def make_request(self, method, path, headers=MultiDict()): False, False) self.payload = mock.Mock() self.transport = mock.Mock() + self.reader = mock.Mock() self.writer = mock.Mock() req = Request(self.app, message, self.payload, - self.transport, self.writer, 15) + self.transport, self.reader, self.writer, 15) return req def test_ctor(self): @@ -276,9 +277,10 @@ def make_request(self, method, path, headers=MultiDict()): False, False) self.payload = mock.Mock() self.transport = mock.Mock() + self.reader = mock.Mock() self.writer = mock.Mock() req = Request(self.app, message, self.payload, - self.transport, self.writer, 15) + self.transport, self.reader, self.writer, 15) return req def test_ctor(self): diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py new file mode 100644 index 00000000000..e1c35e6b7b4 --- /dev/null +++ b/tests/test_web_websocket_functional.py @@ -0,0 +1,194 @@ +import asyncio +import base64 +import hashlib +import os +import socket +import unittest + +import aiohttp +from aiohttp import web, websocket + + +WS_KEY = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + +class TestWebWebSocketFunctional(unittest.TestCase): + + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def find_unused_port(self): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(('127.0.0.1', 0)) + port = s.getsockname()[1] + s.close() + return port + + @asyncio.coroutine + def create_server(self, method, path, handler=None): + app = web.Application(loop=self.loop, debug=True) + if handler: + app.router.add_route(method, path, handler) + + port = self.find_unused_port() + srv = yield from self.loop.create_server( + app.make_handler(), '127.0.0.1', port) + url = "http://127.0.0.1:{}".format(port) + path + self.addCleanup(srv.close) + return app, srv, url + + @asyncio.coroutine + def connect_ws(self, url, protocol='chat'): + sec_key = base64.b64encode(os.urandom(16)) + + conn = aiohttp.TCPConnector(loop=self.loop) + self.addCleanup(conn.close) + # send request + response = yield from aiohttp.request( + 'get', url, + headers={ + 'UPGRADE': 'WebSocket', + 'CONNECTION': 'Upgrade', + 'SEC-WEBSOCKET-VERSION': '13', + 'SEC-WEBSOCKET-PROTOCOL': protocol, + 'SEC-WEBSOCKET-KEY': sec_key.decode(), + }, + connector=conn, + loop=self.loop) + self.addCleanup(response.close, True) + + self.assertEqual(101, response.status) + self.assertEqual(response.headers.get('upgrade', '').lower(), + 'websocket') + self.assertEqual(response.headers.get('connection', '').lower(), + 'upgrade') + + key = response.headers.get('sec-websocket-accept', '').encode() + match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()) + self.assertEqual(key, match) + + # switch to websocket protocol + connection = response.connection + reader = connection.reader.set_parser(websocket.WebSocketParser) + writer = websocket.WebSocketWriter(connection.writer) + + return reader, writer + + def test_send_recv_text(self): + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse() + ws.start(request) + + msg = yield from ws.receive() + ws.send_str(msg+'/answer') + ws.close() + return ws + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + reader, writer = yield from self.connect_ws(url) + writer.send('ask') + msg = yield from reader.read() + self.assertEqual(msg.tp, websocket.MSG_TEXT) + self.assertEqual('ask/answer', msg.data) + + msg = yield from reader.read() + self.assertEqual(msg.tp, websocket.MSG_CLOSE) + self.assertEqual(msg.data, 1000) + self.assertEqual(msg.extra, b'') + + self.loop.run_until_complete(go()) + + def test_send_recv_bytes(self): + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse() + ws.start(request) + + msg = yield from ws.receive() + ws.send_bytes(msg+b'/answer') + ws.close() + return ws + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + reader, writer = yield from self.connect_ws(url) + writer.send(b'ask', binary=True) + msg = yield from reader.read() + self.assertEqual(msg.tp, websocket.MSG_BINARY) + self.assertEqual(b'ask/answer', msg.data) + + msg = yield from reader.read() + self.assertEqual(msg.tp, websocket.MSG_CLOSE) + self.assertEqual(msg.data, 1000) + self.assertEqual(msg.extra, b'') + + self.loop.run_until_complete(go()) + + def test_auto_pong_with_closing_by_peer(self): + + closed = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse() + ws.start(request) + + try: + yield from ws.receive() + except web.WebSocketClosed as exc: + self.assertEqual(1, exc.code) + self.assertEqual(b'exit message', exc.message) + closed.set_result(None) + raise + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + reader, writer = yield from self.connect_ws(url) + writer.ping() + msg = yield from reader.read() + self.assertEqual(msg.tp, websocket.MSG_PONG) + writer.close(1, 'exit message') + yield from closed + + self.loop.run_until_complete(go()) + + def test_ping(self): + + closed = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse() + ws.start(request) + + ws.ping() + try: + yield from ws.receive() + except web.WebSocketClosed as exc: + self.assertEqual(2, exc.code) + self.assertEqual(b'exit message', exc.message) + closed.set_result(None) + raise + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + reader, writer = yield from self.connect_ws(url) + msg = yield from reader.read() + self.assertEqual(msg.tp, websocket.MSG_PING) + writer.pong() + writer.close(2, 'exit message') + yield from closed + + self.loop.run_until_complete(go()) From d5a1649cb741e10e67feb4a547aa4f901191c25f Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 2 Jan 2015 15:38:20 +0200 Subject: [PATCH 02/21] Fix flake8 errors --- aiohttp/web.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aiohttp/web.py b/aiohttp/web.py index 2b0e304ecfc..3e676292694 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -1337,7 +1337,6 @@ def handle_request(self, message, payload): self.log_access(message, None, None, self._loop.time() - now) - class RequestHandlerFactory: def __init__(self, app, router, *, From 0129974ca638877a878806bb25900303de8355ea Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 2 Jan 2015 16:46:27 +0200 Subject: [PATCH 03/21] More tests --- aiohttp/web.py | 3 --- tests/test_web_websocket_functional.py | 19 +++++++++++++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/aiohttp/web.py b/aiohttp/web.py index 3e676292694..14c7f1359dd 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -741,9 +741,6 @@ def receive(self): return msg.data elif msg.tp == MSG_BINARY: return msg.data - else: - # ignore MSG_PONG - pass def write(self, data): raise RuntimeError("Cannot call .write() for websocket") diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index e1c35e6b7b4..76b343d95ae 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -192,3 +192,22 @@ def go(): yield from closed self.loop.run_until_complete(go()) + + def test_change_status(self): + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse() + ws.set_status(200) + self.assertEqual(200, ws.status) + ws.start(request) + self.assertEqual(101, ws.status) + return ws + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + _, writer = yield from self.connect_ws(url) + writer.close() + + self.loop.run_until_complete(go()) From 40dd19db915401f73d92a2b7440845327941b57e Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 2 Jan 2015 16:54:27 +0200 Subject: [PATCH 04/21] Add test for raising exception from WebSocketResponse.receive() --- aiohttp/web.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aiohttp/web.py b/aiohttp/web.py index 14c7f1359dd..f7e255869f5 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -730,7 +730,6 @@ def receive(self): try: msg = yield from self._reader.read() except Exception as exc: - # client dropped connection raise WebSocketClosed(code=None, message=str(exc)) from exc if msg.tp == MSG_PING: From b7330dddc6893bd19e50de85a85fc8e1c0fb26c5 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 2 Jan 2015 17:01:18 +0200 Subject: [PATCH 05/21] Add nonfunctional tests for aiohttp.web.websockets --- aiohttp/web.py | 7 +- tests/test_web_websocket.py | 126 ++++++++++++++++++++++++++++++++++++ 2 files changed, 130 insertions(+), 3 deletions(-) create mode 100644 tests/test_web_websocket.py diff --git a/aiohttp/web.py b/aiohttp/web.py index f7e255869f5..51862ccf679 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -1311,9 +1311,10 @@ def handle_request(self, message, payload): if not isinstance(resp, StreamResponse): raise RuntimeError( ("Handler {!r} should return response instance, " - "got {!r} [middlewares {!r}]") - .format(match_info.handler, type(resp), - self._middlewares)) + "got {!r} [middlewares {!r}]").format( + match_info.handler, + type(resp), + self._middlewares)) except HTTPException as exc: resp = exc diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py new file mode 100644 index 00000000000..f2213ae6611 --- /dev/null +++ b/tests/test_web_websocket.py @@ -0,0 +1,126 @@ +import asyncio +import unittest +from unittest import mock +from aiohttp.multidict import MultiDict +from aiohttp.web import Request, WebSocketResponse, WebSocketClosed +from aiohttp.protocol import RawRequestMessage, HttpVersion11 + + +class TestWebWebSocket(unittest.TestCase): + + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def make_request(self, method, path): + self.app = mock.Mock() + headers = MultiDict({'HOST': 'server.example.com', + 'UPGRADE': 'websocket', + 'CONNECTION': 'Upgrade', + 'SEC-WEBSOCKET-KEY': 'dGhlIHNhbXBsZSBub25jZQ==', + 'ORIGIN': 'http://example.com', + 'SEC-WEBSOCKET-PROTOCOL': 'chat, superchat', + 'SEC-WEBSOCKET-VERSION': '13'}) + message = RawRequestMessage(method, path, HttpVersion11, headers, + False, False) + self.payload = mock.Mock() + self.transport = mock.Mock() + self.reader = mock.Mock() + self.writer = mock.Mock() + req = Request(self.app, message, self.payload, + self.transport, self.reader, self.writer, 15) + return req + + def test_nonstarted_ping(self): + ws = WebSocketResponse() + with self.assertRaises(RuntimeError): + ws.ping() + + def test_nonstarted_send_str(self): + ws = WebSocketResponse() + with self.assertRaises(RuntimeError): + ws.send_str('string') + + def test_nonstarted_send_bytes(self): + ws = WebSocketResponse() + with self.assertRaises(RuntimeError): + ws.send_bytes(b'bytes') + + def test_nonstarted_close(self): + ws = WebSocketResponse() + with self.assertRaises(RuntimeError): + ws.close() + + def test_nonstarted_receive(self): + + @asyncio.coroutine + def go(): + ws = WebSocketResponse() + with self.assertRaises(RuntimeError): + yield from ws.receive() + + self.loop.run_until_complete(go()) + + def test_send_str_nonstring(self): + req = self.make_request('GET', '/') + ws = WebSocketResponse() + ws.start(req) + with self.assertRaises(TypeError): + ws.send_str(b'bytes') + + def test_send_bytes_nonbytes(self): + req = self.make_request('GET', '/') + ws = WebSocketResponse() + ws.start(req) + with self.assertRaises(TypeError): + ws.send_bytes('string') + + def test_write(self): + ws = WebSocketResponse() + with self.assertRaises(RuntimeError): + ws.write(b'data') + + def test_nested_exception(self): + + @asyncio.coroutine + def a(): + raise WebSocketClosed() + + @asyncio.coroutine + def b(): + yield from a() + + @asyncio.coroutine + def c(): + yield from b() + + with self.assertRaises(WebSocketClosed): + self.loop.run_until_complete(c()) + + def test_exception_in_receive(self): + + @asyncio.coroutine + def go(): + req = self.make_request('GET', '/') + ws = WebSocketResponse() + ws.start(req) + + err = RuntimeError("error") + + @asyncio.coroutine + def throw(): + raise err + + ws._reader.read = throw + + with self.assertRaises(WebSocketClosed) as exc: + yield from ws.receive() + + self.assertEqual("error", exc.exception.message) + self.assertIsNone(exc.exception.code) + self.assertIs(err, exc.exception.__cause__) + + self.loop.run_until_complete(go()) From e333680d6f1f7d9da90107f5313e4082ca35055b Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 2 Jan 2015 17:08:52 +0200 Subject: [PATCH 06/21] Fix closing web RequestHandler on GeneratorExit --- aiohttp/web.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/aiohttp/web.py b/aiohttp/web.py index 51862ccf679..6e62019aff1 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -1328,8 +1328,7 @@ def handle_request(self, message, payload): self.log_access(message, None, resp_msg, self._loop.time() - now) except GeneratorExit: - self.close() - + self.transport.close() # log access self.log_access(message, None, None, self._loop.time() - now) From 4d2071f02956f5059f707ac8450ba7187bcdbaef Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 2 Jan 2015 17:11:39 +0200 Subject: [PATCH 07/21] Add comment on reason for closing transport on GeneratorExit --- aiohttp/web.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aiohttp/web.py b/aiohttp/web.py index 6e62019aff1..eff146cd4e5 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -1328,6 +1328,7 @@ def handle_request(self, message, payload): self.log_access(message, None, resp_msg, self._loop.time() - now) except GeneratorExit: + # the HTTP protocol probably in invalid state, close connection self.transport.close() # log access self.log_access(message, None, None, self._loop.time() - now) From b69e82c7ccf1ee80afc91a5df0ca02ce675d26e8 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 2 Jan 2015 17:12:54 +0200 Subject: [PATCH 08/21] Fix grammar --- aiohttp/web.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiohttp/web.py b/aiohttp/web.py index eff146cd4e5..644905bfa42 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -1328,7 +1328,7 @@ def handle_request(self, message, payload): self.log_access(message, None, resp_msg, self._loop.time() - now) except GeneratorExit: - # the HTTP protocol probably in invalid state, close connection + # the HTTP protocol is probably in invalid state, close connection self.transport.close() # log access self.log_access(message, None, None, self._loop.time() - now) From cb87c4a641925ac2a687670c3a27422457d73a44 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 2 Jan 2015 17:16:53 +0200 Subject: [PATCH 09/21] Add test for handling websocket protocols --- tests/test_web_websocket_functional.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index 76b343d95ae..89ee24e426f 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -211,3 +211,20 @@ def go(): writer.close() self.loop.run_until_complete(go()) + + def test_handle_protocol(self): + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse('foo', 'bar') + ws.start(request) + self.assertEqual('bar', ws.protocol) + return ws + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + _, writer = yield from self.connect_ws(url, 'bar, foo') + writer.close() + + self.loop.run_until_complete(go()) From fce528c81952b6a569630c841e673d2a11ea9ec5 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 2 Jan 2015 17:47:23 +0200 Subject: [PATCH 10/21] Fix test for websocket protocols --- aiohttp/web.py | 3 ++- tests/test_web_websocket_functional.py | 10 +++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/aiohttp/web.py b/aiohttp/web.py index 644905bfa42..6f72dc7d8f7 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -678,7 +678,8 @@ def start(self, request): return resp_impl status, headers, parser, writer, protocol = do_handshake( - request.method, request.headers, request.transport) + request.method, request.headers, request.transport, + self._protocols) if self.status != status: self.set_status(status) diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index 89ee24e426f..726cd864f97 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -195,6 +195,8 @@ def go(): def test_change_status(self): + closed = asyncio.Future(loop=self.loop) + @asyncio.coroutine def handler(request): ws = web.WebSocketResponse() @@ -202,6 +204,7 @@ def handler(request): self.assertEqual(200, ws.status) ws.start(request) self.assertEqual(101, ws.status) + closed.set_result(None) return ws @asyncio.coroutine @@ -209,22 +212,27 @@ def go(): _, _, url = yield from self.create_server('GET', '/', handler) _, writer = yield from self.connect_ws(url) writer.close() + yield from closed self.loop.run_until_complete(go()) def test_handle_protocol(self): + closed = asyncio.Future(loop=self.loop) + @asyncio.coroutine def handler(request): ws = web.WebSocketResponse('foo', 'bar') ws.start(request) self.assertEqual('bar', ws.protocol) + closed.set_result(None) return ws @asyncio.coroutine def go(): _, _, url = yield from self.create_server('GET', '/', handler) - _, writer = yield from self.connect_ws(url, 'bar, foo') + _, writer = yield from self.connect_ws(url, 'eggs, bar') writer.close() + yield from closed self.loop.run_until_complete(go()) From 717d33804f4d0de8de5619baad9c59988ff2ef62 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 3 Jan 2015 11:21:48 +0200 Subject: [PATCH 11/21] Make WebSocketResponse protocols parameter keyword-only (think about adding origins also). --- aiohttp/web.py | 2 +- tests/test_web_websocket_functional.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/aiohttp/web.py b/aiohttp/web.py index 6f72dc7d8f7..118b1f78692 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -664,7 +664,7 @@ def write_eof(self): class WebSocketResponse(StreamResponse): - def __init__(self, *protocols): + def __init__(self, *, protocols=()): super().__init__(status=101) self._protocols = protocols self._protocol = None diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index 726cd864f97..0ce0911aa33 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -222,7 +222,7 @@ def test_handle_protocol(self): @asyncio.coroutine def handler(request): - ws = web.WebSocketResponse('foo', 'bar') + ws = web.WebSocketResponse(protocols=('foo', 'bar')) ws.start(request) self.assertEqual('bar', ws.protocol) closed.set_result(None) From 6791d9c273869456ff08d38701dc86e8c378c78a Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 3 Jan 2015 17:49:30 +0200 Subject: [PATCH 12/21] Add closing, can_close and wait_closed to WebSocketResponse --- aiohttp/web.py | 74 ++++++++++++++++++++++++++----- tests/test_web_websocket.py | 86 +++++++++++++++++++++++++++++++++---- 2 files changed, 142 insertions(+), 18 deletions(-) diff --git a/aiohttp/web.py b/aiohttp/web.py index 118b1f78692..eb0aae09225 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -22,7 +22,9 @@ from .protocol import Response as ResponseImpl, HttpVersion, HttpVersion11 from .server import ServerHttpProtocol from .streams import EOF_MARKER -from .websocket import do_handshake, MSG_BINARY, MSG_CLOSE, MSG_PING, MSG_TEXT +from .websocket import (do_handshake, WebSocketError, + MSG_BINARY, MSG_CLOSE, MSG_PING, MSG_TEXT) +from .errors import HttpProcessingError __all__ = [ @@ -670,6 +672,9 @@ def __init__(self, *, protocols=()): self._protocol = None self._writer = None self._reader = None + self._closing = False + self._loop = None + self._closing_fut = None def start(self, request): # make pre-check to don't hide it by do_handshake() exceptions @@ -691,9 +696,27 @@ def start(self, request): self._reader = request._reader.set_parser(parser) self._writer = writer self._protocol = protocol + self._loop = request.app.loop + self._closing_fut = asyncio.Future(loop=self._loop) return resp_impl + def can_start(self, request): + if self._writer is not None: + raise RuntimeError('Already started') + try: + _, _, _, _, protocol = do_handshake( + request.method, request.headers, request.transport, + self._protocols) + except (WebSocketError, HttpProcessingError): + return False, None + else: + return True, protocol + + @property + def closing(self): + return self._closing + @property def protocol(self): return self._protocol @@ -701,11 +724,15 @@ def protocol(self): def ping(self): if self._writer is None: raise RuntimeError('Call .start() first') + if self._closing: + raise RuntimeError('websocket connection is closing') self._writer.ping() def send_str(self, data): if self._writer is None: raise RuntimeError('Call .start() first') + if self._closing: + raise RuntimeError('websocket connection is closing') if not isinstance(data, str): raise TypeError('data argument must be str (%r)', type(data)) self._writer.send(data, binary=False) @@ -713,6 +740,8 @@ def send_str(self, data): def send_bytes(self, data): if self._writer is None: raise RuntimeError('Call .start() first') + if self._closing: + raise RuntimeError('websocket connection is closing') if not isinstance(data, (bytes, bytearray, memoryview)): raise TypeError('data argument must be byte-ish (%r)', type(data)) @@ -721,7 +750,17 @@ def send_bytes(self, data): def close(self, *, code=1000, message=b''): if self._writer is None: raise RuntimeError('Call .start() first') - self._writer.close(code, message) + if not self._closing: + self._closing = True + self._writer.close(code, message) + else: + raise RuntimeError('Already closing') + + @asyncio.coroutine + def wait_closed(self): + if self._closing_fut is None: + raise RuntimeError('Call .start() first') + yield from self._closing_fut @asyncio.coroutine def receive(self): @@ -733,14 +772,29 @@ def receive(self): except Exception as exc: raise WebSocketClosed(code=None, message=str(exc)) from exc - if msg.tp == MSG_PING: - self._writer.pong() - elif msg.tp == MSG_CLOSE: - raise WebSocketClosed(msg.data, msg.extra) - elif msg.tp == MSG_TEXT: - return msg.data - elif msg.tp == MSG_BINARY: - return msg.data + if msg.tp == MSG_CLOSE: + if self._closing: + exc = WebSocketClosed(msg.data, msg.extra) + self._closing_fut.set_exception(exc) + raise exc + else: + self._closing = True + self._writer.close(msg.data, msg.extra) + yield from self.drain() + exc = WebSocketClosed(msg.data, msg.extra) + self._closing_fut.set_exception(exc) + raise exc + elif not self._closing: + if msg.tp == MSG_PING: + self._writer.pong() + elif msg.tp == MSG_TEXT: + return msg.data + elif msg.tp == MSG_BINARY: + return msg.data + + @asyncio.coroutine + def drain(self): + yield from self._resp_impl.transport.drain() def write(self, data): raise RuntimeError("Cannot call .write() for websocket") diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index f2213ae6611..a6cc44b4a72 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -15,15 +15,17 @@ def setUp(self): def tearDown(self): self.loop.close() - def make_request(self, method, path): + def make_request(self, method, path, headers=None): self.app = mock.Mock() - headers = MultiDict({'HOST': 'server.example.com', - 'UPGRADE': 'websocket', - 'CONNECTION': 'Upgrade', - 'SEC-WEBSOCKET-KEY': 'dGhlIHNhbXBsZSBub25jZQ==', - 'ORIGIN': 'http://example.com', - 'SEC-WEBSOCKET-PROTOCOL': 'chat, superchat', - 'SEC-WEBSOCKET-VERSION': '13'}) + if headers is None: + headers = MultiDict( + {'HOST': 'server.example.com', + 'UPGRADE': 'websocket', + 'CONNECTION': 'Upgrade', + 'SEC-WEBSOCKET-KEY': 'dGhlIHNhbXBsZSBub25jZQ==', + 'ORIGIN': 'http://example.com', + 'SEC-WEBSOCKET-PROTOCOL': 'chat, superchat', + 'SEC-WEBSOCKET-VERSION': '13'}) message = RawRequestMessage(method, path, HttpVersion11, headers, False, False) self.payload = mock.Mock() @@ -124,3 +126,71 @@ def throw(): self.assertIs(err, exc.exception.__cause__) self.loop.run_until_complete(go()) + + def test_can_start_ok(self): + req = self.make_request('GET', '/') + ws = WebSocketResponse(protocols=('chat',)) + self.assertEqual((True, 'chat'), ws.can_start(req)) + + def test_can_start_unknown_protocol(self): + req = self.make_request('GET', '/') + ws = WebSocketResponse() + self.assertEqual((True, None), ws.can_start(req)) + + def test_can_start_invalid_method(self): + req = self.make_request('POST', '/') + ws = WebSocketResponse() + self.assertEqual((False, None), ws.can_start(req)) + + def test_can_start_without_upgrade(self): + req = self.make_request('GET', '/', headers=MultiDict()) + ws = WebSocketResponse() + self.assertEqual((False, None), ws.can_start(req)) + + def test_can_start_started(self): + req = self.make_request('GET', '/') + ws = WebSocketResponse() + ws.start(req) + with self.assertRaisesRegex(RuntimeError, 'Already started'): + ws.can_start(req) + + def test_closing_after_ctor(self): + ws = WebSocketResponse() + self.assertFalse(ws.closing) + + def test_send_str_closing(self): + req = self.make_request('GET', '/') + ws = WebSocketResponse() + ws.start(req) + ws.close() + with self.assertRaises(RuntimeError): + ws.send_str('string') + + def test_send_bytes_closing(self): + req = self.make_request('GET', '/') + ws = WebSocketResponse() + ws.start(req) + ws.close() + with self.assertRaises(RuntimeError): + ws.send_bytes(b'bytes') + + def test_ping_closing(self): + req = self.make_request('GET', '/') + ws = WebSocketResponse() + ws.start(req) + ws.close() + with self.assertRaises(RuntimeError): + ws.ping() + + def test_double_close(self): + req = self.make_request('GET', '/') + ws = WebSocketResponse() + ws.start(req) + writer = mock.Mock() + ws._writer = writer + ws.close(code=1, message='message1') + self.assertTrue(ws.closing) + with self.assertRaisesRegex(RuntimeError, 'Already closing'): + ws.close(code=2, message='message2') + self.assertTrue(ws.closing) + writer.close.assert_called_once_with(1, 'message1') From 94d79e3f6ff1f26dad5c99bb3ef67446cb2a2af4 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 3 Jan 2015 18:29:19 +0200 Subject: [PATCH 13/21] Convert internal HttpProcessingError to HTTPException-derived classes, inherit WebClosedError from DisconnectedError instead of GeneratorExit --- aiohttp/web.py | 25 ++++++++++++++++--------- tests/test_web_websocket.py | 15 ++++++++++++++- 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/aiohttp/web.py b/aiohttp/web.py index eb0aae09225..6a406cee15d 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -22,9 +22,8 @@ from .protocol import Response as ResponseImpl, HttpVersion, HttpVersion11 from .server import ServerHttpProtocol from .streams import EOF_MARKER -from .websocket import (do_handshake, WebSocketError, - MSG_BINARY, MSG_CLOSE, MSG_PING, MSG_TEXT) -from .errors import HttpProcessingError +from .websocket import do_handshake, MSG_BINARY, MSG_CLOSE, MSG_PING, MSG_TEXT +from .errors import HttpProcessingError, DisconnectedError __all__ = [ @@ -86,7 +85,7 @@ ] -class WebSocketClosed(GeneratorExit): +class WebSocketClosed(DisconnectedError): """Raised on closing websocket by peer.""" def __init__(self, code=None, message=None): @@ -682,9 +681,17 @@ def start(self, request): if resp_impl is not None: return resp_impl - status, headers, parser, writer, protocol = do_handshake( - request.method, request.headers, request.transport, - self._protocols) + try: + status, headers, parser, writer, protocol = do_handshake( + request.method, request.headers, request.transport, + self._protocols) + except HttpProcessingError as err: + if err.code == 405: + raise HTTPMethodNotAllowed(request.method, ['GET']) + elif err.code == 400: + raise HTTPBadRequest(text=err.message) + else: # pragma: no cover + raise HTTPInternalServerError() from err if self.status != status: self.set_status(status) @@ -708,7 +715,7 @@ def can_start(self, request): _, _, _, _, protocol = do_handshake( request.method, request.headers, request.transport, self._protocols) - except (WebSocketError, HttpProcessingError): + except HttpProcessingError: return False, None else: return True, protocol @@ -1382,7 +1389,7 @@ def handle_request(self, message, payload): # log access self.log_access(message, None, resp_msg, self._loop.time() - now) - except GeneratorExit: + except DisconnectedError: # the HTTP protocol is probably in invalid state, close connection self.transport.close() # log access diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index a6cc44b4a72..9a802c02388 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -2,7 +2,8 @@ import unittest from unittest import mock from aiohttp.multidict import MultiDict -from aiohttp.web import Request, WebSocketResponse, WebSocketClosed +from aiohttp.web import (Request, WebSocketResponse, WebSocketClosed, + HTTPMethodNotAllowed, HTTPBadRequest) from aiohttp.protocol import RawRequestMessage, HttpVersion11 @@ -194,3 +195,15 @@ def test_double_close(self): ws.close(code=2, message='message2') self.assertTrue(ws.closing) writer.close.assert_called_once_with(1, 'message1') + + def test_start_invalid_method(self): + req = self.make_request('POST', '/') + ws = WebSocketResponse() + with self.assertRaises(HTTPMethodNotAllowed): + ws.start(req) + + def test_start_without_upgrade(self): + req = self.make_request('GET', '/', headers=MultiDict()) + ws = WebSocketResponse() + with self.assertRaises(HTTPBadRequest): + ws.start(req) From 7bbb7e68d0b51f539c8cef9d4172bfd96d26237a Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 3 Jan 2015 19:48:17 +0200 Subject: [PATCH 14/21] Test closing handshake --- aiohttp/web.py | 10 +-- tests/test_web_websocket_functional.py | 89 ++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 4 deletions(-) diff --git a/aiohttp/web.py b/aiohttp/web.py index 6a406cee15d..2b545600c07 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -593,6 +593,12 @@ def write(self, data): else: return () + @asyncio.coroutine + def drain(self): + if self._resp_impl is None: + raise RuntimeError("Response has not been started") + yield from self._resp_impl.transport.drain() + @asyncio.coroutine def write_eof(self): if self._eof_sent: @@ -799,10 +805,6 @@ def receive(self): elif msg.tp == MSG_BINARY: return msg.data - @asyncio.coroutine - def drain(self): - yield from self._resp_impl.transport.drain() - def write(self, data): raise RuntimeError("Cannot call .write() for websocket") diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index 0ce0911aa33..4ca1b82f81c 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -236,3 +236,92 @@ def go(): yield from closed self.loop.run_until_complete(go()) + + def test_server_close_handshake(self): + + closed = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse(protocols=('foo', 'bar')) + ws.start(request) + ws.close() + try: + yield from ws.receive() + except web.WebSocketClosed: + closed.set_result(None) + return ws + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + reader, writer = yield from self.connect_ws(url, 'eggs, bar') + + msg = yield from reader.read() + self.assertEqual(msg.tp, websocket.MSG_CLOSE) + writer.close() + yield from closed + + self.loop.run_until_complete(go()) + + def test_client_close_handshake(self): + + closed = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse(protocols=('foo', 'bar')) + ws.start(request) + try: + yield from ws.receive() + except web.WebSocketClosed: + closed.set_result(None) + return ws + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + reader, writer = yield from self.connect_ws(url, 'eggs, bar') + + writer.close() + msg = yield from reader.read() + self.assertEqual(msg.tp, websocket.MSG_CLOSE) + yield from closed + + self.loop.run_until_complete(go()) + + def test_server_close_handshake_by_onother_task(self): + + closed = asyncio.Future(loop=self.loop) + closed2 = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def closer(ws): + ws.close() + try: + yield from ws.wait_closed() + except web.WebSocketClosed: + closed2.set_result(None) + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse(protocols=('foo', 'bar')) + ws.start(request) + asyncio.async(closer(ws), loop=request.app.loop) + try: + yield from ws.receive() + except web.WebSocketClosed: + closed.set_result(None) + return ws + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + reader, writer = yield from self.connect_ws(url, 'eggs, bar') + + msg = yield from reader.read() + self.assertEqual(msg.tp, websocket.MSG_CLOSE) + writer.close() + yield from asyncio.gather(closed, closed2, loop=self.loop) + + self.loop.run_until_complete(go()) From 5169738568a48632e1393434f1e5974ef6c8161a Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 3 Jan 2015 19:55:03 +0200 Subject: [PATCH 15/21] Force non-keep-alive for websocket connections --- aiohttp/web.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aiohttp/web.py b/aiohttp/web.py index 2b545600c07..7c82958faf8 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -703,6 +703,7 @@ def start(self, request): self.set_status(status) for k, v in headers: self.headers[k] = v + self.force_close() resp_impl = super().start(request) From 96f39becc6cff6603c6c57984ce6e70a162fa07e Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 3 Jan 2015 19:55:37 +0200 Subject: [PATCH 16/21] Dont catch internal errors from socket stream reader --- aiohttp/web.py | 5 +---- tests/test_web_websocket.py | 25 ------------------------- 2 files changed, 1 insertion(+), 29 deletions(-) diff --git a/aiohttp/web.py b/aiohttp/web.py index 7c82958faf8..c786efb475e 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -781,10 +781,7 @@ def receive(self): if self._reader is None: raise RuntimeError('Call .start() first') while True: - try: - msg = yield from self._reader.read() - except Exception as exc: - raise WebSocketClosed(code=None, message=str(exc)) from exc + msg = yield from self._reader.read() if msg.tp == MSG_CLOSE: if self._closing: diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index 9a802c02388..644bed4cab2 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -103,31 +103,6 @@ def c(): with self.assertRaises(WebSocketClosed): self.loop.run_until_complete(c()) - def test_exception_in_receive(self): - - @asyncio.coroutine - def go(): - req = self.make_request('GET', '/') - ws = WebSocketResponse() - ws.start(req) - - err = RuntimeError("error") - - @asyncio.coroutine - def throw(): - raise err - - ws._reader.read = throw - - with self.assertRaises(WebSocketClosed) as exc: - yield from ws.receive() - - self.assertEqual("error", exc.exception.message) - self.assertIsNone(exc.exception.code) - self.assertIs(err, exc.exception.__cause__) - - self.loop.run_until_complete(go()) - def test_can_start_ok(self): req = self.make_request('GET', '/') ws = WebSocketResponse(protocols=('chat',)) From 806ef8fcc2f4fc408cf71ee14a0678edf243b60e Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 3 Jan 2015 20:18:46 +0200 Subject: [PATCH 17/21] Add test for wait_closing() before start() --- tests/test_web_websocket.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index 644bed4cab2..66153f51340 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -182,3 +182,13 @@ def test_start_without_upgrade(self): ws = WebSocketResponse() with self.assertRaises(HTTPBadRequest): ws.start(req) + + def test_wait_closed_before_start(self): + + @asyncio.coroutine + def go(): + ws = WebSocketResponse() + with self.assertRaises(RuntimeError): + yield from ws.wait_closed() + + self.loop.run_until_complete(go()) From e60437d7418c9a2006466bcf610d64bb9eb97d87 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 3 Jan 2015 20:20:59 +0200 Subject: [PATCH 18/21] Add test for drain() before start() --- tests/test_web_response.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_web_response.py b/tests/test_web_response.py index 4763012021f..110230bb479 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -436,3 +436,13 @@ def test_started_when_started(self): resp = StreamResponse() resp.start(self.make_request('GET', '/')) self.assertTrue(resp.started) + + def test_drain_before_start(self): + + @asyncio.coroutine + def go(): + resp = StreamResponse() + with self.assertRaises(RuntimeError): + yield from resp.drain() + + self.loop.run_until_complete(go()) From c1168ad7ce08d32917c9ea063019f1e96f35350a Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 3 Jan 2015 20:31:39 +0200 Subject: [PATCH 19/21] Fix typo --- tests/test_web_websocket_functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index 4ca1b82f81c..877b14ee8bb 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -290,7 +290,7 @@ def go(): self.loop.run_until_complete(go()) - def test_server_close_handshake_by_onother_task(self): + def test_server_close_handshake_by_another_task(self): closed = asyncio.Future(loop=self.loop) closed2 = asyncio.Future(loop=self.loop) From 1ea6bb84aef09d81ad9b1c1d11b40ad0b3af5d17 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 3 Jan 2015 20:37:53 +0200 Subject: [PATCH 20/21] Add check for eating messages after close --- tests/test_web_websocket_functional.py | 32 ++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index 877b14ee8bb..0ed822dc964 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -325,3 +325,35 @@ def go(): yield from asyncio.gather(closed, closed2, loop=self.loop) self.loop.run_until_complete(go()) + + def test_server_close_handshake_server_eats_client_messages(self): + + closed = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse(protocols=('foo', 'bar')) + ws.start(request) + ws.close() + try: + yield from ws.receive() + except web.WebSocketClosed: + closed.set_result(None) + return ws + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + reader, writer = yield from self.connect_ws(url, 'eggs, bar') + + msg = yield from reader.read() + self.assertEqual(msg.tp, websocket.MSG_CLOSE) + + writer.send('text') + writer.send(b'bytes', binary=True) + writer.ping() + + writer.close() + yield from closed + + self.loop.run_until_complete(go()) From d0d6f6ad02f563e6f14511c03e42b09782359997 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 3 Jan 2015 20:54:39 +0200 Subject: [PATCH 21/21] Move WebSocketClosed to errors.py, rename it to WebSocketDisconnected, get rid of checking for disconnection in web.RequestHandler --- aiohttp/errors.py | 17 +++++- aiohttp/web.py | 76 +++++++++----------------- tests/test_web_websocket.py | 6 +- tests/test_web_websocket_functional.py | 14 ++--- 4 files changed, 53 insertions(+), 60 deletions(-) diff --git a/aiohttp/errors.py b/aiohttp/errors.py index 2974f9049a6..570b23dadc2 100644 --- a/aiohttp/errors.py +++ b/aiohttp/errors.py @@ -9,7 +9,7 @@ 'ClientError', 'ClientHttpProcessingError', 'ClientConnectionError', 'ClientOSError', 'ClientTimeoutError', 'ProxyConnectionError', - 'ClientRequestError', 'ClientResponseError'] + 'ClientRequestError', 'ClientResponseError', 'WebSocketDisconnected'] from asyncio import TimeoutError @@ -26,6 +26,21 @@ class ServerDisconnectedError(DisconnectedError): """Server disconnected.""" +class WebSocketDisconnected(ClientDisconnectedError): + """Raised on closing websocket by peer.""" + + def __init__(self, code=None, message=None): + super().__init__(code, message) + + @property + def code(self): + return self.args[0] + + @property + def message(self): + return self.args[1] + + class ClientError(Exception): """Base class for client connection errors.""" diff --git a/aiohttp/web.py b/aiohttp/web.py index c786efb475e..0b1813ed150 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -23,11 +23,11 @@ from .server import ServerHttpProtocol from .streams import EOF_MARKER from .websocket import do_handshake, MSG_BINARY, MSG_CLOSE, MSG_PING, MSG_TEXT -from .errors import HttpProcessingError, DisconnectedError +from .errors import HttpProcessingError, WebSocketDisconnected __all__ = [ - 'WebSocketClosed', + 'WebSocketDisconnected', 'Application', 'HttpVersion', 'RequestHandler', @@ -85,21 +85,6 @@ ] -class WebSocketClosed(DisconnectedError): - """Raised on closing websocket by peer.""" - - def __init__(self, code=None, message=None): - super().__init__(code, message) - - @property - def code(self): - return self.args[0] - - @property - def message(self): - return self.args[1] - - sentinel = object() @@ -785,14 +770,14 @@ def receive(self): if msg.tp == MSG_CLOSE: if self._closing: - exc = WebSocketClosed(msg.data, msg.extra) + exc = WebSocketDisconnected(msg.data, msg.extra) self._closing_fut.set_exception(exc) raise exc else: self._closing = True self._writer.close(msg.data, msg.extra) yield from self.drain() - exc = WebSocketClosed(msg.data, msg.extra) + exc = WebSocketDisconnected(msg.data, msg.extra) self._closing_fut.set_exception(exc) raise exc elif not self._closing: @@ -1358,42 +1343,35 @@ def handle_request(self, message, payload): self.transport, self.reader, self.writer, self.keep_alive_timeout) try: - try: - match_info = yield from self._router.resolve(request) - - assert isinstance(match_info, AbstractMatchInfo), match_info + match_info = yield from self._router.resolve(request) - request._match_info = match_info - handler = match_info.handler + assert isinstance(match_info, AbstractMatchInfo), match_info - for factory in reversed(self._middlewares): - handler = yield from factory(app, handler) - resp = yield from handler(request) + request._match_info = match_info + handler = match_info.handler - if not isinstance(resp, StreamResponse): - raise RuntimeError( - ("Handler {!r} should return response instance, " - "got {!r} [middlewares {!r}]").format( - match_info.handler, - type(resp), - self._middlewares)) - except HTTPException as exc: - resp = exc + for factory in reversed(self._middlewares): + handler = yield from factory(app, handler) + resp = yield from handler(request) - resp_msg = resp.start(request) - yield from resp.write_eof() - - # notify server about keep-alive - self.keep_alive(resp_msg.keep_alive()) + if not isinstance(resp, StreamResponse): + raise RuntimeError( + ("Handler {!r} should return response instance, " + "got {!r} [middlewares {!r}]").format( + match_info.handler, + type(resp), + self._middlewares)) + except HTTPException as exc: + resp = exc + + resp_msg = resp.start(request) + yield from resp.write_eof() - # log access - self.log_access(message, None, resp_msg, self._loop.time() - now) + # notify server about keep-alive + self.keep_alive(resp_msg.keep_alive()) - except DisconnectedError: - # the HTTP protocol is probably in invalid state, close connection - self.transport.close() - # log access - self.log_access(message, None, None, self._loop.time() - now) + # log access + self.log_access(message, None, resp_msg, self._loop.time() - now) class RequestHandlerFactory: diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index 66153f51340..fad3131cd7e 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -2,7 +2,7 @@ import unittest from unittest import mock from aiohttp.multidict import MultiDict -from aiohttp.web import (Request, WebSocketResponse, WebSocketClosed, +from aiohttp.web import (Request, WebSocketResponse, WebSocketDisconnected, HTTPMethodNotAllowed, HTTPBadRequest) from aiohttp.protocol import RawRequestMessage, HttpVersion11 @@ -90,7 +90,7 @@ def test_nested_exception(self): @asyncio.coroutine def a(): - raise WebSocketClosed() + raise WebSocketDisconnected() @asyncio.coroutine def b(): @@ -100,7 +100,7 @@ def b(): def c(): yield from b() - with self.assertRaises(WebSocketClosed): + with self.assertRaises(WebSocketDisconnected): self.loop.run_until_complete(c()) def test_can_start_ok(self): diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index 0ed822dc964..e94d3a31904 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -145,7 +145,7 @@ def handler(request): try: yield from ws.receive() - except web.WebSocketClosed as exc: + except web.WebSocketDisconnected as exc: self.assertEqual(1, exc.code) self.assertEqual(b'exit message', exc.message) closed.set_result(None) @@ -175,7 +175,7 @@ def handler(request): ws.ping() try: yield from ws.receive() - except web.WebSocketClosed as exc: + except web.WebSocketDisconnected as exc: self.assertEqual(2, exc.code) self.assertEqual(b'exit message', exc.message) closed.set_result(None) @@ -248,7 +248,7 @@ def handler(request): ws.close() try: yield from ws.receive() - except web.WebSocketClosed: + except web.WebSocketDisconnected: closed.set_result(None) return ws @@ -274,7 +274,7 @@ def handler(request): ws.start(request) try: yield from ws.receive() - except web.WebSocketClosed: + except web.WebSocketDisconnected: closed.set_result(None) return ws @@ -300,7 +300,7 @@ def closer(ws): ws.close() try: yield from ws.wait_closed() - except web.WebSocketClosed: + except web.WebSocketDisconnected: closed2.set_result(None) @asyncio.coroutine @@ -310,7 +310,7 @@ def handler(request): asyncio.async(closer(ws), loop=request.app.loop) try: yield from ws.receive() - except web.WebSocketClosed: + except web.WebSocketDisconnected: closed.set_result(None) return ws @@ -337,7 +337,7 @@ def handler(request): ws.close() try: yield from ws.receive() - except web.WebSocketClosed: + except web.WebSocketDisconnected: closed.set_result(None) return ws