Skip to content

Commit

Permalink
on_connection callback for ws site (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
ovv authored Mar 6, 2019
1 parent 67b89e6 commit 8132537
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 54 deletions.
5 changes: 5 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ For more examples see the `examples folder <https://github.com/eyepea/pillars/tr
Changelog
---------

0.4.1
`````

* Add `on_connection` callback to websocket site

0.4.0
`````

Expand Down
2 changes: 1 addition & 1 deletion pillars/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from dataclasses import dataclass
from typing import Any, Callable, List, Optional

import setproctitle
from aiohttp import signals
from aiohttp.web_runner import BaseRunner, BaseSite
import setproctitle

from . import exceptions

Expand Down
12 changes: 11 additions & 1 deletion pillars/sites/websocket.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import logging
from typing import Optional, Union
from typing import Awaitable, Callable, Optional, Union

import aiohttp
import aiohttp.http_websocket
Expand Down Expand Up @@ -88,6 +88,7 @@ def __init__(
*,
shutdown_timeout: float = 60.0,
session: aiohttp.ClientSession = None,
on_connection: Optional[Callable[[], Awaitable[None]]] = None,
) -> None:
super().__init__(runner, shutdown_timeout=shutdown_timeout)
self._url = url
Expand All @@ -98,6 +99,7 @@ def __init__(
self._transport: Optional[WSTransport] = None
self._closing = False
self._protocol_type = ProtocolType.WS
self._on_connection = on_connection

@property
def name(self) -> str:
Expand All @@ -122,6 +124,7 @@ async def _ws_connection(self) -> None:
async with self._session.ws_connect(self._url) as ws:
self._transport._ws = ws
self._protocol.connection_made(self._transport)
asyncio.create_task(self._connected())
async for message in ws:
LOG.log(2, "Data received: %s", message)
self._protocol.message_received(
Expand All @@ -147,3 +150,10 @@ async def status(self) -> bool:
return await self._transport.status()
else:
return False

async def _connected(self) -> None:
try:
if self._on_connection:
await self._on_connection()
except Exception:
LOG.exception(f"Error calling 'on_connection' for: {self}")
2 changes: 1 addition & 1 deletion pillars/transports/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import aiohttp.web
import cerberus
from aiohttp.abc import AbstractMatchInfo
import ujson
from aiohttp.abc import AbstractMatchInfo

from ..exceptions import DataValidationError
from ..request import BaseRequest, Response
Expand Down
Loading

0 comments on commit 8132537

Please sign in to comment.