diff --git a/.gitignore b/.gitignore index ecd6feb54..96dd13d6d 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ __pycache__ .tox tuyadebug/ -.pre-commit-config.yaml \ No newline at end of file +.pre-commit-config.yaml +.idea diff --git a/README.md b/README.md index 01cbba98f..87a72b9b0 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ The following Tuya device types are currently supported: * Fans * Climates * Vacuums +* * Zigbee and Bluetooth gateways and their attached devices Energy monitoring (voltage, current, watts, etc.) is supported for compatible devices. diff --git a/custom_components/localtuya/__init__.py b/custom_components/localtuya/__init__.py index 9a996c31f..23da5db74 100644 --- a/custom_components/localtuya/__init__.py +++ b/custom_components/localtuya/__init__.py @@ -19,12 +19,14 @@ CONF_PLATFORM, CONF_REGION, CONF_USERNAME, + CONF_FRIENDLY_NAME, + CONF_MODEL, EVENT_HOMEASSISTANT_STOP, SERVICE_RELOAD, ) from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers.device_registry import DeviceEntry +from homeassistant.helpers import device_registry from homeassistant.helpers.event import async_track_time_interval from .cloud_api import TuyaCloudApi @@ -39,6 +41,9 @@ DATA_DISCOVERY, DOMAIN, TUYA_DEVICES, + CONF_GATEWAY_DEVICE_ID, + CONF_IS_GATEWAY, + CONF_PROTOCOL_VERSION, ) from .discovery import TuyaDiscovery @@ -139,10 +144,8 @@ def _device_discovered(device): ) new_data[ATTR_UPDATED_AT] = str(int(time.time() * 1000)) hass.config_entries.async_update_entry(entry, data=new_data) - device = hass.data[DOMAIN][TUYA_DEVICES][device_id] - if not device.connected: - device.async_connect() - elif device_id in hass.data[DOMAIN][TUYA_DEVICES]: + + if device_id in hass.data[DOMAIN][TUYA_DEVICES]: # _LOGGER.debug("Device %s found with IP %s", device_id, device_ip) device = hass.data[DOMAIN][TUYA_DEVICES][device_id] @@ -261,12 +264,34 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry): async def setup_entities(device_ids): platforms = set() + sub_devices = [] for dev_id in device_ids: - entities = entry.data[CONF_DEVICES][dev_id][CONF_ENTITIES] + device_entry = entry.data[CONF_DEVICES][dev_id] + entities = device_entry[CONF_ENTITIES] platforms = platforms.union( set(entity[CONF_PLATFORM] for entity in entities) ) - hass.data[DOMAIN][TUYA_DEVICES][dev_id] = TuyaDevice(hass, entry, dev_id) + device = TuyaDevice(hass, entry, dev_id) + hass.data[DOMAIN][TUYA_DEVICES][dev_id] = device + + # Register gateway device manually to HA + if device_entry.get(CONF_IS_GATEWAY, False): + dr = device_registry.async_get(hass) + dr.async_get_or_create( + config_entry_id=entry.entry_id, + identifiers={(DOMAIN, f"local_{dev_id}")}, + name=device_entry[CONF_FRIENDLY_NAME], + manufacturer="Tuya", + model=f"{device_entry[CONF_MODEL]} ({dev_id})", + sw_version=device_entry[CONF_PROTOCOL_VERSION], + ) + + if device.gateway_device_id: + sub_devices.append(device) + # at this point, the gateway should have been created, + # and we can register the sub-device + for sub_device in sub_devices: + sub_device.register_in_gateway() await asyncio.gather( *[ @@ -322,7 +347,7 @@ async def update_listener(hass, config_entry): async def async_remove_config_entry_device( - hass: HomeAssistant, config_entry: ConfigEntry, device_entry: DeviceEntry + hass: HomeAssistant, config_entry: ConfigEntry, device_entry: device_registry.DeviceEntry ) -> bool: """Remove a config entry from a device.""" dev_id = list(device_entry.identifiers)[0][1].split("_")[-1] diff --git a/custom_components/localtuya/common.py b/custom_components/localtuya/common.py index 79eadc9b1..08a8597da 100644 --- a/custom_components/localtuya/common.py +++ b/custom_components/localtuya/common.py @@ -13,6 +13,7 @@ CONF_ID, CONF_PLATFORM, CONF_SCAN_INTERVAL, + CONF_CLIENT_ID, ) from homeassistant.core import callback from homeassistant.helpers.dispatcher import ( @@ -31,6 +32,8 @@ DATA_CLOUD, DOMAIN, TUYA_DEVICES, + CONF_GATEWAY_DEVICE_ID, + CONF_IS_GATEWAY, ) _LOGGER = logging.getLogger(__name__) @@ -128,6 +131,7 @@ def __init__(self, hass, config_entry, dev_id): self._hass = hass self._config_entry = config_entry self._dev_config_entry = config_entry.data[CONF_DEVICES][dev_id].copy() + self.device_id = self._dev_config_entry[CONF_DEVICE_ID] self._interface = None self._status = {} self.dps_to_request = {} @@ -136,7 +140,15 @@ def __init__(self, hass, config_entry, dev_id): self._disconnect_task = None self._unsub_interval = None self._local_key = self._dev_config_entry[CONF_LOCAL_KEY] - self.set_logger(_LOGGER, self._dev_config_entry[CONF_DEVICE_ID]) + self.set_logger(_LOGGER, self.device_id) + self.is_gateway = self._dev_config_entry.get(CONF_IS_GATEWAY, False) + + # handling sub-devices + self._connected = False + self.gateway_device_id = self._dev_config_entry.get(CONF_GATEWAY_DEVICE_ID) + self.cid = self._dev_config_entry.get(CONF_CLIENT_ID) + self.gateway_device = None + self.sub_devices = {} # This has to be done in case the device type is type_0d for entity in self._dev_config_entry[CONF_ENTITIES]: @@ -145,33 +157,61 @@ def __init__(self, hass, config_entry, dev_id): @property def connected(self): """Return if connected to device.""" - return self._interface is not None + if self.gateway_device: + return self.gateway_device.connected + else: + return self._interface is not None + + def register_in_gateway(self): + self.gateway_device = self._hass.data[DOMAIN][TUYA_DEVICES][self.gateway_device_id] + self.gateway_device.sub_devices[self.cid] = self def async_connect(self): """Connect to device if not already connected.""" + if self.gateway_device and not self.gateway_device.connected: + # early return in case this is a sub-device, + # connect will be triggered by the gateway + return if not self._is_closing and self._connect_task is None and not self._interface: self._connect_task = asyncio.create_task(self._make_connection()) + async def _get_interface(self): + if self.gateway_device: + return self.gateway_device._interface + else: + return await pytuya.connect( + self._dev_config_entry[CONF_HOST], + self.device_id, + self._local_key, + float(self._dev_config_entry[CONF_PROTOCOL_VERSION]), + listener=self, + is_gateway=self.is_gateway, + ) + async def _make_connection(self): """Subscribe localtuya entity events.""" self.debug("Connecting to %s", self._dev_config_entry[CONF_HOST]) try: - self._interface = await pytuya.connect( - self._dev_config_entry[CONF_HOST], - self._dev_config_entry[CONF_DEVICE_ID], - self._local_key, - float(self._dev_config_entry[CONF_PROTOCOL_VERSION]), - self, - ) - self._interface.add_dps_to_request(self.dps_to_request) + self._interface = await self._get_interface() + if self.cid: + self._interface.add_sub_device(self.cid) - self.debug("Retrieving initial state") - status = await self._interface.status() - if status is None: - raise Exception("Failed to retrieve status") + # Query initial state except for gateways + if not self.is_gateway: + self._interface.add_dps_to_request(self.dps_to_request, cid=self.cid) + + self.debug("Retrieving initial state") + status = await self._interface.status(cid=self.cid) + if status is None: + raise Exception("Failed to retrieve status") - self.status_updated(status) + self.status_updated(status) + + self._connected = True + + if self._disconnect_task is not None: + self._disconnect_task() def _new_entity_handler(entity_id): self.debug( @@ -181,7 +221,7 @@ def _new_entity_handler(entity_id): ) self._dispatch_status() - signal = f"localtuya_entity_{self._dev_config_entry[CONF_DEVICE_ID]}" + signal = f"localtuya_entity_{self.device_id}" self._disconnect_task = async_dispatcher_connect( self._hass, signal, _new_entity_handler ) @@ -199,23 +239,31 @@ def _new_entity_handler(entity_id): self.exception( f"Connect to {self._dev_config_entry[CONF_HOST]} failed: %s", type(e) ) + self._connected = False if self._interface is not None: - await self._interface.close() + if not self.gateway_device: + await self._interface.close() self._interface = None except Exception as e: # pylint: disable=broad-except self.exception(f"Connect to {self._dev_config_entry[CONF_HOST]} failed") + self._connected = False if "json.decode" in str(type(e)): await self.update_local_key() if self._interface is not None: - await self._interface.close() + if not self.gateway_device: + await self._interface.close() self._interface = None self._connect_task = None + # if there are sub-devices, we have to connect them as well + for sub_device in self.sub_devices.values(): + sub_device._connect_task = asyncio.create_task(sub_device._make_connection()) + async def update_local_key(self): """Retrieve updated local_key from Cloud API and update the config_entry.""" - dev_id = self._dev_config_entry[CONF_DEVICE_ID] + dev_id = self.device_id await self._hass.data[DOMAIN][DATA_CLOUD].async_get_devices_list() cloud_devs = self._hass.data[DOMAIN][DATA_CLOUD].device_list if dev_id in cloud_devs: @@ -239,10 +287,11 @@ async def close(self): if self._connect_task is not None: self._connect_task.cancel() await self._connect_task - if self._interface is not None: + if self._interface is not None and not self.gateway_device: await self._interface.close() if self._disconnect_task is not None: self._disconnect_task() + self._connected = False self.debug( "Closed connection with device %s.", self._dev_config_entry[CONF_FRIENDLY_NAME], @@ -252,7 +301,7 @@ async def set_dp(self, state, dp_index): """Change value of a DP of the Tuya device.""" if self._interface is not None: try: - await self._interface.set_dp(state, dp_index) + await self._interface.set_dp(state, dp_index, cid=self._dev_config_entry.get(CONF_CLIENT_ID)) except Exception: # pylint: disable=broad-except self.exception("Failed to set DP %d to %d", dp_index, state) else: @@ -264,7 +313,7 @@ async def set_dps(self, states): """Change value of a DPs of the Tuya device.""" if self._interface is not None: try: - await self._interface.set_dps(states) + await self._interface.set_dps(states, cid=self._dev_config_entry.get(CONF_CLIENT_ID)) except Exception: # pylint: disable=broad-except self.exception("Failed to set DPs %r", states) else: @@ -275,17 +324,35 @@ async def set_dps(self, states): @callback def status_updated(self, status): """Device updated status.""" - self._status.update(status) - self._dispatch_status() + for device, device_status in status.items(): + if device == '_default': + # usual case for devices that are not using a gateway + self._status.update(device_status) + self._dispatch_status() + elif self.cid: + # check if this is a sub-device, and if so, update based on the cid + if self.cid == device: + self._status.update(device_status) + self._dispatch_status() + else: + _LOGGER.warning('Sub-device received a status update for a unknown cid: %s', device) + else: + # otherwise, we are dealing with a status for a sub-device in a gateway instance + sub_device = self.sub_devices.get(device) + if not sub_device: + _LOGGER.warning('Gateway received a status update for a unknown cid: %s', device) + return + sub_device._status.update(device_status) + sub_device._dispatch_status() def _dispatch_status(self): - signal = f"localtuya_{self._dev_config_entry[CONF_DEVICE_ID]}" + signal = f"localtuya_{self.device_id}" async_dispatcher_send(self._hass, signal, self._status) @callback def disconnected(self): """Device disconnected.""" - signal = f"localtuya_{self._dev_config_entry[CONF_DEVICE_ID]}" + signal = f"localtuya_{self.device_id}" async_dispatcher_send(self._hass, signal, None) if self._unsub_interval is not None: self._unsub_interval() @@ -301,11 +368,12 @@ def __init__(self, device, config_entry, dp_id, logger, **kwargs): """Initialize the Tuya entity.""" super().__init__() self._device = device + self.device_id = device.device_id self._dev_config_entry = config_entry self._config = get_entity_config(config_entry, dp_id) self._dp_id = dp_id self._status = {} - self.set_logger(logger, self._dev_config_entry[CONF_DEVICE_ID]) + self.set_logger(logger, self.device_id) async def async_added_to_hass(self): """Subscribe localtuya events.""" @@ -327,13 +395,13 @@ def _update_handler(status): self.status_updated() self.schedule_update_ha_state() - signal = f"localtuya_{self._dev_config_entry[CONF_DEVICE_ID]}" + signal = f"localtuya_{self.device_id}" self.async_on_remove( async_dispatcher_connect(self.hass, signal, _update_handler) ) - signal = f"localtuya_entity_{self._dev_config_entry[CONF_DEVICE_ID]}" + signal = f"localtuya_entity_{self.device_id}" async_dispatcher_send(self.hass, signal, self.entity_id) @property @@ -343,11 +411,11 @@ def device_info(self): return { "identifiers": { # Serial numbers are unique identifiers within a specific domain - (DOMAIN, f"local_{self._dev_config_entry[CONF_DEVICE_ID]}") + (DOMAIN, f"local_{self.device_id}") }, "name": self._dev_config_entry[CONF_FRIENDLY_NAME], "manufacturer": "Tuya", - "model": f"{model} ({self._dev_config_entry[CONF_DEVICE_ID]})", + "model": f"{model} ({self.device_id})", "sw_version": self._dev_config_entry[CONF_PROTOCOL_VERSION], } @@ -364,7 +432,7 @@ def should_poll(self): @property def unique_id(self): """Return unique device identifier.""" - return f"local_{self._dev_config_entry[CONF_DEVICE_ID]}_{self._dp_id}" + return f"local_{self.device_id}_{self._dp_id}" def has_config(self, attr): """Return if a config parameter has a valid value.""" diff --git a/custom_components/localtuya/config_flow.py b/custom_components/localtuya/config_flow.py index 695baa91c..3e3a76525 100644 --- a/custom_components/localtuya/config_flow.py +++ b/custom_components/localtuya/config_flow.py @@ -2,6 +2,7 @@ import errno import logging import time +import copy from importlib import import_module import homeassistant.helpers.config_validation as cv @@ -40,6 +41,8 @@ CONF_PROTOCOL_VERSION, CONF_SETUP_CLOUD, CONF_USER_ID, + CONF_GATEWAY_DEVICE_ID, + CONF_IS_GATEWAY, DATA_CLOUD, DATA_DISCOVERY, DOMAIN, @@ -82,12 +85,15 @@ CONFIGURE_DEVICE_SCHEMA = vol.Schema( { - vol.Required(CONF_FRIENDLY_NAME): str, - vol.Required(CONF_LOCAL_KEY): str, - vol.Required(CONF_HOST): str, - vol.Required(CONF_DEVICE_ID): str, + vol.Required(CONF_FRIENDLY_NAME): cv.string, + vol.Required(CONF_LOCAL_KEY): cv.string, + vol.Required(CONF_HOST): cv.string, + vol.Required(CONF_DEVICE_ID): cv.string, vol.Required(CONF_PROTOCOL_VERSION, default="3.3"): vol.In(["3.1", "3.3"]), + vol.Optional(CONF_GATEWAY_DEVICE_ID): cv.string, + vol.Optional(CONF_CLIENT_ID): cv.string, vol.Optional(CONF_SCAN_INTERVAL): int, + vol.Optional(CONF_IS_GATEWAY): cv.boolean, } ) @@ -99,7 +105,7 @@ vol.Required(CONF_FRIENDLY_NAME): cv.string, vol.Required(CONF_PROTOCOL_VERSION, default="3.3"): vol.In(["3.1", "3.3"]), vol.Optional(CONF_SCAN_INTERVAL): int, - } + }, ) PICK_ENTITY_SCHEMA = vol.Schema( @@ -119,12 +125,6 @@ def devices_schema(discovered_devices, cloud_devices_list, add_custom_device=Tru if add_custom_device: devices.update({CUSTOM_DEVICE: CUSTOM_DEVICE}) - # devices.update( - # { - # ent.data[CONF_DEVICE_ID]: ent.data[CONF_FRIENDLY_NAME] - # for ent in entries - # } - # ) return vol.Schema({vol.Required(SELECTED_DEVICE): vol.In(devices)}) @@ -135,11 +135,14 @@ def options_schema(entities): ] return vol.Schema( { - vol.Required(CONF_FRIENDLY_NAME): str, - vol.Required(CONF_HOST): str, - vol.Required(CONF_LOCAL_KEY): str, + vol.Required(CONF_FRIENDLY_NAME): cv.string, + vol.Required(CONF_LOCAL_KEY): cv.string, + vol.Required(CONF_HOST): cv.string, vol.Required(CONF_PROTOCOL_VERSION, default="3.3"): vol.In(["3.1", "3.3"]), + vol.Optional(CONF_GATEWAY_DEVICE_ID): cv.string, + vol.Optional(CONF_CLIENT_ID): cv.string, vol.Optional(CONF_SCAN_INTERVAL): int, + vol.Optional(CONF_IS_GATEWAY): cv.boolean, vol.Required( CONF_ENTITIES, description={"suggested_value": entity_names} ): cv.multi_select(entity_names), @@ -228,18 +231,26 @@ def config_schema(): async def validate_input(hass: core.HomeAssistant, data): """Validate the user input allows us to connect.""" - detected_dps = {} interface = None try: + device_id = data.get(CONF_GATEWAY_DEVICE_ID) + if not device_id or device_id == "": + device_id = data[CONF_DEVICE_ID] + interface = await pytuya.connect( data[CONF_HOST], - data[CONF_DEVICE_ID], + device_id, data[CONF_LOCAL_KEY], float(data[CONF_PROTOCOL_VERSION]), + is_gateway=data.get(CONF_IS_GATEWAY) or data.get(CONF_CLIENT_ID), ) + if data.get(CONF_CLIENT_ID): + interface.add_sub_device(data[CONF_CLIENT_ID]) - detected_dps = await interface.detect_available_dps() + # Do not detect DPS for gateways + if not data.get(CONF_IS_GATEWAY): + detected_dps = await interface.detect_available_dps(cid=data.get(CONF_CLIENT_ID)) except (ConnectionRefusedError, ConnectionResetError) as ex: raise CannotConnect from ex except ValueError as ex: @@ -249,11 +260,14 @@ async def validate_input(hass: core.HomeAssistant, data): await interface.close() # Indicate an error if no datapoints found as the rest of the flow - # won't work in this case - if not detected_dps: - raise EmptyDpsList + # won't work in this case, except for gateways as they have no DPS + if data.get(CONF_IS_GATEWAY): + return + else: + if not detected_dps: + raise EmptyDpsList - return dps_string_list(detected_dps) + return dps_string_list(detected_dps) async def attempt_cloud_connection(hass, user_input): @@ -437,14 +451,14 @@ async def async_step_add_device(self, user_input=None): return await self.async_step_configure_device() - self.discovered_devices = {} + discovered_devices = {} data = self.hass.data.get(DOMAIN) if data and DATA_DISCOVERY in data: - self.discovered_devices = data[DATA_DISCOVERY].devices + discovered_devices = copy.deepcopy(data[DATA_DISCOVERY].devices) else: try: - self.discovered_devices = await discover() + discovered_devices = await discover() except OSError as ex: if ex.errno == errno.EADDRINUSE: errors["base"] = "address_in_use" @@ -454,11 +468,32 @@ async def async_step_add_device(self, user_input=None): _LOGGER.exception("discovery failed") errors["base"] = "discovery_failed" - devices = { - dev_id: dev["ip"] - for dev_id, dev in self.discovered_devices.items() - if dev["gwId"] not in self.config_entry.data[CONF_DEVICES] - } + # device category reference + # https://developer.tuya.com/en/docs/cloud/303a03de7e?id=Kb2us379ab2mi + gateways = {} + for dev_id, dev in data[DATA_CLOUD].device_list.items(): + if dev['category'] == 'wg2': + # identify gateways by local key since sub-devices under + # this gateway use the gateway local key + gateways[dev['local_key']] = dev_id + + devices = {} + for dev_id, dev in data[DATA_CLOUD].device_list.items(): + if dev_id not in self.config_entry.data[CONF_DEVICES]: + if dev_id in discovered_devices: + devices[dev_id] = discovered_devices[dev_id]["ip"] + elif dev['local_key'] in gateways: + gateway_id = gateways[dev['local_key']] + # this device uses the gateway api for communication + discovered_devices[dev_id] = discovered_devices[gateway_id].copy() + discovered_devices[dev_id][CONF_GATEWAY_DEVICE_ID] = discovered_devices[gateway_id]['gwId'] + devices[dev_id] = discovered_devices[gateway_id]["ip"] + # this should be the mac address of the sub device + discovered_devices[dev_id][CONF_CLIENT_ID] = dev['node_id'] + # keep the original device id in the gwId + discovered_devices[dev_id]['gwId'] = dev_id + + self.discovered_devices = discovered_devices return self.async_show_form( step_id="add_device", @@ -500,10 +535,8 @@ async def async_step_configure_device(self, user_input=None): if user_input is not None: try: self.device_data = user_input.copy() + if dev_id is not None: - # self.device_data[CONF_PRODUCT_KEY] = self.devices[ - # self.selected_device - # ]["productKey"] cloud_devs = self.hass.data[DOMAIN][DATA_CLOUD].device_list if dev_id in cloud_devs: self.device_data[CONF_MODEL] = cloud_devs[dev_id].get( @@ -530,8 +563,14 @@ async def async_step_configure_device(self, user_input=None): ] return await self.async_step_configure_entity() - self.dps_strings = await validate_input(self.hass, user_input) - return await self.async_step_pick_entity_type() + # Skip DPS assignment & entity addition if we're adding a gateway + if self.device_data.get(CONF_IS_GATEWAY): + await validate_input(self.hass, self.device_data) + return await self.async_step_pick_entity_type({ NO_ADDITIONAL_ENTITIES: True }) + else: + self.dps_strings = await validate_input(self.hass, self.device_data) + return await self.async_step_pick_entity_type() + except CannotConnect: errors["base"] = "cannot_connect" except InvalidAuth: @@ -554,12 +593,16 @@ async def async_step_configure_device(self, user_input=None): defaults[CONF_DEVICE_ID] = "" defaults[CONF_LOCAL_KEY] = "" defaults[CONF_FRIENDLY_NAME] = "" + defaults[CONF_GATEWAY_DEVICE_ID] = "" + defaults[CONF_CLIENT_ID] = "" if dev_id is not None: # Insert default values from discovery and cloud if present device = self.discovered_devices[dev_id] defaults[CONF_HOST] = device.get("ip") defaults[CONF_DEVICE_ID] = device.get("gwId") defaults[CONF_PROTOCOL_VERSION] = device.get("version") + defaults[CONF_GATEWAY_DEVICE_ID] = device.get(CONF_GATEWAY_DEVICE_ID, "") + defaults[CONF_CLIENT_ID] = device.get(CONF_CLIENT_ID, "") cloud_devs = self.hass.data[DOMAIN][DATA_CLOUD].device_list if dev_id in cloud_devs: defaults[CONF_LOCAL_KEY] = cloud_devs[dev_id].get(CONF_LOCAL_KEY) @@ -679,7 +722,7 @@ async def async_step_configure_entity(self, user_input=None): new_data = self.config_entry.data.copy() entry_id = self.config_entry.entry_id # removing entities from registry (they will be recreated) - ent_reg = await er.async_get_registry(self.hass) + ent_reg = er.async_get(self.hass) reg_entities = { ent.unique_id: ent.entity_id for ent in er.async_entries_for_config_entry(ent_reg, entry_id) diff --git a/custom_components/localtuya/const.py b/custom_components/localtuya/const.py index c94030432..6c832fb89 100644 --- a/custom_components/localtuya/const.py +++ b/custom_components/localtuya/const.py @@ -40,6 +40,8 @@ CONF_EDIT_DEVICE = "edit_device" CONF_SETUP_CLOUD = "setup_cloud" CONF_NO_CLOUD = "no_cloud" +CONF_GATEWAY_DEVICE_ID = "gateway_device_id" +CONF_IS_GATEWAY = "is_gateway" # light CONF_BRIGHTNESS_LOWER = "brightness_lower" diff --git a/custom_components/localtuya/pytuya/__init__.py b/custom_components/localtuya/pytuya/__init__.py index b7645ec25..56bf03867 100644 --- a/custom_components/localtuya/pytuya/__init__.py +++ b/custom_components/localtuya/pytuya/__init__.py @@ -1,7 +1,7 @@ # PyTuya Module # -*- coding: utf-8 -*- """ -Python module to interface with Tuya WiFi smart devices. +Python module to interface with Tuya WiFi, Zigbee, or Bluetooth smart devices. Mostly derived from Shenzhen Xenon ESP8266MOD WiFi smart devices E.g. https://wikidevi.com/wiki/Xenon_SM-PW701U @@ -12,20 +12,25 @@ For more information see https://github.com/clach04/python-tuya Classes - TuyaInterface(dev_id, address, local_key=None) + TuyaProtocol(dev_id, local_key, protocol_version, on_connected, listener, is_gateway) dev_id (str): Device ID e.g. 01234567891234567890 - address (str): Device Network IP Address e.g. 10.0.1.99 - local_key (str, optional): The encryption key. Defaults to None. + local_key (str): The encryption key, obtainable via iot.tuya.com + protocol_version (float): The protocol version (3.1 or 3.3). + on_connected (object): Callback when connected. + listener (object): Listener for events such as status updates. + is_gateway (bool): Specifies if this is a gateway. Functions - json = status() # returns json payload - set_version(version) # 3.1 [default] or 3.3 - detect_available_dps() # returns a list of available dps provided by the device - update_dps(dps) # sends update dps command - add_dps_to_request(dp_index) # adds dp_index to the list of dps used by the - # device (to be queried in the payload) - set_dp(on, dp_index) # Set value of any dps index. - + json = status() # returns json payload for current dps status + detect_available_dps() # returns a list of available dps provided by the device + update_dps(dps) # sends update dps command + add_dps_to_request(dp_index, cid) # adds dp_index to the list of dps used by the + # device (to be queried in the payload), optionally + # with sub-device cid if this is a gateway + set_dp(on, dp_index, cid) # Set value of any dps index, optionally with cid if this is a gateway + set_dps(dps, cid) # Set values of a set of dps, optionally with cid if this is a gateway + add_sub_device(cid) # Adds a sub-device to a gateway + remove_sub_device(cid) # Removes a sub-device Credits * TuyaAPI https://github.com/codetheweb/tuyapi by codetheweb and blackrozes @@ -57,12 +62,12 @@ _LOGGER = logging.getLogger(__name__) -TuyaMessage = namedtuple("TuyaMessage", "seqno cmd retcode payload crc") +TuyaMessage = namedtuple("TuyaMessage", "seqno cmd retcode payload crc crcpassed") -SET = "set" -STATUS = "status" -HEARTBEAT = "heartbeat" -UPDATEDPS = "updatedps" # Request refresh of DPS +ACTION_SET = "set" +ACTION_STATUS = "status" +ACTION_HEARTBEAT = "heartbeat" +ACTION_UPDATEDPS = "updatedps" # Request refresh of DPS PROTOCOL_VERSION_BYTES_31 = b"3.1" PROTOCOL_VERSION_BYTES_33 = b"3.3" @@ -81,6 +86,17 @@ # DPS that are known to be safe to use with update_dps (0x12) command UPDATE_DPS_WHITELIST = [18, 19, 20] # Socket (Wi-Fi) +DEV_TYPE_0A = "type_0a" # DP_QUERY +DEV_TYPE_0D = "type_0d" # CONTROL_NEW + +COMMAND_DP_QUERY = 0x0A +COMMAND_CONTROL_NEW = 0x0D +COMMAND_SET = 0x07 +PUSH_STATUS = 0x08 +COMMAND_HEARTBEAT = 0x09 +COMMAND_DP_QUERY_NEW = 0x10 +COMMAND_UPDATE_DPS = 0x12 + # This is intended to match requests.json payload at # https://github.com/codetheweb/tuyapi : # type_0a devices require the 0a command as the status request @@ -90,18 +106,26 @@ # prefix: # Next byte is command byte ("hexByte") some zero padding, then length # of remaining payload, i.e. command + suffix (unclear if multiple bytes used for # length, zero padding implies could be more than one byte) +GATEWAY_PAYLOAD_DICT = { + # TYPE_0A should never be used with gateways + DEV_TYPE_0D: { + ACTION_STATUS: {"hexByte": COMMAND_DP_QUERY_NEW, "command": {"cid": ""}}, + ACTION_SET: {"hexByte": COMMAND_CONTROL_NEW, "command": {"cid": "", "t": ""}}, + ACTION_HEARTBEAT: {"hexByte": COMMAND_HEARTBEAT, "command": {}}, + }, +} PAYLOAD_DICT = { - "type_0a": { - STATUS: {"hexByte": 0x0A, "command": {"gwId": "", "devId": ""}}, - SET: {"hexByte": 0x07, "command": {"devId": "", "uid": "", "t": ""}}, - HEARTBEAT: {"hexByte": 0x09, "command": {}}, - UPDATEDPS: {"hexByte": 0x12, "command": {"dpId": [18, 19, 20]}}, + DEV_TYPE_0A: { + ACTION_STATUS: {"hexByte": COMMAND_DP_QUERY, "command": {"gwId": "", "devId": "", "uid": ""}}, + ACTION_SET: {"hexByte": COMMAND_SET, "command": {"devId": "", "uid": "", "t": ""}}, + ACTION_HEARTBEAT: {"hexByte": COMMAND_HEARTBEAT, "command": {}}, + ACTION_UPDATEDPS: {"hexByte": COMMAND_UPDATE_DPS, "command": {"dpId": [18, 19, 20]}}, }, - "type_0d": { - STATUS: {"hexByte": 0x0D, "command": {"devId": "", "uid": "", "t": ""}}, - SET: {"hexByte": 0x07, "command": {"devId": "", "uid": "", "t": ""}}, - HEARTBEAT: {"hexByte": 0x09, "command": {}}, - UPDATEDPS: {"hexByte": 0x12, "command": {"dpId": [18, 19, 20]}}, + DEV_TYPE_0D: { + ACTION_STATUS: {"hexByte": COMMAND_DP_QUERY_NEW, "command": {"gwId": "", "devId": "", "uid": ""}}, + ACTION_SET: {"hexByte": COMMAND_SET, "command": {"devId": "", "uid": "", "t": ""}}, + ACTION_HEARTBEAT: {"hexByte": COMMAND_HEARTBEAT, "command": {}}, + ACTION_UPDATEDPS: {"hexByte": COMMAND_UPDATE_DPS, "command": {"dpId": [18, 19, 20]}}, }, } @@ -151,14 +175,14 @@ def pack_message(msg): """Pack a TuyaMessage into bytes.""" # Create full message excluding CRC and suffix buffer = ( - struct.pack( - MESSAGE_HEADER_FMT, - PREFIX_VALUE, - msg.seqno, - msg.cmd, - len(msg.payload) + struct.calcsize(MESSAGE_END_FMT), - ) - + msg.payload + struct.pack( + MESSAGE_HEADER_FMT, + PREFIX_VALUE, + msg.seqno, + msg.cmd, + len(msg.payload) + struct.calcsize(MESSAGE_END_FMT), + ) + + msg.payload ) # Calculate CRC, add it together with suffix @@ -167,19 +191,6 @@ def pack_message(msg): return buffer -def unpack_message(data): - """Unpack bytes into a TuyaMessage.""" - header_len = struct.calcsize(MESSAGE_RECV_HEADER_FMT) - end_len = struct.calcsize(MESSAGE_END_FMT) - - _, seqno, cmd, _, retcode = struct.unpack( - MESSAGE_RECV_HEADER_FMT, data[:header_len] - ) - payload = data[header_len:-end_len] - crc, _ = struct.unpack(MESSAGE_END_FMT, data[-end_len:]) - return TuyaMessage(seqno, cmd, retcode, payload, crc) - - class AESCipher: """Cipher module for Tuya communication.""" @@ -208,7 +219,7 @@ def _pad(self, data): @staticmethod def _unpad(data): - return data[: -ord(data[len(data) - 1 :])] + return data[: -ord(data[len(data) - 1:])] class MessageDispatcher(ContextualLogger): @@ -257,7 +268,7 @@ def add_data(self, data): header_len = struct.calcsize(MESSAGE_RECV_HEADER_FMT) while self.buffer: - # Check if enough data for measage header + # Check if enough data for message header if len(self.buffer) < header_len: break @@ -265,7 +276,7 @@ def add_data(self, data): _, seqno, cmd, length, retcode = struct.unpack_from( MESSAGE_RECV_HEADER_FMT, self.buffer ) - if len(self.buffer[header_len - 4 :]) < length: + if len(self.buffer[header_len - 4:]) < length: break # length includes payload length, retcode, crc and suffix @@ -275,15 +286,19 @@ def add_data(self, data): else: payload_start = header_len payload_length = length - 4 - struct.calcsize(MESSAGE_END_FMT) - payload = self.buffer[payload_start : payload_start + payload_length] + payload = self.buffer[payload_start: payload_start + payload_length] crc, _ = struct.unpack_from( MESSAGE_END_FMT, - self.buffer[payload_start + payload_length : payload_start + length], + self.buffer[payload_start + payload_length: payload_start + length], ) - self.buffer = self.buffer[header_len - 4 + length :] - self._dispatch(TuyaMessage(seqno, cmd, retcode, payload, crc)) + # CRC calculated from prefix to end of payload + crc_calc = binascii.crc32(self.buffer[:header_len + payload_length]) + + self.buffer = self.buffer[header_len - 4 + length:] + + self._dispatch(TuyaMessage(seqno, cmd, retcode, payload, crc, crc == crc_calc)) def _dispatch(self, msg): """Dispatch a message to someone that is listening.""" @@ -293,17 +308,21 @@ def _dispatch(self, msg): sem = self.listeners[msg.seqno] self.listeners[msg.seqno] = msg sem.release() - elif msg.cmd == 0x09: + elif msg.cmd == COMMAND_HEARTBEAT: self.debug("Got heartbeat response") if self.HEARTBEAT_SEQNO in self.listeners: sem = self.listeners[self.HEARTBEAT_SEQNO] self.listeners[self.HEARTBEAT_SEQNO] = msg sem.release() - elif msg.cmd == 0x12: + elif msg.cmd == COMMAND_UPDATE_DPS: self.debug("Got normal updatedps response") - elif msg.cmd == 0x08: + elif msg.cmd == PUSH_STATUS: self.debug("Got status update") self.listener(msg) + elif msg.cmd == COMMAND_DP_QUERY_NEW: + self.debug("Got dp_query_new response") + elif msg.cmd == COMMAND_CONTROL_NEW: + self.debug("Got control_new response") else: self.debug( "Got message type %d for unknown listener %d: %s", @@ -338,26 +357,27 @@ def disconnected(self): class TuyaProtocol(asyncio.Protocol, ContextualLogger): """Implementation of the Tuya protocol.""" - def __init__(self, dev_id, local_key, protocol_version, on_connected, listener): + def __init__(self, dev_id, local_key, protocol_version, on_connected, listener, is_gateway): """ Initialize a new TuyaInterface. Args: dev_id (str): The device id. - address (str): The network address. - local_key (str, optional): The encryption key. Defaults to None. - - Attributes: - port (int): The port to connect to. + local_key (str): The encryption key. + protocol_version (float): The protocol version (3.1 or 3.3). + on_connected (object): Callback when connected. + listener (object): Listener for events such as status updates. + is_gateway (bool): Specifies if this is a gateway. """ super().__init__() self.loop = asyncio.get_running_loop() self.set_logger(_LOGGER, dev_id) self.id = dev_id + self.is_gateway = is_gateway self.local_key = local_key.encode("latin1") self.version = protocol_version - self.dev_type = "type_0a" - self.dps_to_request = {} + self.dev_type = DEV_TYPE_0D if is_gateway else DEV_TYPE_0A + self.dps_to_request = {"_default": {}} self.cipher = AESCipher(self.local_key) self.seqno = 0 self.transport = None @@ -365,19 +385,36 @@ def __init__(self, dev_id, local_key, protocol_version, on_connected, listener): self.dispatcher = self._setup_dispatcher() self.on_connected = on_connected self.heartbeater = None - self.dps_cache = {} + self.dps_cache = {"_default": {}} + self.sub_devices = [] + + def _parse_cid(self, cid): + if cid: + if cid not in self.sub_devices: + raise Exception("Unexpected sub-device cid", cid) + return cid or '_default' def _setup_dispatcher(self): - def _status_update(msg): - decoded_message = self._decode_payload(msg.payload) - if "dps" in decoded_message: - self.dps_cache.update(decoded_message["dps"]) + """Sets up message dispatcher for this pytuya instance""" + return MessageDispatcher(self.id, self._status_update) - listener = self.listener and self.listener() - if listener is not None: - listener.status_updated(self.dps_cache) + def _status_update(self, msg): + """Handle status updates""" + decoded_message = self._decode_payload(msg.payload) + + # This could happen if not all Zigbee sub-devices have been + # added to a gateway + if self.is_gateway: + cid = decoded_message.get("cid") + if cid not in self.sub_devices: + return - return MessageDispatcher(self.id, _status_update) + self._update_dps_cache(decoded_message) + + listener = self.listener and self.listener() + if listener is not None: + device = decoded_message.get('cid', '_default') + listener.status_updated({device: self.dps_cache[device]}) def connection_made(self, transport): """Did connect to the device.""" @@ -439,20 +476,20 @@ async def close(self): self.transport = None transport.close() - async def exchange(self, command, dps=None): + async def exchange(self, command, dps=None, cid=None): """Send and receive a message, returning response from device.""" self.debug( "Sending command %s (device type: %s)", command, self.dev_type, ) - payload = self._generate_payload(command, dps) + payload = self._generate_payload(command, dps, cid) dev_type = self.dev_type # Wait for special sequence number if heartbeat seqno = ( MessageDispatcher.HEARTBEAT_SEQNO - if command == HEARTBEAT + if command == ACTION_HEARTBEAT else (self.seqno - 1) ) @@ -462,7 +499,10 @@ async def exchange(self, command, dps=None): self.debug("Wait was aborted for seqno %d", seqno) return None - # TODO: Verify stuff, e.g. CRC sequence number? + if not msg.crcpassed: + self.debug("CRC for sequence number %d failed, resending command %s", seqno, command) + return await self.exchange(command, dps, cid) + payload = self._decode_payload(msg.payload) # Perform a new exchange (once) if we switched device type @@ -473,19 +513,22 @@ async def exchange(self, command, dps=None): dev_type, self.dev_type, ) - return await self.exchange(command, dps) + return await self.exchange(command, dps, cid) + return payload - async def status(self): + async def status(self, cid=None): """Return device status.""" - status = await self.exchange(STATUS) - if status and "dps" in status: - self.dps_cache.update(status["dps"]) - return self.dps_cache + device = self._parse_cid(cid) + status = await self.exchange(ACTION_STATUS, cid=cid) + if not status: # Happens when there's an error in decoding + return None + self._update_dps_cache(status) + return {device: self.dps_cache[device]} async def heartbeat(self): """Send a heartbeat message.""" - return await self.exchange(HEARTBEAT) + return await self.exchange(ACTION_HEARTBEAT) async def update_dps(self, dps=None): """ @@ -503,77 +546,107 @@ async def update_dps(self, dps=None): # filter non whitelisted dps dps = list(set(dps).intersection(set(UPDATE_DPS_WHITELIST))) self.debug("updatedps() entry (dps %s, dps_cache %s)", dps, self.dps_cache) - payload = self._generate_payload(UPDATEDPS, dps) + payload = self._generate_payload(ACTION_UPDATEDPS, dps) self.transport.write(payload) return True - async def set_dp(self, value, dp_index): + async def set_dp(self, value, dp_index, cid=None): """ Set value (may be any type: bool, int or string) of any dps index. Args: dp_index(int): dps index to set value: new value for the dps index + cid: Client ID of sub-device """ - return await self.exchange(SET, {str(dp_index): value}) + self._parse_cid(cid) + return await self.exchange(ACTION_SET, {str(dp_index): value}, cid) - async def set_dps(self, dps): + async def set_dps(self, dps, cid=None): """Set values for a set of datapoints.""" - return await self.exchange(SET, dps) + self._parse_cid(cid) + return await self.exchange(ACTION_SET, dps, cid) - async def detect_available_dps(self): + async def detect_available_dps(self, cid=None): """Return which datapoints are supported by the device.""" + # type_0d devices need a sort of bruteforce querying in order to detect the # list of available dps experience shows that the dps available are usually # in the ranges [1-25] and [100-110] need to split the bruteforcing in # different steps due to request payload limitation (max. length = 255) - self.dps_cache = {} ranges = [(2, 11), (11, 21), (21, 31), (100, 111)] + device = self._parse_cid(cid) + self.dps_cache[device] = {} for dps_range in ranges: # dps 1 must always be sent, otherwise it might fail in case no dps is found # in the requested range - self.dps_to_request = {"1": None} - self.add_dps_to_request(range(*dps_range)) + self.dps_to_request[device] = {"1": None} + self.add_dps_to_request(range(*dps_range), cid) try: - data = await self.status() + status = await self.status(cid) + self._update_dps_cache(status) except Exception as ex: - self.exception("Failed to get status: %s", ex) + self.exception("Failed to get status for %s: %s", device, ex) raise - if "dps" in data: - self.dps_cache.update(data["dps"]) - if self.dev_type == "type_0a": - return self.dps_cache - self.debug("Detected dps: %s", self.dps_cache) - return self.dps_cache + self.debug("Detected dps for %s: %s", device, self.dps_cache[device]) - def add_dps_to_request(self, dp_indicies): + return self.dps_cache[device] + + def add_dps_to_request(self, dp_indicies, cid=None): """Add a datapoint (DP) to be included in requests.""" + device = self._parse_cid(cid) if isinstance(dp_indicies, int): - self.dps_to_request[str(dp_indicies)] = None + self.dps_to_request[device][str(dp_indicies)] = None else: - self.dps_to_request.update({str(index): None for index in dp_indicies}) + self.dps_to_request[device].update({str(index): None for index in dp_indicies}) + + def add_sub_device(self, cid): + """Add a sub-device for a gateway device""" + + if not self.is_gateway: + raise Exception("Attempt to add sub-device to a non-gateway device") + + if cid in self.sub_devices: + return + + self.sub_devices.append(cid) + self.dps_to_request[cid] = {} + self.dps_cache[cid] = {} + + def remove_sub_device(self, cid): + """Removes a sub-device for a gateway device""" + if not self.is_gateway: + raise Exception("Attempt to remove sub-device from a non-gateway device") + + if cid not in self.sub_devices: + return + + self.sub_devices.remove(cid) + del self.dps_to_request[cid] + del self.dps_cache[cid] def _decode_payload(self, payload): + """Decodes payload received from a Tuya device""" if not payload: payload = "{}" elif payload.startswith(b"{"): pass elif payload.startswith(PROTOCOL_VERSION_BYTES_31): - payload = payload[len(PROTOCOL_VERSION_BYTES_31) :] # remove version header + payload = payload[len(PROTOCOL_VERSION_BYTES_31):] # remove version header # remove (what I'm guessing, but not confirmed is) 16-bytes of MD5 # hexdigest of payload payload = self.cipher.decrypt(payload[16:]) elif self.version == 3.3: - if self.dev_type != "type_0a" or payload.startswith( - PROTOCOL_VERSION_BYTES_33 + if payload.startswith( + PROTOCOL_VERSION_BYTES_33 ): - payload = payload[len(PROTOCOL_33_HEADER) :] + payload = payload[len(PROTOCOL_33_HEADER):] payload = self.cipher.decrypt(payload, False) if "data unvalid" in payload: - self.dev_type = "type_0d" + self.dev_type = DEV_TYPE_0D self.debug( "switching to dev_type %s", self.dev_type, @@ -587,7 +660,7 @@ def _decode_payload(self, payload): self.debug("Decrypted payload: %s", payload) return json.loads(payload) - def _generate_payload(self, command, data=None): + def _generate_payload(self, command, data=None, cid=None): """ Generate the payload to send. @@ -596,8 +669,19 @@ def _generate_payload(self, command, data=None): This is one of the entries from payload_dict data(dict, optional): The data to be send. This is what will be passed via the 'dps' entry + cid(str, optional): The sub-device CID to send """ - cmd_data = PAYLOAD_DICT[self.dev_type][command] + + if cid: + if command != ACTION_HEARTBEAT: + if cid not in self.sub_devices: + raise Exception("Unexpected sub-device cid", cid) + + payload_dict = GATEWAY_PAYLOAD_DICT + else: + payload_dict = PAYLOAD_DICT + + cmd_data = payload_dict[self.dev_type][command] json_data = cmd_data["command"] command_hb = cmd_data["hexByte"] @@ -607,6 +691,8 @@ def _generate_payload(self, command, data=None): json_data["devId"] = self.id if "uid" in json_data: json_data["uid"] = self.id # still use id, no separate uid + if "cid" in json_data: + json_data["cid"] = cid # for Zigbee gateways, cid specifies the sub-device if "t" in json_data: json_data["t"] = str(int(time.time())) @@ -615,53 +701,68 @@ def _generate_payload(self, command, data=None): json_data["dpId"] = data else: json_data["dps"] = data - elif command_hb == 0x0D: - json_data["dps"] = self.dps_to_request + elif command_hb == COMMAND_CONTROL_NEW: + if cid: + json_data["dps"] = self.dps_to_request[cid] + else: + json_data["dps"] = self.dps_to_request['_default'] payload = json.dumps(json_data).replace(" ", "").encode("utf-8") self.debug("Send payload: %s", payload) if self.version == 3.3: payload = self.cipher.encrypt(payload, False) - if command_hb not in [0x0A, 0x12]: + if command_hb not in [ + COMMAND_DP_QUERY, + COMMAND_DP_QUERY_NEW, + COMMAND_UPDATE_DPS + ]: # add the 3.3 header payload = PROTOCOL_33_HEADER + payload - elif command == SET: + elif command == ACTION_SET: payload = self.cipher.encrypt(payload) to_hash = ( - b"data=" - + payload - + b"||lpv=" - + PROTOCOL_VERSION_BYTES_31 - + b"||" - + self.local_key + b"data=" + + payload + + b"||lpv=" + + PROTOCOL_VERSION_BYTES_31 + + b"||" + + self.local_key ) hasher = md5() hasher.update(to_hash) hexdigest = hasher.hexdigest() payload = ( - PROTOCOL_VERSION_BYTES_31 - + hexdigest[8:][:16].encode("latin1") - + payload + PROTOCOL_VERSION_BYTES_31 + + hexdigest[8:][:16].encode("latin1") + + payload ) - msg = TuyaMessage(self.seqno, command_hb, 0, payload, 0) + msg = TuyaMessage(self.seqno, command_hb, 0, payload, 0, True) self.seqno += 1 return pack_message(msg) + def _update_dps_cache(self, status): + """Updates dps status cache""" + if not status or "dps" not in status: + return + device = status.get("cid", "_default") + self.dps_cache[device].update(status["dps"]) + def __repr__(self): """Return internal string representation of object.""" return self.id async def connect( - address, - device_id, - local_key, - protocol_version, - listener=None, - port=6668, - timeout=5, + address, + device_id, + local_key, + protocol_version, + listener=None, + port=6668, + timeout=5, + is_gateway=False, ): """Connect to a device.""" loop = asyncio.get_running_loop() @@ -673,7 +774,8 @@ async def connect( protocol_version, on_connected, listener or EmptyListener(), - ), + is_gateway, + ), address, port, ) diff --git a/custom_components/localtuya/strings.json b/custom_components/localtuya/strings.json index 4db7e7078..4443bf9f8 100644 --- a/custom_components/localtuya/strings.json +++ b/custom_components/localtuya/strings.json @@ -23,7 +23,7 @@ }, "power_outlet": { "title": "Add subswitch", - "description": "You are about to add subswitch number `{number}`. If you want to add another, tick `Add another switch` before continuing.", + "description": "You are about to add subswitch number `{number}`. If you want to add another, tick `Add another switch` before continuing.", "data": { "id": "ID", "name": "Name", diff --git a/custom_components/localtuya/translations/en.json b/custom_components/localtuya/translations/en.json index 82f406446..d6ccb3f5c 100644 --- a/custom_components/localtuya/translations/en.json +++ b/custom_components/localtuya/translations/en.json @@ -95,8 +95,10 @@ "device_id": "Device ID", "local_key": "Local key", "protocol_version": "Protocol Version", + "gateway_device_id": "Gateway device ID", + "client_id": "Mac Address of the gateway sub-device", "scan_interval": "Scan interval (seconds, only when not updating automatically)", - "entities": "Entities (uncheck an entity to remove it)" + "is_gateway": "This is a gateway (Zigbee / Bluetooth)" } }, "pick_entity_type": {