Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support adding and removing entities to devices #191

Merged
merged 2 commits into from
Dec 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions custom_components/localtuya/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,14 @@
import logging

import homeassistant.helpers.config_validation as cv
import homeassistant.helpers.entity_registry as er
import voluptuous as vol
from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry
from homeassistant.const import (
CONF_DEVICE_ID,
CONF_ENTITIES,
CONF_HOST,
CONF_ID,
CONF_PLATFORM,
EVENT_HOMEASSISTANT_STOP,
SERVICE_RELOAD,
Expand All @@ -72,7 +74,7 @@
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.reload import async_integration_yaml_config

from .common import TuyaDevice
from .common import TuyaDevice, async_config_entry_by_device_id
from .config_flow import config_schema
from .const import CONF_PRODUCT_KEY, DATA_DISCOVERY, DOMAIN, TUYA_DEVICE
from .discovery import TuyaDiscovery
Expand Down Expand Up @@ -112,14 +114,6 @@ async def async_setup(hass: HomeAssistant, config: dict):

device_cache = {}

def _entry_by_device_id(device_id):
"""Look up config entry by device id."""
current_entries = hass.config_entries.async_entries(DOMAIN)
for entry in current_entries:
if entry.data[CONF_DEVICE_ID] == device_id:
return entry
return None

async def _handle_reload(service):
"""Handle reload service call."""
config = await async_integration_yaml_config(hass, DOMAIN)
Expand All @@ -142,7 +136,7 @@ async def _handle_reload(service):

async def _handle_set_dp(event):
"""Handle set_dp service call."""
entry = _entry_by_device_id(event.data[CONF_DEVICE_ID])
entry = async_config_entry_by_device_id(hass, event.data[CONF_DEVICE_ID])
if not entry:
raise HomeAssistantError("unknown device id")

Expand All @@ -163,7 +157,7 @@ def _device_discovered(device):

# If device is not in cache, check if a config entry exists
if device_id not in device_cache:
entry = _entry_by_device_id(device_id)
entry = async_config_entry_by_device_id(hass, device_id)
if entry:
# Save address from config entry in cache to trigger
# potential update below
Expand All @@ -172,7 +166,7 @@ def _device_discovered(device):
if device_id not in device_cache:
return

entry = _entry_by_device_id(device_id)
entry = async_config_entry_by_device_id(hass, device_id)
if entry is None:
return

Expand Down Expand Up @@ -252,6 +246,8 @@ async def setup_entities():
]
)

await async_remove_orphan_entities(hass, entry)

hass.async_create_task(setup_entities())

return True
Expand Down Expand Up @@ -281,3 +277,19 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry):
async def update_listener(hass, config_entry):
"""Update listener."""
await hass.config_entries.async_reload(config_entry.entry_id)


async def async_remove_orphan_entities(hass, entry):
"""Remove entities associated with config entry that has been removed."""
ent_reg = await er.async_get_registry(hass)
entities = {
int(ent.unique_id.split("_")[-1]): ent.entity_id
for ent in er.async_entries_for_config_entry(ent_reg, entry.entry_id)
}

for entity in entry.data[CONF_ENTITIES]:
if entity[CONF_ID] in entities:
del entities[entity[CONF_ID]]

for entity_id in entities.values():
ent_reg.async_remove(entity_id)
110 changes: 79 additions & 31 deletions custom_components/localtuya/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import homeassistant.helpers.config_validation as cv
import voluptuous as vol
from homeassistant import config_entries, core, exceptions
from homeassistant.config_entries import SOURCE_IMPORT
from homeassistant.const import (
CONF_DEVICE_ID,
CONF_ENTITIES,
Expand All @@ -16,7 +17,7 @@
)
from homeassistant.core import callback

from . import pytuya
from .common import async_config_entry_by_device_id, pytuya
from .const import CONF_DPS_STRINGS # pylint: disable=unused-import
from .const import (
CONF_LOCAL_KEY,
Expand Down Expand Up @@ -46,14 +47,6 @@
}
)

