Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve dispatcher typing #106872

Merged
merged 7 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions homeassistant/components/cast/const.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
"""Consts for Cast integration."""
from __future__ import annotations

from typing import TYPE_CHECKING

from pychromecast.controllers.homeassistant import HomeAssistantController

from homeassistant.helpers.dispatcher import SignalType

if TYPE_CHECKING:
from .helpers import ChromecastInfo


DOMAIN = "cast"

Expand All @@ -14,14 +25,16 @@

# Dispatcher signal fired with a ChromecastInfo every time we discover a new
# Chromecast or receive it through configuration
SIGNAL_CAST_DISCOVERED = "cast_discovered"
SIGNAL_CAST_DISCOVERED: SignalType[ChromecastInfo] = SignalType("cast_discovered")

# Dispatcher signal fired with a ChromecastInfo every time a Chromecast is
# removed
SIGNAL_CAST_REMOVED = "cast_removed"
SIGNAL_CAST_REMOVED: SignalType[ChromecastInfo] = SignalType("cast_removed")

# Dispatcher signal fired when a Chromecast should show a Home Assistant Cast view.
SIGNAL_HASS_CAST_SHOW_VIEW = "cast_show_view"
SIGNAL_HASS_CAST_SHOW_VIEW: SignalType[
HomeAssistantController, str, str, str | None
] = SignalType("cast_show_view")

CONF_IGNORE_CEC = "ignore_cec"
CONF_KNOWN_HOSTS = "known_hosts"
5 changes: 4 additions & 1 deletion homeassistant/components/cloud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.discovery import async_load_platform
from homeassistant.helpers.dispatcher import (
SignalType,
async_dispatcher_connect,
async_dispatcher_send,
)
Expand Down Expand Up @@ -68,7 +69,9 @@
SERVICE_REMOTE_CONNECT = "remote_connect"
SERVICE_REMOTE_DISCONNECT = "remote_disconnect"

SIGNAL_CLOUD_CONNECTION_STATE = "CLOUD_CONNECTION_STATE"
SIGNAL_CLOUD_CONNECTION_STATE: SignalType[CloudConnectionState] = SignalType(
"CLOUD_CONNECTION_STATE"
)

STARTUP_REPAIR_DELAY = 1 # 1 hour

Expand Down
8 changes: 7 additions & 1 deletion homeassistant/components/cloud/const.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
"""Constants for the cloud component."""
from __future__ import annotations

from typing import Any

from homeassistant.helpers.dispatcher import SignalType

DOMAIN = "cloud"
DATA_PLATFORMS_SETUP = "cloud_platforms_setup"
REQUEST_TIMEOUT = 10
Expand Down Expand Up @@ -64,6 +70,6 @@
MODE_DEV = "development"
MODE_PROD = "production"

DISPATCHER_REMOTE_UPDATE = "cloud_remote_update"
DISPATCHER_REMOTE_UPDATE: SignalType[Any] = SignalType("cloud_remote_update")

STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text"
117 changes: 107 additions & 10 deletions homeassistant/helpers/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,73 @@
from __future__ import annotations

from collections.abc import Callable, Coroutine
from dataclasses import dataclass
from functools import partial
import logging
from typing import Any
from typing import Any, Generic, TypeVarTuple, overload

from homeassistant.core import HassJob, HomeAssistant, callback
from homeassistant.loader import bind_hass
from homeassistant.util.async_ import run_callback_threadsafe
from homeassistant.util.logging import catch_log_exception

_Ts = TypeVarTuple("_Ts")

_LOGGER = logging.getLogger(__name__)
DATA_DISPATCHER = "dispatcher"


@dataclass(frozen=True)
class SignalType(Generic[*_Ts]):
"""Generic string class for signal to improve typing."""

name: str

def __hash__(self) -> int:
"""Return hash of name."""

return hash(self.name)

def __eq__(self, other: Any) -> bool:
"""Check equality for dict keys to be compatible with str."""

if isinstance(other, str):
return self.name == other
if isinstance(other, SignalType):
return self.name == other.name
return False

Check warning on line 39 in homeassistant/helpers/dispatcher.py

View check run for this annotation

Codecov / codecov/patch

homeassistant/helpers/dispatcher.py#L37-L39

Added lines #L37 - L39 were not covered by tests


_DispatcherDataType = dict[
str,
SignalType[*_Ts] | str,
dict[
Callable[..., Any],
Callable[[*_Ts], Any] | Callable[..., Any],
HassJob[..., None | Coroutine[Any, Any, None]] | None,
],
]


@overload
@bind_hass
def dispatcher_connect(
hass: HomeAssistant, signal: SignalType[*_Ts], target: Callable[[*_Ts], None]
) -> Callable[[], None]:
...


@overload
@bind_hass
def dispatcher_connect(
hass: HomeAssistant, signal: str, target: Callable[..., None]
) -> Callable[[], None]:
...


