Skip to content

Commit

Permalink
Add cloud_connected and cloud_disconnected methods to CloudClient (#450)
Browse files Browse the repository at this point in the history
* Add cloud_connected method to CloudClient

* Update TestClient

* Fix tests

* Run black

* Add cloud_disconnected callback for symmetry
  • Loading branch information
emontnemery committed Jun 20, 2023
1 parent 5e61d5b commit 618d250
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 9 deletions.
10 changes: 9 additions & 1 deletion hass_nabucasa/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,17 @@ def cloudhooks(self) -> dict[str, dict[str, str | bool]]:
def remote_autostart(self) -> bool:
"""Return true if we want start a remote connection."""

@abstractmethod
async def cloud_connected(self) -> None:
"""Called when cloud connected."""

@abstractmethod
async def cloud_disconnected(self) -> None:
"""Called when cloud disconnected."""

@abstractmethod
async def cloud_started(self) -> None:
"""Called when cloud started with active subscription ."""
"""Called when cloud started with active subscription."""

@abstractmethod
async def cloud_stopped(self) -> None:
Expand Down
11 changes: 7 additions & 4 deletions hass_nabucasa/iot_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,12 @@ async def connect(self) -> None:
# Still adding it here to make sure we can always reconnect
self._logger.exception("Unexpected error")

if self.state == STATE_CONNECTED and self._on_disconnect:
await gather_callbacks(
self._logger, "on_disconnect", self._on_disconnect
)
if self.state == STATE_CONNECTED:
await self.cloud.client.cloud_disconnected()
if self._on_disconnect:
await gather_callbacks(
self._logger, "on_disconnect", self._on_disconnect
)

if self.close_requested:
break
Expand Down Expand Up @@ -306,5 +308,6 @@ async def _connected(self) -> None:
self.state = STATE_CONNECTED
self._logger.info("Connected")

await self.cloud.client.cloud_connected()
if self._on_connect:
await gather_callbacks(self._logger, "on_connect", self._on_connect)
10 changes: 8 additions & 2 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from hass_nabucasa.client import CloudClient


class TestClient(CloudClient):
class MockClient(CloudClient):
"""Interface class for Home Assistant."""

def __init__(self, loop, websession):
"""Initialize TestClient."""
"""Initialize MockClient."""
self._loop = loop
self._websession = websession
self._cloudhooks = {}
Expand Down Expand Up @@ -65,6 +65,12 @@ def remote_autostart(self) -> bool:
"""Return true if we want start a remote connection."""
return self.prop_remote_autostart

async def cloud_connected(self):
"""Handle cloud connected."""

async def cloud_disconnected(self):
"""Handle cloud disconnected."""

async def cloud_started(self):
"""Handle cloud started."""

Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

from .utils.aiohttp import mock_aiohttp_client
from .common import TestClient
from .common import MockClient

logging.basicConfig(level=logging.DEBUG)

Expand All @@ -34,7 +34,7 @@ def _executor(call, *args):
cloud.run_executor = _executor

cloud.websession = aioclient_mock.create_session(loop)
cloud.client = TestClient(loop, cloud.websession)
cloud.client = MockClient(loop, cloud.websession)

async def update_token(id_token, access_token, refresh_token=None):
cloud.id_token = id_token
Expand Down
3 changes: 3 additions & 0 deletions tests/test_google_report_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from hass_nabucasa import iot_base
from hass_nabucasa.google_report_state import GoogleReportState, ErrorResponse

from .common import MockClient


async def create_grs(loop, ws_server, server_msg_handler) -> GoogleReportState:
"""Create a grs instance."""
Expand All @@ -15,6 +17,7 @@ async def create_grs(loop, ws_server, server_msg_handler) -> GoogleReportState:
remotestate_server="mock-report-state-url.com",
auth=Mock(async_check_token=AsyncMock()),
websession=Mock(ws_connect=AsyncMock(return_value=client)),
client=Mock(spec_set=MockClient),
)
return GoogleReportState(mock_cloud)

Expand Down

0 comments on commit 618d250

Please sign in to comment.