-
-
Notifications
You must be signed in to change notification settings - Fork 31.1k
/
entity_component.py
349 lines (277 loc) · 11.6 KB
/
entity_component.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
"""Helpers for components that manage entities."""
from __future__ import annotations
import asyncio
from collections.abc import Callable, Iterable
from datetime import timedelta
from itertools import chain
import logging
from types import ModuleType
from typing import Any, Generic
from typing_extensions import TypeVar
import voluptuous as vol
from homeassistant import config as conf_util
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import (
CONF_ENTITY_NAMESPACE,
CONF_SCAN_INTERVAL,
EVENT_HOMEASSISTANT_STOP,
)
from homeassistant.core import Event, HomeAssistant, ServiceCall, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.loader import async_get_integration, bind_hass
from homeassistant.setup import async_prepare_setup_platform
from . import config_per_platform, config_validation as cv, discovery, entity, service
from .entity_platform import EntityPlatform
from .typing import ConfigType, DiscoveryInfoType
DEFAULT_SCAN_INTERVAL = timedelta(seconds=15)
DATA_INSTANCES = "entity_components"
_EntityT = TypeVar("_EntityT", bound=entity.Entity, default=entity.Entity)
@bind_hass
async def async_update_entity(hass: HomeAssistant, entity_id: str) -> None:
"""Trigger an update for an entity."""
domain = entity_id.partition(".")[0]
entity_comp: EntityComponent[entity.Entity] | None
entity_comp = hass.data.get(DATA_INSTANCES, {}).get(domain)
if entity_comp is None:
logging.getLogger(__name__).warning(
"Forced update failed. Component for %s not loaded.", entity_id
)
return
if (entity_obj := entity_comp.get_entity(entity_id)) is None:
logging.getLogger(__name__).warning(
"Forced update failed. Entity %s not found.", entity_id
)
return
await entity_obj.async_update_ha_state(True)
class EntityComponent(Generic[_EntityT]):
"""The EntityComponent manages platforms that manages entities.
This class has the following responsibilities:
- Process the configuration and set up a platform based component.
- Manage the platforms and their entities.
- Help extract the entities from a service call.
- Listen for discovery events for platforms related to the domain.
"""
def __init__(
self,
logger: logging.Logger,
domain: str,
hass: HomeAssistant,
scan_interval: timedelta = DEFAULT_SCAN_INTERVAL,
) -> None:
"""Initialize an entity component."""
self.logger = logger
self.hass = hass
self.domain = domain
self.scan_interval = scan_interval
self.config: ConfigType | None = None
self._platforms: dict[
str | tuple[str, timedelta | None, str | None], EntityPlatform
] = {domain: self._async_init_entity_platform(domain, None)}
self.async_add_entities = self._platforms[domain].async_add_entities
self.add_entities = self._platforms[domain].add_entities
hass.data.setdefault(DATA_INSTANCES, {})[domain] = self
@property
def entities(self) -> Iterable[_EntityT]:
"""Return an iterable that returns all entities.
As the underlying dicts may change when async context is lost,
callers that iterate over this asynchronously should make a copy
using list() before iterating.
"""
return chain.from_iterable(
platform.entities.values() # type: ignore[misc]
for platform in self._platforms.values()
)
def get_entity(self, entity_id: str) -> _EntityT | None:
"""Get an entity."""
for platform in self._platforms.values():
entity_obj = platform.entities.get(entity_id)
if entity_obj is not None:
return entity_obj # type: ignore[return-value]
return None
def register_shutdown(self) -> None:
"""Register shutdown on Home Assistant STOP event.
Note: this is only required if the integration never calls
`setup` or `async_setup`.
"""
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._async_shutdown)
def setup(self, config: ConfigType) -> None:
"""Set up a full entity component.
This doesn't block the executor to protect from deadlocks.
"""
self.hass.create_task(
self.async_setup(config), f"EntityComponent setup {self.domain}"
)
async def async_setup(self, config: ConfigType) -> None:
"""Set up a full entity component.
Loads the platforms from the config and will listen for supported
discovered platforms.
This method must be run in the event loop.
"""
self.register_shutdown()
self.config = config
# Look in config for Domain, Domain 2, Domain 3 etc and load them
for p_type, p_config in config_per_platform(config, self.domain):
if p_type is not None:
self.hass.async_create_task(
self.async_setup_platform(p_type, p_config),
f"EntityComponent setup platform {p_type} {self.domain}",
)
# Generic discovery listener for loading platform dynamically
# Refer to: homeassistant.helpers.discovery.async_load_platform()
async def component_platform_discovered(
platform: str, info: dict[str, Any] | None
) -> None:
"""Handle the loading of a platform."""
await self.async_setup_platform(platform, {}, info)
discovery.async_listen_platform(
self.hass, self.domain, component_platform_discovered
)
async def async_setup_entry(self, config_entry: ConfigEntry) -> bool:
"""Set up a config entry."""
platform_type = config_entry.domain
platform = await async_prepare_setup_platform(
self.hass,
# In future PR we should make hass_config part of the constructor
# params.
self.config or {},
self.domain,
platform_type,
)
if platform is None:
return False
key = config_entry.entry_id
if key in self._platforms:
raise ValueError("Config entry has already been setup!")
self._platforms[key] = self._async_init_entity_platform(
platform_type,
platform,
scan_interval=getattr(platform, "SCAN_INTERVAL", None),
)
return await self._platforms[key].async_setup_entry(config_entry)
async def async_unload_entry(self, config_entry: ConfigEntry) -> bool:
"""Unload a config entry."""
key = config_entry.entry_id
if (platform := self._platforms.pop(key, None)) is None:
raise ValueError("Config entry was never loaded!")
await platform.async_reset()
return True
async def async_extract_from_service(
self, service_call: ServiceCall, expand_group: bool = True
) -> list[_EntityT]:
"""Extract all known and available entities from a service call.
Will return an empty list if entities specified but unknown.
This method must be run in the event loop.
"""
return await service.async_extract_entities(
self.hass, self.entities, service_call, expand_group
)
@callback
def async_register_entity_service(
self,
name: str,
schema: dict[str | vol.Marker, Any] | vol.Schema,
func: str | Callable[..., Any],
required_features: list[int] | None = None,
) -> None:
"""Register an entity service."""
if isinstance(schema, dict):
schema = cv.make_entity_service_schema(schema)
async def handle_service(call: ServiceCall) -> None:
"""Handle the service."""
await service.entity_service_call(
self.hass, self._platforms.values(), func, call, required_features
)
self.hass.services.async_register(self.domain, name, handle_service, schema)
async def async_setup_platform(
self,
platform_type: str,
platform_config: ConfigType,
discovery_info: DiscoveryInfoType | None = None,
) -> None:
"""Set up a platform for this component."""
if self.config is None:
raise RuntimeError("async_setup needs to be called first")
platform = await async_prepare_setup_platform(
self.hass, self.config, self.domain, platform_type
)
if platform is None:
return
# Use config scan interval, fallback to platform if none set
scan_interval = platform_config.get(
CONF_SCAN_INTERVAL, getattr(platform, "SCAN_INTERVAL", None)
)
entity_namespace = platform_config.get(CONF_ENTITY_NAMESPACE)
key = (platform_type, scan_interval, entity_namespace)
if key not in self._platforms:
self._platforms[key] = self._async_init_entity_platform(
platform_type, platform, scan_interval, entity_namespace
)
await self._platforms[key].async_setup(platform_config, discovery_info)
async def _async_reset(self) -> None:
"""Remove entities and reset the entity component to initial values.
This method must be run in the event loop.
"""
tasks = []
for key, platform in self._platforms.items():
if key == self.domain:
tasks.append(platform.async_reset())
else:
tasks.append(platform.async_destroy())
if tasks:
await asyncio.gather(*tasks)
self._platforms = {self.domain: self._platforms[self.domain]}
self.config = None
async def async_remove_entity(self, entity_id: str) -> None:
"""Remove an entity managed by one of the platforms."""
found = None
for platform in self._platforms.values():
if entity_id in platform.entities:
found = platform
break
if found:
await found.async_remove_entity(entity_id)
async def async_prepare_reload(
self, *, skip_reset: bool = False
) -> ConfigType | None:
"""Prepare reloading this entity component.
This method must be run in the event loop.
"""
try:
conf = await conf_util.async_hass_config_yaml(self.hass)
except HomeAssistantError as err:
self.logger.error(err)
return None
integration = await async_get_integration(self.hass, self.domain)
processed_conf = await conf_util.async_process_component_config(
self.hass, conf, integration
)
if processed_conf is None:
return None
if not skip_reset:
await self._async_reset()
return processed_conf
@callback
def _async_init_entity_platform(
self,
platform_type: str,
platform: ModuleType | None,
scan_interval: timedelta | None = None,
entity_namespace: str | None = None,
) -> EntityPlatform:
"""Initialize an entity platform."""
if scan_interval is None:
scan_interval = self.scan_interval
return EntityPlatform(
hass=self.hass,
logger=self.logger,
domain=self.domain,
platform_name=platform_type,
platform=platform,
scan_interval=scan_interval,
entity_namespace=entity_namespace,
)
async def _async_shutdown(self, event: Event) -> None:
"""Call when Home Assistant is stopping."""
await asyncio.gather(
*(platform.async_shutdown() for platform in chain(self._platforms.values()))
)