@bind_hass # type: ignore[misc] # workaround; exclude typing of 2 overload in func def
def dispatcher_connect(
hass: HomeAssistant,
signal: SignalType[*_Ts],
target: Callable[[*_Ts], None],
) -> Callable[[], None]:
"""Connect a callable function to a signal."""
async_unsub = run_callback_threadsafe(
Expand All @@ -41,9 +84,9 @@

@callback
def _async_remove_dispatcher(
dispatchers: _DispatcherDataType,
signal: str,
target: Callable[..., Any],
dispatchers: _DispatcherDataType[*_Ts],
signal: SignalType[*_Ts] | str,
target: Callable[[*_Ts], Any] | Callable[..., Any],
) -> None:
"""Remove signal listener."""
try:
Expand All @@ -59,10 +102,30 @@
_LOGGER.warning("Unable to remove unknown dispatcher %s", target)


@overload
@callback
@bind_hass
def async_dispatcher_connect(
hass: HomeAssistant, signal: SignalType[*_Ts], target: Callable[[*_Ts], Any]
) -> Callable[[], None]:
...


@overload
@callback
@bind_hass
def async_dispatcher_connect(
hass: HomeAssistant, signal: str, target: Callable[..., Any]
) -> Callable[[], None]:
...


@callback
@bind_hass
def async_dispatcher_connect(
hass: HomeAssistant,
signal: SignalType[*_Ts] | str,
target: Callable[[*_Ts], Any] | Callable[..., Any],
) -> Callable[[], None]:
"""Connect a callable function to a signal.

Expand All @@ -71,7 +134,7 @@
if DATA_DISPATCHER not in hass.data:
hass.data[DATA_DISPATCHER] = {}

dispatchers: _DispatcherDataType = hass.data[DATA_DISPATCHER]
dispatchers: _DispatcherDataType[*_Ts] = hass.data[DATA_DISPATCHER]

if signal not in dispatchers:
dispatchers[signal] = {}
Expand All @@ -84,13 +147,29 @@
return partial(_async_remove_dispatcher, dispatchers, signal, target)


@overload
@bind_hass
def dispatcher_send(hass: HomeAssistant, signal: SignalType[*_Ts], *args: *_Ts) -> None:
...


@overload
@bind_hass
def dispatcher_send(hass: HomeAssistant, signal: str, *args: Any) -> None:
...


@bind_hass # type: ignore[misc] # workaround; exclude typing of 2 overload in func def
def dispatcher_send(hass: HomeAssistant, signal: SignalType[*_Ts], *args: *_Ts) -> None:
"""Send signal and data."""
hass.loop.call_soon_threadsafe(async_dispatcher_send, hass, signal, *args)


def _format_err(signal: str, target: Callable[..., Any], *args: Any) -> str:
def _format_err(
signal: SignalType[*_Ts] | str,
target: Callable[[*_Ts], Any] | Callable[..., Any],
*args: Any,
) -> str:
"""Format error message."""
return "Exception in {} when dispatching '{}': {}".format(
# Functions wrapped in partial do not have a __name__
Expand All @@ -101,7 +180,7 @@


def _generate_job(
signal: str, target: Callable[..., Any]
signal: SignalType[*_Ts] | str, target: Callable[[*_Ts], Any] | Callable[..., Any]
) -> HassJob[..., None | Coroutine[Any, Any, None]]:
"""Generate a HassJob for a signal and target."""
return HassJob(
Expand All @@ -110,16 +189,34 @@
)


@overload
@callback
@bind_hass
def async_dispatcher_send(
hass: HomeAssistant, signal: SignalType[*_Ts], *args: *_Ts
) -> None:
...


@overload
@callback
@bind_hass
def async_dispatcher_send(hass: HomeAssistant, signal: str, *args: Any) -> None:
...


@callback
@bind_hass
def async_dispatcher_send(
hass: HomeAssistant, signal: SignalType[*_Ts] | str, *args: *_Ts
) -> None:
"""Send signal and data.

This method must be run in the event loop.
"""
if (maybe_dispatchers := hass.data.get(DATA_DISPATCHER)) is None:
return
dispatchers: _DispatcherDataType = maybe_dispatchers
dispatchers: _DispatcherDataType[*_Ts] = maybe_dispatchers
if (target_list := dispatchers.get(signal)) is None:
return

Expand Down
27 changes: 27 additions & 0 deletions tests/helpers/test_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.dispatcher import (
SignalType,
async_dispatcher_connect,
async_dispatcher_send,
)
Expand All @@ -30,6 +31,32 @@ def test_funct(data):
assert calls == [3, "bla"]


async def test_signal_type(hass: HomeAssistant) -> None:
"""Test dispatcher with SignalType."""
signal: SignalType[str, int] = SignalType("test")
calls: list[tuple[str, int]] = []

def test_funct(data1: str, data2: int) -> None:
calls.append((data1, data2))

async_dispatcher_connect(hass, signal, test_funct)
async_dispatcher_send(hass, signal, "Hello", 2)
await hass.async_block_till_done()

assert calls == [("Hello", 2)]

async_dispatcher_send(hass, signal, "World", 3)
await hass.async_block_till_done()

assert calls == [("Hello", 2), ("World", 3)]

# Test compatibility with string keys
async_dispatcher_send(hass, "test", "x", 4)
await hass.async_block_till_done()

assert calls == [("Hello", 2), ("World", 3), ("x", 4)]


async def test_simple_function_unsub(hass: HomeAssistant) -> None:
"""Test simple function (executor) and unsub."""
calls1 = []
Expand Down