OPTIONS_SCHEMA = vol.Schema(
{
vol.Required(CONF_FRIENDLY_NAME): str,
vol.Required(CONF_HOST): str,
vol.Required(CONF_LOCAL_KEY): str,
vol.Required(CONF_PROTOCOL_VERSION, default="3.3"): vol.In(["3.1", "3.3"]),
}
)

DEVICE_SCHEMA = vol.Schema(
{
Expand All @@ -70,11 +63,37 @@
)


def user_schema(devices):
def user_schema(devices, entries):
"""Create schema for user step."""
devices = [f"{ip} ({dev['gwId']})" for ip, dev in devices.items()]
devices = {dev_id: dev["ip"] for dev_id, dev in devices.items()}
devices.update(
{
ent.data[CONF_DEVICE_ID]: ent.data[CONF_FRIENDLY_NAME]
for ent in entries
if ent.source != SOURCE_IMPORT
}
)
device_list = [f"{key} ({value})" for key, value in devices.items()]
return vol.Schema(
{vol.Required(DISCOVERED_DEVICE): vol.In(device_list + [CUSTOM_DEVICE])}
)


def options_schema(entities):
"""Create schema for options."""
entity_names = [
f"{entity[CONF_ID]} {entity[CONF_FRIENDLY_NAME]}" for entity in entities
]
return vol.Schema(
{vol.Required(DISCOVERED_DEVICE): vol.In(devices + [CUSTOM_DEVICE])}
{
vol.Required(CONF_FRIENDLY_NAME): str,
vol.Required(CONF_HOST): str,
vol.Required(CONF_LOCAL_KEY): str,
vol.Required(CONF_PROTOCOL_VERSION, default="3.3"): vol.In(["3.1", "3.3"]),
vol.Required(
CONF_ENTITIES, description={"suggested_value": entity_names}
): cv.multi_select(entity_names),
}
)


Expand Down Expand Up @@ -213,8 +232,7 @@ async def async_step_user(self, user_input=None):
errors = {}
if user_input is not None:
if user_input[DISCOVERED_DEVICE] != CUSTOM_DEVICE:
device = user_input[DISCOVERED_DEVICE].split(" ")[0]
self.selected_device = self.devices[device]
self.selected_device = user_input[DISCOVERED_DEVICE].split(" ")[0]
return await self.async_step_basic_info()

# Use cache if available or fallback to manual discovery
Expand All @@ -241,22 +259,23 @@ async def async_step_user(self, user_input=None):
}

return self.async_show_form(
step_id="user", errors=errors, data_schema=user_schema(self.devices)
step_id="user",
errors=errors,
data_schema=user_schema(self.devices, self._async_current_entries()),
)

async def async_step_basic_info(self, user_input=None):
"""Handle input of basic info."""
errors = {}
if user_input is not None:
await self.async_set_unique_id(user_input[CONF_DEVICE_ID])
self._abort_if_unique_id_configured()

try:
self.basic_info = user_input
if self.selected_device is not None:
self.basic_info[CONF_PRODUCT_KEY] = self.selected_device[
"productKey"
]
self.basic_info[CONF_PRODUCT_KEY] = self.devices[
self.selected_device
]["productKey"]
self.dps_strings = await validate_input(self.hass, user_input)
return await self.async_step_pick_entity_type()
except CannotConnect:
Expand All @@ -269,12 +288,23 @@ async def async_step_basic_info(self, user_input=None):
_LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"

# If selected device exists as a config entry, load config from it
if self.selected_device in self._async_current_ids():
entry = async_config_entry_by_device_id(self.hass, self.selected_device)
await self.async_set_unique_id(entry.data[CONF_DEVICE_ID])
self.basic_info = entry.data.copy()
self.dps_strings = self.basic_info.pop(CONF_DPS_STRINGS).copy()
self.entities = self.basic_info.pop(CONF_ENTITIES).copy()
return await self.async_step_pick_entity_type()

