Skip to content

Commit

Permalink
Merge pull request #1 from leeyuentuen/branch2
Browse files Browse the repository at this point in the history
Update gateway connection to connect at load time instead of waiting for a call to async_connect
  • Loading branch information
knifehandz authored Feb 9, 2022
2 parents 453b77a + 6898579 commit 11e9893
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 39 deletions.
11 changes: 4 additions & 7 deletions custom_components/localtuya/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry):

if not entry.data.get(CONF_IS_GATEWAY):
async def setup_entities():
platforms = set(entity[CONF_PLATFORM] for entity in entry.data[CONF_ENTITIES])
platforms = set(
entity[CONF_PLATFORM] for entity in entry.data[CONF_ENTITIES]
)
await asyncio.gather(
*[
hass.config_entries.async_forward_entry_setup(entry, platform)
Expand All @@ -286,12 +288,9 @@ async def setup_entities():
device.async_connect()

await async_remove_orphan_entities(hass, entry)

hass.async_create_task(setup_entities())

return True


async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry):
"""Unload a config entry."""
unload_ok = all(
Expand All @@ -310,9 +309,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry):
if unload_ok:
hass.data[DOMAIN].pop(entry.entry_id)

return True


return True
async def update_listener(hass, config_entry):
"""Update listener."""
await hass.config_entries.async_reload(config_entry.entry_id)
Expand Down
42 changes: 24 additions & 18 deletions custom_components/localtuya/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import asyncio
import logging
from datetime import timedelta
from homeassistant.config_entries import ConfigEntry

from homeassistant.const import (
CONF_DEVICE_ID,
Expand All @@ -12,7 +13,9 @@
CONF_PLATFORM,
CONF_SCAN_INTERVAL,
)
from homeassistant.core import callback
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
Expand Down Expand Up @@ -62,7 +65,12 @@ def prepare_setup_entities(hass, config_entry, platform):


async def async_setup_entry(
domain, entity_class, flow_schema, hass, config_entry, async_add_entities
domain: str,
entity_class: type,
flow_schema,
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
):
"""Set up a Tuya platform based on a config entry.
Expand All @@ -72,28 +80,27 @@ async def async_setup_entry(
tuyainterface, entities_to_setup = prepare_setup_entities(
hass, config_entry, domain
)

if not entities_to_setup:
return

dps_config_fields = list(get_dps_for_platform(flow_schema))

entities = []
for device_config in entities_to_setup:
# Add DPS used by this platform to the request list
for dp_conf in dps_config_fields:
if dp_conf in device_config:
tuyainterface.dps_to_request[device_config[dp_conf]] = None

entities.append(
entity_class(
tuyainterface,
config_entry,
device_config[CONF_ID],
)
async_add_entities(
[
entity_class(
tuyainterface,
config_entry,
device_config[CONF_ID],
)
],
True,
)

async_add_entities(entities)


def get_dps_for_platform(flow_schema):
"""Return config keys for all platform keys that depends on a datapoint."""
Expand Down Expand Up @@ -276,7 +283,7 @@ def __init__(self, hass, config_entry):
self._config_entry = config_entry
self._interface = None
self._is_closing = False
self._connect_task = None
self._connect_task = asyncio.create_task(self._make_connection())
self._disconnect_task = None
self._retry_sub_conn_interval = None
self._sub_devices = {}
Expand All @@ -299,7 +306,6 @@ def async_connect(self):
async def _make_connection(self):
"""Subscribe localtuya entity events."""
self.debug("Connecting to gateway %s", self._config_entry[CONF_HOST])

try:
self._interface = await pytuya.connect(
self._config_entry[CONF_HOST],
Expand Down Expand Up @@ -329,7 +335,7 @@ async def _make_connection(self):
self._retry_sub_conn_interval = async_track_time_interval(
self._hass,
self._retry_sub_device_connection,
timedelta(seconds=SUB_DEVICE_RECONNECT_INTERVAL)
timedelta(seconds=SUB_DEVICE_RECONNECT_INTERVAL),
)

except Exception: # pylint: disable=broad-except
Expand Down Expand Up @@ -537,7 +543,7 @@ async def _gateway_request_task(self, request, content):
self.debug(
"Unable to dispatch request %s due to pending request %s",
request,
self._pending_request["request"]
self._pending_request["request"],
)

return
Expand All @@ -562,7 +568,7 @@ async def _gateway_request_task(self, request, content):
"request": self._pending_request["request"],
"cid": self._config_entry[CONF_DEVICE_ID],
"content": self._pending_request["content"],
}
},
)

self._pending_request["retry_count"] += 1
Expand Down
3 changes: 2 additions & 1 deletion custom_components/localtuya/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,8 @@ def __init__(self, config_entry):
"""Initialize localtuya options flow."""
self.config_entry = config_entry
self.dps_strings = config_entry.data.get(CONF_DPS_STRINGS, gen_dps_strings())
self.entities = config_entry.data[CONF_ENTITIES]
if not config_entry.data.get(CONF_IS_GATEWAY):
self.entities = config_entry.data[CONF_ENTITIES]
self.data = None

async def async_step_init(self, user_input=None):
Expand Down
25 changes: 12 additions & 13 deletions custom_components/localtuya/pytuya/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def disconnected(self):
class TuyaProtocol(asyncio.Protocol, ContextualLogger):
"""Implementation of the Tuya protocol."""


def __init__(self, dev_id, local_key, protocol_version, on_connected, listener, is_gateway):
"""
Initialize a new TuyaInterface.
Expand Down Expand Up @@ -387,15 +388,15 @@ def __init__(self, dev_id, local_key, protocol_version, on_connected, listener,
self.sub_devices = []

def _setup_dispatcher(self):
def _status_update(msg):
decoded_message = self._decode_payload(msg.payload)
self._update_dps_cache(decoded_message)
return MessageDispatcher(self.id, self._status_update)

listener = self.listener and self.listener()
if listener is not None:
listener.status_updated(self.dps_cache)
def _status_update(self, msg):
decoded_message = self._decode_payload(msg.payload)
self._update_dps_cache(decoded_message)

return MessageDispatcher(self.id, _status_update)
listener = self.listener and self.listener()
if listener is not None:
listener.status_updated(self.dps_cache)

def connection_made(self, transport):
"""Did connect to the device."""
Expand Down Expand Up @@ -551,7 +552,6 @@ async def set_dp(self, value, dp_index, cid=None):
raise Exception("Sub-device cid not specified for gateway")
if cid not in self.sub_devices:
raise Exception("Unexpected sub-device cid", cid)

return await self.exchange(SET, {str(dp_index): value}, cid)

async def set_dps(self, dps, cid=None):
Expand All @@ -561,9 +561,7 @@ async def set_dps(self, dps, cid=None):
raise Exception("Sub-device cid not specified for gateway")
if cid not in self.sub_devices:
raise Exception("Unexpected sub-device cid", cid)

return await self.exchange(SET, dps, cid)

async def detect_available_dps(self, cid=None):
"""Return which datapoints are supported by the device."""

Expand Down Expand Up @@ -619,7 +617,6 @@ async def detect_available_dps(self, cid=None):

def add_dps_to_request(self, dp_indicies, cid=None):
"""Add a datapoint (DP) to be included in requests."""

if self.is_gateway:
if not cid:
raise Exception("Sub-device cid not specified for gateway")
Expand All @@ -630,7 +627,6 @@ def add_dps_to_request(self, dp_indicies, cid=None):
self.dps_to_request[cid][str(dp_indicies)] = None
else:
self.dps_to_request[cid].update({str(index): None for index in dp_indicies})

else:
if isinstance(dp_indicies, int):
self.dps_to_request[str(dp_indicies)] = None
Expand Down Expand Up @@ -731,7 +727,10 @@ def _generate_payload(self, command, data=None, cid=None):
else:
json_data["dps"] = data
elif command_hb == COMMAND_CONTROL_NEW:
json_data["dps"] = self.dps_to_request
if cid:
json_data["dps"] = self.dps_to_request[cid]
else:
json_data["dps"] = self.dps_to_request

payload = json.dumps(json_data).replace(" ", "").encode("utf-8")
self.debug("Send payload: %s", payload)
Expand Down

0 comments on commit 11e9893

Please sign in to comment.