Skip to content

Commit

Permalink
Move WebSocketClosed to errors.py, rename it to WebSocketDisconnected…
Browse files Browse the repository at this point in the history
…, get rid of checking for disconnection in web.RequestHandler
  • Loading branch information
asvetlov committed Jan 3, 2015
1 parent 1ea6bb8 commit d0d6f6a
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 60 deletions.
17 changes: 16 additions & 1 deletion aiohttp/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

'ClientError', 'ClientHttpProcessingError', 'ClientConnectionError',
'ClientOSError', 'ClientTimeoutError', 'ProxyConnectionError',
'ClientRequestError', 'ClientResponseError']
'ClientRequestError', 'ClientResponseError', 'WebSocketDisconnected']

from asyncio import TimeoutError

Expand All @@ -26,6 +26,21 @@ class ServerDisconnectedError(DisconnectedError):
"""Server disconnected."""


class WebSocketDisconnected(ClientDisconnectedError):
"""Raised on closing websocket by peer."""

def __init__(self, code=None, message=None):
super().__init__(code, message)

@property
def code(self):
return self.args[0]

@property
def message(self):
return self.args[1]


class ClientError(Exception):
"""Base class for client connection errors."""

Expand Down
76 changes: 27 additions & 49 deletions aiohttp/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
from .server import ServerHttpProtocol
from .streams import EOF_MARKER
from .websocket import do_handshake, MSG_BINARY, MSG_CLOSE, MSG_PING, MSG_TEXT
from .errors import HttpProcessingError, DisconnectedError
from .errors import HttpProcessingError, WebSocketDisconnected


__all__ = [
'WebSocketClosed',
'WebSocketDisconnected',
'Application',
'HttpVersion',
'RequestHandler',
Expand Down Expand Up @@ -85,21 +85,6 @@
]


class WebSocketClosed(DisconnectedError):
"""Raised on closing websocket by peer."""

def __init__(self, code=None, message=None):
super().__init__(code, message)

@property
def code(self):
return self.args[0]

@property
def message(self):
return self.args[1]


sentinel = object()


Expand Down Expand Up @@ -785,14 +770,14 @@ def receive(self):

if msg.tp == MSG_CLOSE:
if self._closing:
exc = WebSocketClosed(msg.data, msg.extra)
exc = WebSocketDisconnected(msg.data, msg.extra)
self._closing_fut.set_exception(exc)
raise exc
else:
self._closing = True
self._writer.close(msg.data, msg.extra)
yield from self.drain()
exc = WebSocketClosed(msg.data, msg.extra)
exc = WebSocketDisconnected(msg.data, msg.extra)
self._closing_fut.set_exception(exc)
raise exc
elif not self._closing:
Expand Down Expand Up @@ -1358,42 +1343,35 @@ def handle_request(self, message, payload):
self.transport, self.reader, self.writer,
self.keep_alive_timeout)
try:
try:
match_info = yield from self._router.resolve(request)

assert isinstance(match_info, AbstractMatchInfo), match_info
match_info = yield from self._router.resolve(request)

request._match_info = match_info
handler = match_info.handler
assert isinstance(match_info, AbstractMatchInfo), match_info

for factory in reversed(self._middlewares):
handler = yield from factory(app, handler)
resp = yield from handler(request)
request._match_info = match_info
handler = match_info.handler

if not isinstance(resp, StreamResponse):
raise RuntimeError(
("Handler {!r} should return response instance, "
"got {!r} [middlewares {!r}]").format(
match_info.handler,
type(resp),
self._middlewares))
except HTTPException as exc:
resp = exc
for factory in reversed(self._middlewares):
handler = yield from factory(app, handler)
resp = yield from handler(request)

resp_msg = resp.start(request)
yield from resp.write_eof()

# notify server about keep-alive
self.keep_alive(resp_msg.keep_alive())
if not isinstance(resp, StreamResponse):
raise RuntimeError(
("Handler {!r} should return response instance, "
"got {!r} [middlewares {!r}]").format(
match_info.handler,
type(resp),
self._middlewares))
except HTTPException as exc:
resp = exc

resp_msg = resp.start(request)
yield from resp.write_eof()

# log access
self.log_access(message, None, resp_msg, self._loop.time() - now)
# notify server about keep-alive
self.keep_alive(resp_msg.keep_alive())

except DisconnectedError:
# the HTTP protocol is probably in invalid state, close connection
self.transport.close()
# log access
self.log_access(message, None, None, self._loop.time() - now)
# log access
self.log_access(message, None, resp_msg, self._loop.time() - now)


class RequestHandlerFactory:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_web_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import unittest
from unittest import mock
from aiohttp.multidict import MultiDict
from aiohttp.web import (Request, WebSocketResponse, WebSocketClosed,
from aiohttp.web import (Request, WebSocketResponse, WebSocketDisconnected,
HTTPMethodNotAllowed, HTTPBadRequest)
from aiohttp.protocol import RawRequestMessage, HttpVersion11

Expand Down Expand Up @@ -90,7 +90,7 @@ def test_nested_exception(self):

@asyncio.coroutine
def a():
raise WebSocketClosed()
raise WebSocketDisconnected()

@asyncio.coroutine
def b():
Expand All @@ -100,7 +100,7 @@ def b():
def c():
yield from b()

with self.assertRaises(WebSocketClosed):
with self.assertRaises(WebSocketDisconnected):
self.loop.run_until_complete(c())

def test_can_start_ok(self):
Expand Down
14 changes: 7 additions & 7 deletions tests/test_web_websocket_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def handler(request):

try:
yield from ws.receive()
except web.WebSocketClosed as exc:
except web.WebSocketDisconnected as exc:
self.assertEqual(1, exc.code)
self.assertEqual(b'exit message', exc.message)
closed.set_result(None)
Expand Down Expand Up @@ -175,7 +175,7 @@ def handler(request):
ws.ping()
try:
yield from ws.receive()
except web.WebSocketClosed as exc:
except web.WebSocketDisconnected as exc:
self.assertEqual(2, exc.code)
self.assertEqual(b'exit message', exc.message)
closed.set_result(None)
Expand Down Expand Up @@ -248,7 +248,7 @@ def handler(request):
ws.close()
try:
yield from ws.receive()
except web.WebSocketClosed:
except web.WebSocketDisconnected:
closed.set_result(None)
return ws

Expand All @@ -274,7 +274,7 @@ def handler(request):
ws.start(request)
try:
yield from ws.receive()
except web.WebSocketClosed:
except web.WebSocketDisconnected:
closed.set_result(None)
return ws

Expand All @@ -300,7 +300,7 @@ def closer(ws):
ws.close()
try:
yield from ws.wait_closed()
except web.WebSocketClosed:
except web.WebSocketDisconnected:
closed2.set_result(None)

@asyncio.coroutine
Expand All @@ -310,7 +310,7 @@ def handler(request):
asyncio.async(closer(ws), loop=request.app.loop)
try:
yield from ws.receive()
except web.WebSocketClosed:
except web.WebSocketDisconnected:
closed.set_result(None)
return ws

Expand All @@ -337,7 +337,7 @@ def handler(request):
ws.close()
try:
yield from ws.receive()
except web.WebSocketClosed:
except web.WebSocketDisconnected:
closed.set_result(None)
return ws

Expand Down

0 comments on commit d0d6f6a

Please sign in to comment.