Skip to content

Commit

Permalink
Merge pull request #220 from KeepSafe/websockets
Browse files Browse the repository at this point in the history
Implement aiohttp.web websockets
  • Loading branch information
asvetlov committed Jan 3, 2015
2 parents 3b2f9be + d0d6f6a commit 04e9620
Show file tree
Hide file tree
Showing 8 changed files with 757 additions and 25 deletions.
17 changes: 16 additions & 1 deletion aiohttp/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

'ClientError', 'ClientHttpProcessingError', 'ClientConnectionError',
'ClientOSError', 'ClientTimeoutError', 'ProxyConnectionError',
'ClientRequestError', 'ClientResponseError']
'ClientRequestError', 'ClientResponseError', 'WebSocketDisconnected']

from asyncio import TimeoutError

Expand All @@ -26,6 +26,21 @@ class ServerDisconnectedError(DisconnectedError):
"""Server disconnected."""


class WebSocketDisconnected(ClientDisconnectedError):
"""Raised on closing websocket by peer."""

def __init__(self, code=None, message=None):
super().__init__(code, message)

@property
def code(self):
return self.args[0]

@property
def message(self):
return self.args[1]


class ClientError(Exception):
"""Base class for client connection errors."""

Expand Down
186 changes: 167 additions & 19 deletions aiohttp/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,20 @@
from .protocol import Response as ResponseImpl, HttpVersion, HttpVersion11
from .server import ServerHttpProtocol
from .streams import EOF_MARKER
from .websocket import do_handshake, MSG_BINARY, MSG_CLOSE, MSG_PING, MSG_TEXT
from .errors import HttpProcessingError, WebSocketDisconnected


__all__ = [
'WebSocketDisconnected',
'Application',
'HttpVersion',
'RequestHandler',
'RequestHandlerFactory',
'Request',
'StreamResponse',
'Response',
'WebSocketResponse',
'UrlDispatcher',
'UrlMappingMatchInfo',
'HTTPException',
Expand Down Expand Up @@ -135,11 +139,12 @@ def content_length(self):

class Request(HeadersMixin):

def __init__(self, app, message, payload, transport, writer,
def __init__(self, app, message, payload, transport, reader, writer,
keep_alive_timeout):
self._app = app
self._version = message.version
self._transport = transport
self._reader = reader
self._writer = writer
self._method = message.method
self._host = message.headers.get('HOST')
Expand Down Expand Up @@ -374,17 +379,6 @@ def post(self):
self._post = MultiDict(out.items(getall=True))
return self._post

# @asyncio.coroutine
# def start_websocket(self):
# """Upgrade connection to websocket.

# Returns (reader, writer) pair.
# """

# upgrade = 'websocket' in message.headers.get('UPGRADE', '').lower()
# if not upgrade:
# pass


############################################################
# HTTP Response classes
Expand Down Expand Up @@ -532,15 +526,22 @@ def _generate_content_type_header(self):
ctype = self._content_type
self.headers['Content-Type'] = ctype

def start(self, request):
def _start_pre_check(self, request):
if self._resp_impl is not None:
if self._req is not request:
raise RuntimeError(
'Response has been started with different request.')
return self._resp_impl
else:
return self._resp_impl
else:
return None

self._req = request
def start(self, request):
resp_impl = self._start_pre_check(request)
if resp_impl is not None:
return resp_impl

self._req = request
keep_alive = self._keep_alive
if keep_alive is None:
keep_alive = request.keep_alive
Expand Down Expand Up @@ -577,6 +578,12 @@ def write(self, data):
else:
return ()

@asyncio.coroutine
def drain(self):
if self._resp_impl is None:
raise RuntimeError("Response has not been started")
yield from self._resp_impl.transport.drain()

@asyncio.coroutine
def write_eof(self):
if self._eof_sent:
Expand Down Expand Up @@ -647,6 +654,144 @@ def write_eof(self):
yield from super().write_eof()


class WebSocketResponse(StreamResponse):

def __init__(self, *, protocols=()):
super().__init__(status=101)
self._protocols = protocols
self._protocol = None
self._writer = None
self._reader = None
self._closing = False
self._loop = None
self._closing_fut = None

def start(self, request):
# make pre-check to don't hide it by do_handshake() exceptions
resp_impl = self._start_pre_check(request)
if resp_impl is not None:
return resp_impl

try:
status, headers, parser, writer, protocol = do_handshake(
request.method, request.headers, request.transport,
self._protocols)
except HttpProcessingError as err:
if err.code == 405:
raise HTTPMethodNotAllowed(request.method, ['GET'])
elif err.code == 400:
raise HTTPBadRequest(text=err.message)
else: # pragma: no cover
raise HTTPInternalServerError() from err

if self.status != status:
self.set_status(status)
for k, v in headers:
self.headers[k] = v
self.force_close()

resp_impl = super().start(request)

self._reader = request._reader.set_parser(parser)
self._writer = writer
self._protocol = protocol
self._loop = request.app.loop
self._closing_fut = asyncio.Future(loop=self._loop)

return resp_impl

def can_start(self, request):
if self._writer is not None:
raise RuntimeError('Already started')
try:
_, _, _, _, protocol = do_handshake(
request.method, request.headers, request.transport,
self._protocols)
except HttpProcessingError:
return False, None
else:
return True, protocol

@property
def closing(self):
return self._closing

@property
def protocol(self):
return self._protocol

def ping(self):
if self._writer is None:
raise RuntimeError('Call .start() first')
if self._closing:
raise RuntimeError('websocket connection is closing')
self._writer.ping()

def send_str(self, data):
if self._writer is None:
raise RuntimeError('Call .start() first')
if self._closing:
raise RuntimeError('websocket connection is closing')
if not isinstance(data, str):
raise TypeError('data argument must be str (%r)', type(data))
self._writer.send(data, binary=False)

def send_bytes(self, data):
if self._writer is None:
raise RuntimeError('Call .start() first')
if self._closing:
raise RuntimeError('websocket connection is closing')
if not isinstance(data, (bytes, bytearray, memoryview)):
raise TypeError('data argument must be byte-ish (%r)',
type(data))
self._writer.send(data, binary=True)

def close(self, *, code=1000, message=b''):
if self._writer is None:
raise RuntimeError('Call .start() first')
if not self._closing:
self._closing = True
self._writer.close(code, message)
else:
raise RuntimeError('Already closing')

@asyncio.coroutine
def wait_closed(self):
if self._closing_fut is None:
raise RuntimeError('Call .start() first')
yield from self._closing_fut

@asyncio.coroutine
def receive(self):
if self._reader is None:
raise RuntimeError('Call .start() first')
while True:
msg = yield from self._reader.read()

if msg.tp == MSG_CLOSE:
if self._closing:
exc = WebSocketDisconnected(msg.data, msg.extra)
self._closing_fut.set_exception(exc)
raise exc
else:
self._closing = True
self._writer.close(msg.data, msg.extra)
yield from self.drain()
exc = WebSocketDisconnected(msg.data, msg.extra)
self._closing_fut.set_exception(exc)
raise exc
elif not self._closing:
if msg.tp == MSG_PING:
self._writer.pong()
elif msg.tp == MSG_TEXT:
return msg.data
elif msg.tp == MSG_BINARY:
return msg.data

def write(self, data):
raise RuntimeError("Cannot call .write() for websocket")


############################################################
# HTTP Exceptions
############################################################
Expand Down Expand Up @@ -1195,7 +1340,8 @@ def handle_request(self, message, payload):

app = self._app
request = Request(app, message, payload,
self.transport, self.writer, self.keep_alive_timeout)
self.transport, self.reader, self.writer,
self.keep_alive_timeout)
try:
match_info = yield from self._router.resolve(request)

Expand All @@ -1210,9 +1356,11 @@ def handle_request(self, message, payload):

if not isinstance(resp, StreamResponse):
raise RuntimeError(
("Handler {!r} should return response instance, got {!r} "
"[middlewares {!r}]")
.format(match_info.handler, type(resp), self._middlewares))
("Handler {!r} should return response instance, "
"got {!r} [middlewares {!r}]").format(
match_info.handler,
type(resp),
self._middlewares))
except HTTPException as exc:
resp = exc