# Insert default values from discovery if present
defaults = {}
defaults.update(user_input or {})
if self.selected_device is not None:
defaults[CONF_HOST] = self.selected_device.get("ip")
defaults[CONF_DEVICE_ID] = self.selected_device.get("gwId")
defaults[CONF_PROTOCOL_VERSION] = self.selected_device.get("version")
device = self.devices[self.selected_device]
defaults[CONF_HOST] = device.get("ip")
defaults[CONF_DEVICE_ID] = device.get("gwId")
defaults[CONF_PROTOCOL_VERSION] = device.get("version")

return self.async_show_form(
step_id="basic_info",
Expand All @@ -291,6 +321,10 @@ async def async_step_pick_entity_type(self, user_input=None):
CONF_DPS_STRINGS: self.dps_strings,
CONF_ENTITIES: self.entities,
}
entry = async_config_entry_by_device_id(self.hass, self.unique_id)
if entry:
self.hass.config_entries.async_update_entry(entry, data=config)
return self.async_abort(reason="device_updated")
return self.async_create_entry(
title=config[CONF_FRIENDLY_NAME], data=config
)
Expand All @@ -313,7 +347,8 @@ async def async_step_add_entity(self, user_input=None):
errors = {}
if user_input is not None:
already_configured = any(
switch[CONF_ID] == user_input[CONF_ID] for switch in self.entities
switch[CONF_ID] == int(user_input[CONF_ID].split(" ")[0])
for switch in self.entities
)
if not already_configured:
user_input[CONF_PLATFORM] = self.platform
Expand Down Expand Up @@ -352,21 +387,34 @@ async def async_step_init(self, user_input=None):
"""Manage basic options."""
device_id = self.config_entry.data[CONF_DEVICE_ID]
if user_input is not None:
self.data = {
CONF_DEVICE_ID: device_id,
CONF_DPS_STRINGS: self.dps_strings,
CONF_ENTITIES: [],
}
self.data.update(user_input)
return await self.async_step_entity()
self.data = user_input.copy()
self.data.update(
{
CONF_DEVICE_ID: device_id,
CONF_DPS_STRINGS: self.dps_strings,
CONF_ENTITIES: [],
}
)
if len(user_input[CONF_ENTITIES]) > 0:
entity_ids = [
int(entity.split(" ")[0]) for entity in user_input[CONF_ENTITIES]
]
self.entities = [
entity
for entity in self.config_entry.data[CONF_ENTITIES]
if entity[CONF_ID] in entity_ids
]
return await self.async_step_entity()

# Not supported for YAML imports
if self.config_entry.source == config_entries.SOURCE_IMPORT:
return await self.async_step_yaml_import()

return self.async_show_form(
step_id="init",
data_schema=schema_defaults(OPTIONS_SCHEMA, **self.config_entry.data),
data_schema=schema_defaults(
options_schema(self.entities), **self.config_entry.data
),
description_placeholders={"device_id": device_id},
)

Expand Down
2 changes: 1 addition & 1 deletion custom_components/localtuya/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def datagram_received(self, data, addr):
def device_found(self, device):
"""Discover a new device."""
if device.get("ip") not in self.devices:
self.devices[device.get("ip")] = device
self.devices[device.get("gwId")] = device
_LOGGER.debug("Discovered device: %s", device)

if self._callback:
Expand Down
6 changes: 4 additions & 2 deletions custom_components/localtuya/translations/en.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{
"config": {
"abort": {
"already_configured": "Device has already been configured."
"already_configured": "Device has already been configured.",
"device_updated": "Device configuration has been updated!"
},
"error": {
"cannot_connect": "Cannot connect to device. Verify that address is correct and try again.",
Expand Down Expand Up @@ -87,7 +88,8 @@
"friendly_name": "Friendly Name",
"host": "Host",
"local_key": "Local key",
"protocol_version": "Protocol Version"
"protocol_version": "Protocol Version",
"entities": "Entities (uncheck an entity to remove it)"
}
},
"entity": {
Expand Down