diff --git a/hass_nabucasa/client.py b/hass_nabucasa/client.py index 9852cdad9..b2125bbd2 100644 --- a/hass_nabucasa/client.py +++ b/hass_nabucasa/client.py @@ -48,9 +48,17 @@ def cloudhooks(self) -> dict[str, dict[str, str | bool]]: def remote_autostart(self) -> bool: """Return true if we want start a remote connection.""" + @abstractmethod + async def cloud_connected(self) -> None: + """Called when cloud connected.""" + + @abstractmethod + async def cloud_disconnected(self) -> None: + """Called when cloud disconnected.""" + @abstractmethod async def cloud_started(self) -> None: - """Called when cloud started with active subscription .""" + """Called when cloud started with active subscription.""" @abstractmethod async def cloud_stopped(self) -> None: diff --git a/hass_nabucasa/iot_base.py b/hass_nabucasa/iot_base.py index 9cb2783ff..acfec53f0 100644 --- a/hass_nabucasa/iot_base.py +++ b/hass_nabucasa/iot_base.py @@ -139,10 +139,12 @@ async def connect(self) -> None: # Still adding it here to make sure we can always reconnect self._logger.exception("Unexpected error") - if self.state == STATE_CONNECTED and self._on_disconnect: - await gather_callbacks( - self._logger, "on_disconnect", self._on_disconnect - ) + if self.state == STATE_CONNECTED: + await self.cloud.client.cloud_disconnected() + if self._on_disconnect: + await gather_callbacks( + self._logger, "on_disconnect", self._on_disconnect + ) if self.close_requested: break @@ -306,5 +308,6 @@ async def _connected(self) -> None: self.state = STATE_CONNECTED self._logger.info("Connected") + await self.cloud.client.cloud_connected() if self._on_connect: await gather_callbacks(self._logger, "on_connect", self._on_connect) diff --git a/tests/common.py b/tests/common.py index 08ea78b68..b29a604ef 100644 --- a/tests/common.py +++ b/tests/common.py @@ -10,11 +10,11 @@ from hass_nabucasa.client import CloudClient -class TestClient(CloudClient): +class MockClient(CloudClient): """Interface class for Home Assistant.""" def __init__(self, loop, websession): - """Initialize TestClient.""" + """Initialize MockClient.""" self._loop = loop self._websession = websession self._cloudhooks = {} @@ -65,6 +65,12 @@ def remote_autostart(self) -> bool: """Return true if we want start a remote connection.""" return self.prop_remote_autostart + async def cloud_connected(self): + """Handle cloud connected.""" + + async def cloud_disconnected(self): + """Handle cloud disconnected.""" + async def cloud_started(self): """Handle cloud started.""" diff --git a/tests/conftest.py b/tests/conftest.py index 9f1d93f12..9d8d01547 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,7 @@ import pytest from .utils.aiohttp import mock_aiohttp_client -from .common import TestClient +from .common import MockClient logging.basicConfig(level=logging.DEBUG) @@ -34,7 +34,7 @@ def _executor(call, *args): cloud.run_executor = _executor cloud.websession = aioclient_mock.create_session(loop) - cloud.client = TestClient(loop, cloud.websession) + cloud.client = MockClient(loop, cloud.websession) async def update_token(id_token, access_token, refresh_token=None): cloud.id_token = id_token diff --git a/tests/test_google_report_state.py b/tests/test_google_report_state.py index a1047b036..ea3f126a6 100644 --- a/tests/test_google_report_state.py +++ b/tests/test_google_report_state.py @@ -5,6 +5,8 @@ from hass_nabucasa import iot_base from hass_nabucasa.google_report_state import GoogleReportState, ErrorResponse +from .common import MockClient + async def create_grs(loop, ws_server, server_msg_handler) -> GoogleReportState: """Create a grs instance.""" @@ -15,6 +17,7 @@ async def create_grs(loop, ws_server, server_msg_handler) -> GoogleReportState: remotestate_server="mock-report-state-url.com", auth=Mock(async_check_token=AsyncMock()), websession=Mock(ws_connect=AsyncMock(return_value=client)), + client=Mock(spec_set=MockClient), ) return GoogleReportState(mock_cloud)