Expand Down
3 changes: 2 additions & 1 deletion tests/test_urldispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ def make_request(self, method, path):
MultiDict(), False, False)
self.payload = mock.Mock()
self.transport = mock.Mock()
self.reader = mock.Mock()
self.writer = mock.Mock()
req = Request(self.app, message, self.payload,
self.transport, self.writer, 15)
self.transport, self.reader, self.writer, 15)
return req

def test_add_route_root(self):
Expand Down
3 changes: 2 additions & 1 deletion tests/test_web_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def setUp(self):

self.payload = mock.Mock()
self.transport = mock.Mock()
self.reader = mock.Mock()
self.writer = mock.Mock()
self.writer.drain.return_value = ()
self.buf = b''
Expand All @@ -34,7 +35,7 @@ def make_request(self, method='GET', path='/', headers=MultiDict()):
message = RawRequestMessage(method, path, HttpVersion11, headers,
False, False)
req = Request(self.app, message, self.payload,
self.transport, self.writer, 15)
self.transport, self.reader, self.writer, 15)
return req

def test_all_http_exceptions_exported(self):
Expand Down
4 changes: 3 additions & 1 deletion tests/test_web_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ def make_request(self, method, path, headers=MultiDict(), *,
self.payload = mock.Mock()
self.transport = mock.Mock()
self.writer = mock.Mock()
self.reader = mock.Mock()
req = Request(self.app, message, self.payload,
self.transport, self.writer, keep_alive_timeout)
self.transport, self.reader, self.writer,
keep_alive_timeout)
return req

def test_ctor(self):
Expand Down
16 changes: 14 additions & 2 deletions tests/test_web_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ def make_request(self, method, path, headers=MultiDict()):
False, False)
self.payload = mock.Mock()
self.transport = mock.Mock()
self.reader = mock.Mock()
self.writer = mock.Mock()
req = Request(self.app, message, self.payload,
self.transport, self.writer, 15)
self.transport, self.reader, self.writer, 15)
return req

def test_ctor(self):
Expand Down Expand Up @@ -276,9 +277,10 @@ def make_request(self, method, path, headers=MultiDict()):
False, False)
self.payload = mock.Mock()
self.transport = mock.Mock()
self.reader = mock.Mock()
self.writer = mock.Mock()
req = Request(self.app, message, self.payload,
self.transport, self.writer, 15)
self.transport, self.reader, self.writer, 15)
return req

def test_ctor(self):
Expand Down Expand Up @@ -434,3 +436,13 @@ def test_started_when_started(self):
resp = StreamResponse()
resp.start(self.make_request('GET', '/'))
self.assertTrue(resp.started)

def test_drain_before_start(self):

@asyncio.coroutine
def go():
resp = StreamResponse()
with self.assertRaises(RuntimeError):
yield from resp.drain()

self.loop.run_until_complete(go())
Loading

0 comments on commit 04e9620

Please sign in to comment.