From c376602666fb35dd37dbef0cf4c425bc9726d9fd Mon Sep 17 00:00:00 2001 From: Robin Battey Date: Wed, 12 Oct 2022 14:49:48 -0700 Subject: [PATCH 1/6] refactor repeated SSL/auth cruft for HASS connection --- appdaemon/plugins/hass/hassplugin.py | 242 +++++++++------------------ 1 file changed, 79 insertions(+), 163 deletions(-) diff --git a/appdaemon/plugins/hass/hassplugin.py b/appdaemon/plugins/hass/hassplugin.py index e1b9e5d3e..c021c0aad 100644 --- a/appdaemon/plugins/hass/hassplugin.py +++ b/appdaemon/plugins/hass/hassplugin.py @@ -42,76 +42,41 @@ def __init__(self, ad: AppDaemon, name, args): self.config = args self.name = name - self.stopping = False - self.ws = None - self.reading_messages = False - self.metadata = None - self.hass_booting = False - self.logger.info("HASS Plugin Initializing") - self.name = name - - if "namespace" in args: - self.namespace = args["namespace"] - else: - self.namespace = "default" - + # validate basic config if "ha_key" in args: - self.ha_key = args["ha_key"] self.logger.warning("ha_key is deprecated please use HASS Long Lived Tokens instead") - else: - self.ha_key = None - - if "token" in args: - self.token = args["token"] - else: - self.token = None - - if "ha_url" in args: - self.ha_url = args["ha_url"] - else: + if "ha_url" not in args: self.logger.warning("ha_url not found in HASS configuration - module not initialized") - if "cert_path" in args: - self.cert_path = args["cert_path"] - else: - self.cert_path = None - - if "timeout" in args: - self.timeout = args["timeout"] - else: - self.timeout = None - - if "retry_secs" in args: - self.retry_secs = int(args["retry_secs"]) - else: - self.retry_secs = 5 - - if "cert_verify" in args: - self.cert_verify = args["cert_verify"] - else: - self.cert_verify = True - - if "commtype" in args: - self.commtype = args["commtype"] - else: - self.commtype = "WS" - - if "appdaemon_startup_conditions" in args: - self.appdaemon_startup_conditions = args["appdaemon_startup_conditions"] - else: - self.appdaemon_startup_conditions = {} - - if "plugin_startup_conditions" in args: - self.plugin_startup_conditions = args["plugin_startup_conditions"] - else: - self.plugin_startup_conditions = {} + # Locally store common args and their defaults + self.appdaemon_startup_conditions = args.get("appdaemon_startup_conditions", {}) + self.cert_path = args.get("cert_path") + self.cert_verify = args.get("cert_verify") + self.commtype = args.get("commtype", "WS") + self.ha_key = args.get("ha_key") + self.ha_url = args.get("ha_url", "") + self.namespace = args.get("namespace", "default") + self.plugin_startup_conditions = args.get("plugin_startup_conditions", {}) + self.retry_secs = int(args.get("retry_secs", 5)) + self.timeout = args.get("timeout") + self.token = args.get("token") + + # Connections to HA + self._session = None # http connection pool for general use + self.ws = None # websocket dedicated for event loop + + # Cached state from HA + self.metadata = None + self.services = None - self.session = None - self.first_time = False + # Internal state flags self.already_notified = False - self.services = None + self.first_time = False + self.hass_booting = False + self.reading_messages = False + self.stopping = False self.logger.info("HASS Plugin initialization complete") @@ -131,9 +96,36 @@ def list_constraints(self): return [] # - # Get initial state + # Persistent Session to HASS instance # + @property + def session(self): + if not self._session: + # ssl None means to use default behavior which check certs for https + ssl_context = (self.cert_verify and None) + if self.cert_verify and self.cert_path: + ssl_context = ssl.create_default_context(capath=self.cert_path) + conn = aiohttp.TCPConnector(ssl=ssl_context) + + # configure auth + headers = {} + if self.token is not None: + headers["Authorization"] = "Bearer {}".format(self.token) + elif self.ha_key is not None: + headers["x-ha-access"] = self.ha_key + + self._session = aiohttp.ClientSession( + base_url=self.ha_url, + connector=conn, + headers=headers, + json_serialize=utils.convert_json + ) + return self._session + + # + # Get initial state + # async def get_complete_state(self): hass_state = await self.get_hass_state() @@ -147,14 +139,12 @@ async def get_complete_state(self): # # Get HASS Metadata # - async def get_metadata(self): return self.metadata # # Handle state updates # - async def evaluate_started(self, first_time, plugin_booting, event=None): # noqa: C901 if first_time is True: @@ -448,22 +438,11 @@ async def set_plugin_state(self, namespace, entity_id, **kwargs): self.logger.debug("set_plugin_state() %s %s %s", namespace, entity_id, kwargs) config = (await self.AD.plugins.get_plugin_object(namespace)).config - # TODO cert_path is not used - if "cert_path" in config: - cert_path = config["cert_path"] - else: - cert_path = False # noqa: F841 - if "token" in config: - headers = {"Authorization": "Bearer {}".format(config["token"])} - elif "ha_key" in config: - headers = {"x-ha-access": config["ha_key"]} - else: - headers = {} - api_url = "{}/api/states/{}".format(config["ha_url"], entity_id) + api_url = "/api/states/{}".format(entity_id) try: - r = await self.session.post(api_url, headers=headers, json=kwargs, verify_ssl=self.cert_verify) + r = await self.session.post(api_url, json=kwargs) if r.status == 200 or r.status == 201: state = await r.json() self.logger.debug("return = %s", state) @@ -508,25 +487,19 @@ async def call_plugin_service(self, namespace, domain, service, data): data = {"entity_id": data} config = (await self.AD.plugins.get_plugin_object(namespace)).config - if "token" in config: - headers = {"Authorization": "Bearer {}".format(config["token"])} - elif "ha_key" in config: - headers = {"x-ha-access": config["ha_key"]} - else: - headers = {} if domain == "template" and service == "render": - api_url = "{}/api/template".format(config["ha_url"]) + api_url = "/api/template" elif domain == "database": return await self.get_history(**data) else: - api_url = "{}/api/services/{}/{}".format(config["ha_url"], domain, service) + api_url = "/api/services/{}/{}".format(domain, service) try: - r = await self.session.post(api_url, headers=headers, json=data, verify_ssl=self.cert_verify) + r = await self.session.post(api_url, json=data) if r.status == 200 or r.status == 201: if domain == "template": @@ -567,23 +540,10 @@ async def call_plugin_service(self, namespace, domain, service, data): async def get_history(self, **kwargs): """Used to get HA's History""" - # TODO cert_path is not used - if "cert_path" in self.config: - cert_path = self.config["cert_path"] - else: - cert_path = False # noqa: F841 - - if "token" in self.config: - headers = {"Authorization": "Bearer {}".format(self.config["token"])} - elif "ha_key" in self.config: - headers = {"x-ha-access": self.config["ha_key"]} - else: - headers = {} - try: api_url = await self.get_history_api(**kwargs) - r = await self.session.get(api_url, headers=headers, verify_ssl=self.cert_verify) + r = await self.session.get(api_url) if r.status == 200 or r.status == 201: result = await r.json() @@ -651,7 +611,7 @@ def as_datetime(args, key): # Build the url # /api/history/period/?filter_entity_id=&end_time= - apiurl = f'{self.config["ha_url"]}/api/history/period' + apiurl = "/api/history/period" if start_time: apiurl += "/" + utils.dt_to_str(start_time.replace(microsecond=0), self.AD.tz) @@ -667,19 +627,12 @@ def as_datetime(args, key): async def get_hass_state(self, entity_id=None): - if self.token is not None: - headers = {"Authorization": "Bearer {}".format(self.token)} - elif self.ha_key is not None: - headers = {"x-ha-access": self.ha_key} - else: - headers = {} - if entity_id is None: - api_url = "{}/api/states".format(self.ha_url) + api_url = "/api/states" else: - api_url = "{}/api/states/{}".format(self.ha_url, entity_id) + api_url = "/api/states/{}".format(entity_id) self.logger.debug("get_ha_state: url is %s", api_url) - r = await self.session.get(api_url, headers=headers, verify_ssl=self.cert_verify) + r = await self.session.get(api_url) if r.status == 200 or r.status == 201: state = await r.json() else: @@ -720,24 +673,10 @@ def validate_tz(self, meta): async def get_hass_config(self): try: - if self.session is None: - # - # Set up HTTP Client - # - conn = aiohttp.TCPConnector() - self.session = aiohttp.ClientSession(connector=conn, json_serialize=utils.convert_json) - self.logger.debug("get_ha_config()") - if self.token is not None: - headers = {"Authorization": "Bearer {}".format(self.token)} - elif self.ha_key is not None: - headers = {"x-ha-access": self.ha_key} - else: - headers = {} - - api_url = "{}/api/config".format(self.ha_url) + api_url = "/api/config" self.logger.debug("get_ha_config: url is %s", api_url) - r = await self.session.get(api_url, headers=headers, verify_ssl=self.cert_verify) + r = await self.session.get(api_url) r.raise_for_status() meta = await r.json() # @@ -749,23 +688,17 @@ async def get_hass_config(self): self.validate_tz(meta) return meta - except Exception: - self.logger.warning("Error getting metadata - retrying") + except Exception as ex: + self.logger.warning("Error getting metadata - retrying: %s", str(ex)) raise async def get_hass_services(self) -> dict: try: self.logger.debug("get_hass_services()") - if self.token is not None: - headers = {"Authorization": "Bearer {}".format(self.token)} - elif self.ha_key is not None: - headers = {"x-ha-access": self.ha_key} - else: - headers = {} - api_url = "{}/api/services".format(self.ha_url) + api_url = "/api/services" self.logger.debug("get_hass_services: url is %s", api_url) - r = await self.session.get(api_url, headers=headers, verify_ssl=self.cert_verify) + r = await self.session.get(api_url) r.raise_for_status() services = await r.json() @@ -859,19 +792,13 @@ async def check_register_service(self, domain: str, services: Union[dict, str]) async def fire_plugin_event(self, event, namespace, **kwargs): self.logger.debug("fire_event: %s, %s %s", event, namespace, kwargs) - config = (await self.AD.plugins.get_plugin_object(namespace)).config - - if "token" in config: - headers = {"Authorization": "Bearer {}".format(config["token"])} - elif "ha_key" in config: - headers = {"x-ha-access": config["ha_key"]} - else: - headers = {} + # if we get a request for not our namespace something has gone very wrong + assert namespace == self.namespace event_clean = quote(event, safe="") - api_url = "{}/api/events/{}".format(config["ha_url"], event_clean) + api_url = "/api/events/{}".format(event_clean) try: - r = await self.session.post(api_url, headers=headers, json=kwargs, verify_ssl=self.cert_verify) + r = await self.session.post(api_url, json=kwargs) r.raise_for_status() state = await r.json() return state @@ -890,25 +817,14 @@ async def fire_plugin_event(self, event, namespace, **kwargs): @hass_check async def remove_entity(self, namespace, entity_id): self.logger.debug("remove_entity() %s", entity_id) - config = (await self.AD.plugins.get_plugin_object(namespace)).config - - # TODO cert_path is not used - if "cert_path" in config: - cert_path = config["cert_path"] - else: - cert_path = False # noqa: F841 - if "token" in config: - headers = {"Authorization": "Bearer {}".format(config["token"])} - elif "ha_key" in config: - headers = {"x-ha-access": config["ha_key"]} - else: - headers = {} + # if we get a request for not our namespace something has gone very wrong + assert namespace == self.namespace - api_url = "{}/api/states/{}".format(config["ha_url"], entity_id) + api_url = "/api/states/{}".format(entity_id) try: - r = await self.session.delete(api_url, headers=headers, verify_ssl=self.cert_verify) + r = await self.session.delete(api_url) if r.status == 200 or r.status == 201: state = await r.json() self.logger.debug("return = %s", state) From 87b65e388b07e15e779c051d0d02354c72cd4840 Mon Sep 17 00:00:00 2001 From: Robin Battey Date: Wed, 12 Oct 2022 15:12:18 -0700 Subject: [PATCH 2/6] remove unused variables --- appdaemon/plugins/hass/hassplugin.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/appdaemon/plugins/hass/hassplugin.py b/appdaemon/plugins/hass/hassplugin.py index c021c0aad..b3555fb51 100644 --- a/appdaemon/plugins/hass/hassplugin.py +++ b/appdaemon/plugins/hass/hassplugin.py @@ -436,8 +436,9 @@ def utility(self): @hass_check async def set_plugin_state(self, namespace, entity_id, **kwargs): self.logger.debug("set_plugin_state() %s %s %s", namespace, entity_id, kwargs) - config = (await self.AD.plugins.get_plugin_object(namespace)).config + # if we get a request for not our namespace something has gone very wrong + assert namespace == self.namespace api_url = "/api/states/{}".format(entity_id) @@ -480,14 +481,15 @@ async def call_plugin_service(self, namespace, domain, service, data): data, ) + # if we get a request for not our namespace something has gone very wrong + assert namespace == self.namespace + # # If data is a string just assume it's an entity_id # if isinstance(data, str): data = {"entity_id": data} - config = (await self.AD.plugins.get_plugin_object(namespace)).config - if domain == "template" and service == "render": api_url = "/api/template" From ae43f34c7f1882f5f391100f659418bd9ad6de8c Mon Sep 17 00:00:00 2001 From: Robin Battey Date: Wed, 12 Oct 2022 18:25:41 -0700 Subject: [PATCH 3/6] black --- appdaemon/.vscode/settings.json | 3 + appdaemon/plugins/hass/hassplugin.py | 128 ++++++++++++++++++++------- 2 files changed, 99 insertions(+), 32 deletions(-) create mode 100644 appdaemon/.vscode/settings.json diff --git a/appdaemon/.vscode/settings.json b/appdaemon/.vscode/settings.json new file mode 100644 index 000000000..de288e1ea --- /dev/null +++ b/appdaemon/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.formatting.provider": "black" +} \ No newline at end of file diff --git a/appdaemon/plugins/hass/hassplugin.py b/appdaemon/plugins/hass/hassplugin.py index b3555fb51..ff49acf6c 100644 --- a/appdaemon/plugins/hass/hassplugin.py +++ b/appdaemon/plugins/hass/hassplugin.py @@ -25,7 +25,9 @@ def hass_check(func): def func_wrapper(*args, **kwargs): self = args[0] if not self.reading_messages: - self.logger.warning("Attempt to call Home Assistant while disconnected: %s", func.__name__) + self.logger.warning( + "Attempt to call Home Assistant while disconnected: %s", func.__name__ + ) return no_func() else: return func(*args, **kwargs) @@ -46,9 +48,13 @@ def __init__(self, ad: AppDaemon, name, args): # validate basic config if "ha_key" in args: - self.logger.warning("ha_key is deprecated please use HASS Long Lived Tokens instead") + self.logger.warning( + "ha_key is deprecated please use HASS Long Lived Tokens instead" + ) if "ha_url" not in args: - self.logger.warning("ha_url not found in HASS configuration - module not initialized") + self.logger.warning( + "ha_url not found in HASS configuration - module not initialized" + ) # Locally store common args and their defaults self.appdaemon_startup_conditions = args.get("appdaemon_startup_conditions", {}) @@ -96,13 +102,13 @@ def list_constraints(self): return [] # - # Persistent Session to HASS instance + # Persistent Session to HASS instance # @property def session(self): if not self._session: # ssl None means to use default behavior which check certs for https - ssl_context = (self.cert_verify and None) + ssl_context = None if self.cert_verify else False if self.cert_verify and self.cert_path: ssl_context = ssl.create_default_context(capath=self.cert_path) conn = aiohttp.TCPConnector(ssl=ssl_context) @@ -118,11 +124,10 @@ def session(self): base_url=self.ha_url, connector=conn, headers=headers, - json_serialize=utils.convert_json + json_serialize=utils.convert_json, ) return self._session - # # Get initial state # @@ -145,7 +150,9 @@ async def get_metadata(self): # # Handle state updates # - async def evaluate_started(self, first_time, plugin_booting, event=None): # noqa: C901 + async def evaluate_started( + self, first_time, plugin_booting, event=None + ): # noqa: C901 if first_time is True: self.hass_ready = False @@ -163,7 +170,9 @@ async def evaluate_started(self, first_time, plugin_booting, event=None): # noq if "delay" in startup_conditions: if first_time is True: - self.logger.info("Delaying startup for %s seconds", startup_conditions["delay"]) + self.logger.info( + "Delaying startup for %s seconds", startup_conditions["delay"] + ) await asyncio.sleep(int(startup_conditions["delay"])) if "hass_state" in startup_conditions: @@ -196,7 +205,9 @@ async def evaluate_started(self, first_time, plugin_booting, event=None): # noq start_ok = False elif entry["entity"] in state: if self.state_matched is False: - self.logger.info("Startup condition met: %s exists", entry["entity"]) + self.logger.info( + "Startup condition met: %s exists", entry["entity"] + ) self.state_matched = True else: start_ok = False @@ -214,7 +225,9 @@ async def evaluate_started(self, first_time, plugin_booting, event=None): # noq start_ok = False else: if entry["event_type"] == event["event_type"]: - if "values_changed" not in DeepDiff(event["data"], entry["data"]): + if "values_changed" not in DeepDiff( + event["data"], entry["data"] + ): self.logger.info( "Startup condition met: event type %s, data = %s fired", event["event_type"], @@ -294,7 +307,9 @@ async def get_updates(self): # noqa: C901 sslopt = {"cert_reqs": ssl.CERT_NONE} if self.cert_path: sslopt["ca_certs"] = self.cert_path - self.ws = websocket.create_connection("{}/api/websocket".format(url), sslopt=sslopt) + self.ws = websocket.create_connection( + "{}/api/websocket".format(url), sslopt=sslopt + ) res = await utils.run_in_executor(self, self.ws.recv) result = json.loads(res) self.logger.info("Connected to Home Assistant %s", result["ha_version"]) @@ -307,7 +322,9 @@ async def get_updates(self): # noqa: C901 elif self.ha_key is not None: auth = json.dumps({"type": "auth", "api_password": self.ha_key}) else: - raise ValueError("HASS requires authentication and none provided in plugin config") + raise ValueError( + "HASS requires authentication and none provided in plugin config" + ) await utils.run_in_executor(self, self.ws.send, auth) result = json.loads(self.ws.recv()) @@ -320,8 +337,14 @@ async def get_updates(self): # noqa: C901 sub = json.dumps({"id": _id, "type": "subscribe_events"}) await utils.run_in_executor(self, self.ws.send, sub) result = json.loads(self.ws.recv()) - if not (result["id"] == _id and result["type"] == "result" and result["success"] is True): - self.logger.warning("Unable to subscribe to HA events, id = %s", _id) + if not ( + result["id"] == _id + and result["type"] == "result" + and result["success"] is True + ): + self.logger.warning( + "Unable to subscribe to HA events, id = %s", _id + ) self.logger.warning(result) raise ValueError("Error subscribing to HA Events") @@ -337,7 +360,11 @@ async def get_updates(self): # noqa: C901 domain = hass_service["domain"] for service in hass_service["services"]: self.AD.services.register_service( - self.get_namespace(), domain, service, self.call_plugin_service, __silent=True + self.get_namespace(), + domain, + service, + self.call_plugin_service, + __silent=True, ) # Decide if we can start yet @@ -365,12 +392,16 @@ async def get_updates(self): # noqa: C901 result = json.loads(ret) if not (result["id"] == _id and result["type"] == "event"): - self.logger.warning("Unexpected result from Home Assistant, id = %s", _id) + self.logger.warning( + "Unexpected result from Home Assistant, id = %s", _id + ) self.logger.warning(result) if self.reading_messages is False: if result["type"] == "event": - await self.evaluate_started(False, self.hass_booting, result["event"]) + await self.evaluate_started( + False, self.hass_booting, result["event"] + ) else: await self.evaluate_started(False, self.hass_booting) else: @@ -380,7 +411,9 @@ async def get_updates(self): # noqa: C901 metadata["context"] = result["event"].pop("context", None) result["event"]["data"]["metadata"] = metadata - await self.AD.events.process_event(self.namespace, result["event"]) + await self.AD.events.process_event( + self.namespace, result["event"] + ) if result["event"].get("event_type") == "service_registered": data = result["event"]["data"] @@ -401,10 +434,15 @@ async def get_updates(self): # noqa: C901 await self.AD.callbacks.clear_callbacks(self.name) if not self.already_notified: - await self.AD.plugins.notify_plugin_stopped(self.name, self.namespace) + await self.AD.plugins.notify_plugin_stopped( + self.name, self.namespace + ) self.already_notified = True if not self.stopping: - self.logger.warning("Disconnected from Home Assistant, retrying in %s seconds", self.retry_secs) + self.logger.warning( + "Disconnected from Home Assistant, retrying in %s seconds", + self.retry_secs, + ) self.logger.debug("-" * 60) self.logger.debug("Unexpected error:") self.logger.debug("-" * 60) @@ -459,7 +497,9 @@ async def set_plugin_state(self, namespace, entity_id, **kwargs): state = None return state except (asyncio.TimeoutError, asyncio.CancelledError): - self.logger.warning("Timeout in set_state(%s, %s, %s)", namespace, entity_id, kwargs) + self.logger.warning( + "Timeout in set_state(%s, %s, %s)", namespace, entity_id, kwargs + ) except aiohttp.client_exceptions.ServerDisconnectedError: self.logger.warning("HASS Disconnected unexpectedly during set_state()") except Exception: @@ -533,7 +573,9 @@ async def call_plugin_service(self, namespace, domain, service, data): except Exception: self.logger.error("-" * 60) self.logger.error("Unexpected error during call_plugin_service()") - self.logger.error("Service: %s.%s.%s Arguments: %s", namespace, domain, service, data) + self.logger.error( + "Service: %s.%s.%s Arguments: %s", namespace, domain, service, data + ) self.logger.error("-" * 60) self.logger.error(traceback.format_exc()) self.logger.error("-" * 60) @@ -616,7 +658,9 @@ def as_datetime(args, key): apiurl = "/api/history/period" if start_time: - apiurl += "/" + utils.dt_to_str(start_time.replace(microsecond=0), self.AD.tz) + apiurl += "/" + utils.dt_to_str( + start_time.replace(microsecond=0), self.AD.tz + ) if entity_id or end_time: if entity_id: @@ -646,7 +690,9 @@ async def get_hass_state(self, entity_id=None): def validate_meta(self, meta, key): if key not in meta: - self.logger.warning("Value for '%s' not found in metadata for plugin %s", key, self.name) + self.logger.warning( + "Value for '%s' not found in metadata for plugin %s", key, self.name + ) raise ValueError try: float(meta[key]) @@ -661,7 +707,9 @@ def validate_meta(self, meta, key): def validate_tz(self, meta): if "time_zone" not in meta: - self.logger.warning("Value for 'time_zone' not found in metadata for plugin %s", self.name) + self.logger.warning( + "Value for 'time_zone' not found in metadata for plugin %s", self.name + ) raise ValueError try: pytz.timezone(meta["time_zone"]) @@ -751,7 +799,9 @@ async def run_hass_service_check(self) -> None: await self.check_register_service(domain, services) - async def check_register_service(self, domain: str, services: Union[dict, str]) -> bool: + async def check_register_service( + self, domain: str, services: Union[dict, str] + ) -> bool: """Used to check and register a service if need be""" domain_exists = False @@ -775,7 +825,11 @@ async def check_register_service(self, domain: str, services: Union[dict, str]) self.services[service_index]["services"][services] = {} self.AD.services.register_service( - self.get_namespace(), domain, services, self.call_plugin_service, __silent=True + self.get_namespace(), + domain, + services, + self.call_plugin_service, + __silent=True, ) else: @@ -785,7 +839,11 @@ async def check_register_service(self, domain: str, services: Union[dict, str]) self.services[service_index]["services"][service] = service_data self.AD.services.register_service( - self.get_namespace(), domain, service, self.call_plugin_service, __silent=True + self.get_namespace(), + domain, + service, + self.call_plugin_service, + __silent=True, ) return domain_exists @@ -805,7 +863,9 @@ async def fire_plugin_event(self, event, namespace, **kwargs): state = await r.json() return state except (asyncio.TimeoutError, asyncio.CancelledError): - self.logger.warning("Timeout in fire_event(%s, %s, %s)", event, namespace, kwargs) + self.logger.warning( + "Timeout in fire_event(%s, %s, %s)", event, namespace, kwargs + ) except aiohttp.client_exceptions.ServerDisconnectedError: self.logger.warning("HASS Disconnected unexpectedly during fire_event()") except Exception: @@ -831,13 +891,17 @@ async def remove_entity(self, namespace, entity_id): state = await r.json() self.logger.debug("return = %s", state) else: - self.logger.warning("Error Removing Home Assistant entity %s", entity_id) + self.logger.warning( + "Error Removing Home Assistant entity %s", entity_id + ) txt = await r.text() self.logger.warning("Code: %s, error: %s", r.status, txt) state = None return state except (asyncio.TimeoutError, asyncio.CancelledError): - self.logger.warning("Timeout in remove_entity(%s, %s)", namespace, entity_id) + self.logger.warning( + "Timeout in remove_entity(%s, %s)", namespace, entity_id + ) except aiohttp.client_exceptions.ServerDisconnectedError: self.logger.warning("HASS Disconnected unexpectedly during remove_entity()") except Exception: From a34bf4ce8115359e9c1a68d55569412b97f8e152 Mon Sep 17 00:00:00 2001 From: Robin Battey Date: Wed, 12 Oct 2022 18:29:08 -0700 Subject: [PATCH 4/6] remove accidentally added vscode files --- appdaemon/.vscode/settings.json | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 appdaemon/.vscode/settings.json diff --git a/appdaemon/.vscode/settings.json b/appdaemon/.vscode/settings.json deleted file mode 100644 index de288e1ea..000000000 --- a/appdaemon/.vscode/settings.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "python.formatting.provider": "black" -} \ No newline at end of file From 49babd4248107fbf77ab4991bb1568e897dc7912 Mon Sep 17 00:00:00 2001 From: Robin Battey Date: Wed, 12 Oct 2022 18:38:36 -0700 Subject: [PATCH 5/6] black with proper settings --- appdaemon/plugins/hass/hassplugin.py | 98 +++++++--------------------- 1 file changed, 24 insertions(+), 74 deletions(-) diff --git a/appdaemon/plugins/hass/hassplugin.py b/appdaemon/plugins/hass/hassplugin.py index ff49acf6c..9ad6209c3 100644 --- a/appdaemon/plugins/hass/hassplugin.py +++ b/appdaemon/plugins/hass/hassplugin.py @@ -25,9 +25,7 @@ def hass_check(func): def func_wrapper(*args, **kwargs): self = args[0] if not self.reading_messages: - self.logger.warning( - "Attempt to call Home Assistant while disconnected: %s", func.__name__ - ) + self.logger.warning("Attempt to call Home Assistant while disconnected: %s", func.__name__) return no_func() else: return func(*args, **kwargs) @@ -48,13 +46,9 @@ def __init__(self, ad: AppDaemon, name, args): # validate basic config if "ha_key" in args: - self.logger.warning( - "ha_key is deprecated please use HASS Long Lived Tokens instead" - ) + self.logger.warning("ha_key is deprecated please use HASS Long Lived Tokens instead") if "ha_url" not in args: - self.logger.warning( - "ha_url not found in HASS configuration - module not initialized" - ) + self.logger.warning("ha_url not found in HASS configuration - module not initialized") # Locally store common args and their defaults self.appdaemon_startup_conditions = args.get("appdaemon_startup_conditions", {}) @@ -150,9 +144,7 @@ async def get_metadata(self): # # Handle state updates # - async def evaluate_started( - self, first_time, plugin_booting, event=None - ): # noqa: C901 + async def evaluate_started(self, first_time, plugin_booting, event=None): # noqa: C901 if first_time is True: self.hass_ready = False @@ -170,9 +162,7 @@ async def evaluate_started( if "delay" in startup_conditions: if first_time is True: - self.logger.info( - "Delaying startup for %s seconds", startup_conditions["delay"] - ) + self.logger.info("Delaying startup for %s seconds", startup_conditions["delay"]) await asyncio.sleep(int(startup_conditions["delay"])) if "hass_state" in startup_conditions: @@ -205,9 +195,7 @@ async def evaluate_started( start_ok = False elif entry["entity"] in state: if self.state_matched is False: - self.logger.info( - "Startup condition met: %s exists", entry["entity"] - ) + self.logger.info("Startup condition met: %s exists", entry["entity"]) self.state_matched = True else: start_ok = False @@ -225,9 +213,7 @@ async def evaluate_started( start_ok = False else: if entry["event_type"] == event["event_type"]: - if "values_changed" not in DeepDiff( - event["data"], entry["data"] - ): + if "values_changed" not in DeepDiff(event["data"], entry["data"]): self.logger.info( "Startup condition met: event type %s, data = %s fired", event["event_type"], @@ -307,9 +293,7 @@ async def get_updates(self): # noqa: C901 sslopt = {"cert_reqs": ssl.CERT_NONE} if self.cert_path: sslopt["ca_certs"] = self.cert_path - self.ws = websocket.create_connection( - "{}/api/websocket".format(url), sslopt=sslopt - ) + self.ws = websocket.create_connection("{}/api/websocket".format(url), sslopt=sslopt) res = await utils.run_in_executor(self, self.ws.recv) result = json.loads(res) self.logger.info("Connected to Home Assistant %s", result["ha_version"]) @@ -322,9 +306,7 @@ async def get_updates(self): # noqa: C901 elif self.ha_key is not None: auth = json.dumps({"type": "auth", "api_password": self.ha_key}) else: - raise ValueError( - "HASS requires authentication and none provided in plugin config" - ) + raise ValueError("HASS requires authentication and none provided in plugin config") await utils.run_in_executor(self, self.ws.send, auth) result = json.loads(self.ws.recv()) @@ -337,14 +319,8 @@ async def get_updates(self): # noqa: C901 sub = json.dumps({"id": _id, "type": "subscribe_events"}) await utils.run_in_executor(self, self.ws.send, sub) result = json.loads(self.ws.recv()) - if not ( - result["id"] == _id - and result["type"] == "result" - and result["success"] is True - ): - self.logger.warning( - "Unable to subscribe to HA events, id = %s", _id - ) + if not (result["id"] == _id and result["type"] == "result" and result["success"] is True): + self.logger.warning("Unable to subscribe to HA events, id = %s", _id) self.logger.warning(result) raise ValueError("Error subscribing to HA Events") @@ -392,16 +368,12 @@ async def get_updates(self): # noqa: C901 result = json.loads(ret) if not (result["id"] == _id and result["type"] == "event"): - self.logger.warning( - "Unexpected result from Home Assistant, id = %s", _id - ) + self.logger.warning("Unexpected result from Home Assistant, id = %s", _id) self.logger.warning(result) if self.reading_messages is False: if result["type"] == "event": - await self.evaluate_started( - False, self.hass_booting, result["event"] - ) + await self.evaluate_started(False, self.hass_booting, result["event"]) else: await self.evaluate_started(False, self.hass_booting) else: @@ -411,9 +383,7 @@ async def get_updates(self): # noqa: C901 metadata["context"] = result["event"].pop("context", None) result["event"]["data"]["metadata"] = metadata - await self.AD.events.process_event( - self.namespace, result["event"] - ) + await self.AD.events.process_event(self.namespace, result["event"]) if result["event"].get("event_type") == "service_registered": data = result["event"]["data"] @@ -434,9 +404,7 @@ async def get_updates(self): # noqa: C901 await self.AD.callbacks.clear_callbacks(self.name) if not self.already_notified: - await self.AD.plugins.notify_plugin_stopped( - self.name, self.namespace - ) + await self.AD.plugins.notify_plugin_stopped(self.name, self.namespace) self.already_notified = True if not self.stopping: self.logger.warning( @@ -497,9 +465,7 @@ async def set_plugin_state(self, namespace, entity_id, **kwargs): state = None return state except (asyncio.TimeoutError, asyncio.CancelledError): - self.logger.warning( - "Timeout in set_state(%s, %s, %s)", namespace, entity_id, kwargs - ) + self.logger.warning("Timeout in set_state(%s, %s, %s)", namespace, entity_id, kwargs) except aiohttp.client_exceptions.ServerDisconnectedError: self.logger.warning("HASS Disconnected unexpectedly during set_state()") except Exception: @@ -573,9 +539,7 @@ async def call_plugin_service(self, namespace, domain, service, data): except Exception: self.logger.error("-" * 60) self.logger.error("Unexpected error during call_plugin_service()") - self.logger.error( - "Service: %s.%s.%s Arguments: %s", namespace, domain, service, data - ) + self.logger.error("Service: %s.%s.%s Arguments: %s", namespace, domain, service, data) self.logger.error("-" * 60) self.logger.error(traceback.format_exc()) self.logger.error("-" * 60) @@ -658,9 +622,7 @@ def as_datetime(args, key): apiurl = "/api/history/period" if start_time: - apiurl += "/" + utils.dt_to_str( - start_time.replace(microsecond=0), self.AD.tz - ) + apiurl += "/" + utils.dt_to_str(start_time.replace(microsecond=0), self.AD.tz) if entity_id or end_time: if entity_id: @@ -690,9 +652,7 @@ async def get_hass_state(self, entity_id=None): def validate_meta(self, meta, key): if key not in meta: - self.logger.warning( - "Value for '%s' not found in metadata for plugin %s", key, self.name - ) + self.logger.warning("Value for '%s' not found in metadata for plugin %s", key, self.name) raise ValueError try: float(meta[key]) @@ -707,9 +667,7 @@ def validate_meta(self, meta, key): def validate_tz(self, meta): if "time_zone" not in meta: - self.logger.warning( - "Value for 'time_zone' not found in metadata for plugin %s", self.name - ) + self.logger.warning("Value for 'time_zone' not found in metadata for plugin %s", self.name) raise ValueError try: pytz.timezone(meta["time_zone"]) @@ -799,9 +757,7 @@ async def run_hass_service_check(self) -> None: await self.check_register_service(domain, services) - async def check_register_service( - self, domain: str, services: Union[dict, str] - ) -> bool: + async def check_register_service(self, domain: str, services: Union[dict, str]) -> bool: """Used to check and register a service if need be""" domain_exists = False @@ -863,9 +819,7 @@ async def fire_plugin_event(self, event, namespace, **kwargs): state = await r.json() return state except (asyncio.TimeoutError, asyncio.CancelledError): - self.logger.warning( - "Timeout in fire_event(%s, %s, %s)", event, namespace, kwargs - ) + self.logger.warning("Timeout in fire_event(%s, %s, %s)", event, namespace, kwargs) except aiohttp.client_exceptions.ServerDisconnectedError: self.logger.warning("HASS Disconnected unexpectedly during fire_event()") except Exception: @@ -891,17 +845,13 @@ async def remove_entity(self, namespace, entity_id): state = await r.json() self.logger.debug("return = %s", state) else: - self.logger.warning( - "Error Removing Home Assistant entity %s", entity_id - ) + self.logger.warning("Error Removing Home Assistant entity %s", entity_id) txt = await r.text() self.logger.warning("Code: %s, error: %s", r.status, txt) state = None return state except (asyncio.TimeoutError, asyncio.CancelledError): - self.logger.warning( - "Timeout in remove_entity(%s, %s)", namespace, entity_id - ) + self.logger.warning("Timeout in remove_entity(%s, %s)", namespace, entity_id) except aiohttp.client_exceptions.ServerDisconnectedError: self.logger.warning("HASS Disconnected unexpectedly during remove_entity()") except Exception: From a78759bfc80741a5ad445575412ecf3e7a4ca170 Mon Sep 17 00:00:00 2001 From: Robin Battey Date: Wed, 12 Oct 2022 22:19:38 -0700 Subject: [PATCH 6/6] also refactor websocket for re-use --- appdaemon/plugins/hass/hassplugin.py | 74 ++++++++++++++++------------ 1 file changed, 43 insertions(+), 31 deletions(-) diff --git a/appdaemon/plugins/hass/hassplugin.py b/appdaemon/plugins/hass/hassplugin.py index 9ad6209c3..1f8b6139f 100644 --- a/appdaemon/plugins/hass/hassplugin.py +++ b/appdaemon/plugins/hass/hassplugin.py @@ -96,7 +96,7 @@ def list_constraints(self): return [] # - # Persistent Session to HASS instance + # Persistent HTTP Session to HASS instance # @property def session(self): @@ -122,6 +122,47 @@ def session(self): ) return self._session + # + # Connect and return a new WebSocket to HASS instance + # + async def create_websocket(self): + # change to websocket protocol + url = self.ha_url + if url.startswith("https://"): + url = url.replace("https", "wss", 1) + elif url.startswith("http://"): + url = url.replace("http", "ws", 1) + + # ssl options + sslopt = {} + if self.cert_verify is False: + sslopt = {"cert_reqs": ssl.CERT_NONE} + if self.cert_path: + sslopt["ca_certs"] = self.cert_path + ws = websocket.create_connection("{}/api/websocket".format(url), sslopt=sslopt) + + # wait for successful connection + res = await utils.run_in_executor(self, ws.recv) + result = json.loads(res) + self.logger.info("Connected to Home Assistant %s", result["ha_version"]) + + # Check if auth required, if so send password + if result["type"] == "auth_required": + if self.token is not None: + auth = json.dumps({"type": "auth", "access_token": self.token}) + elif self.ha_key is not None: + auth = json.dumps({"type": "auth", "api_password": self.ha_key}) + else: + raise ValueError("HASS requires authentication and none provided in plugin config") + + await utils.run_in_executor(self, ws.send, auth) + result = json.loads(ws.recv()) + if result["type"] != "auth_ok": + self.logger.warning("Error in authentication") + raise ValueError("Error in authentication") + + return ws + # # Get initial state # @@ -282,37 +323,8 @@ async def get_updates(self): # noqa: C901 # # Connect to websocket interface # - url = self.ha_url - if url.startswith("https://"): - url = url.replace("https", "wss", 1) - elif url.startswith("http://"): - url = url.replace("http", "ws", 1) - - sslopt = {} - if self.cert_verify is False: - sslopt = {"cert_reqs": ssl.CERT_NONE} - if self.cert_path: - sslopt["ca_certs"] = self.cert_path - self.ws = websocket.create_connection("{}/api/websocket".format(url), sslopt=sslopt) - res = await utils.run_in_executor(self, self.ws.recv) - result = json.loads(res) - self.logger.info("Connected to Home Assistant %s", result["ha_version"]) - # - # Check if auth required, if so send password - # - if result["type"] == "auth_required": - if self.token is not None: - auth = json.dumps({"type": "auth", "access_token": self.token}) - elif self.ha_key is not None: - auth = json.dumps({"type": "auth", "api_password": self.ha_key}) - else: - raise ValueError("HASS requires authentication and none provided in plugin config") + self.ws = await self.create_websocket() - await utils.run_in_executor(self, self.ws.send, auth) - result = json.loads(self.ws.recv()) - if result["type"] != "auth_ok": - self.logger.warning("Error in authentication") - raise ValueError("Error in authentication") # # Subscribe to event stream #