Skip to content

Commit

Permalink
Add support for Shelly Gen3 devices (#104874)
Browse files Browse the repository at this point in the history
* Add support for Gen3 devices

* Add RPC_GENERATIONS const

* Add gen3 to tests

* More tests

* Add BLOCK_GENERATIONS const

* Use *_GENERATIONS constants from aioshelly
  • Loading branch information
bieniu authored Dec 11, 2023
1 parent 662e199 commit bf93929
Show file tree
Hide file tree
Showing 14 changed files with 55 additions and 27 deletions.
5 changes: 3 additions & 2 deletions homeassistant/components/shelly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from aioshelly.block_device import BlockDevice, BlockUpdateType
from aioshelly.common import ConnectionOptions
from aioshelly.const import RPC_GENERATIONS
from aioshelly.exceptions import (
DeviceConnectionError,
InvalidAuthError,
Expand Down Expand Up @@ -123,7 +124,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:

get_entry_data(hass)[entry.entry_id] = ShellyEntryData()

if get_device_entry_gen(entry) == 2:
if get_device_entry_gen(entry) in RPC_GENERATIONS:
return await _async_setup_rpc_entry(hass, entry)

return await _async_setup_block_entry(hass, entry)
Expand Down Expand Up @@ -313,7 +314,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
if not entry.data.get(CONF_SLEEP_PERIOD):
platforms = RPC_PLATFORMS

if get_device_entry_gen(entry) == 2:
if get_device_entry_gen(entry) in RPC_GENERATIONS:
if unload_ok := await hass.config_entries.async_unload_platforms(
entry, platforms
):
Expand Down
4 changes: 3 additions & 1 deletion homeassistant/components/shelly/binary_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from dataclasses import dataclass
from typing import Final, cast

from aioshelly.const import RPC_GENERATIONS

from homeassistant.components.binary_sensor import (
BinarySensorDeviceClass,
BinarySensorEntity,
Expand Down Expand Up @@ -224,7 +226,7 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up sensors for device."""
if get_device_entry_gen(config_entry) == 2:
if get_device_entry_gen(config_entry) in RPC_GENERATIONS:
if config_entry.data[CONF_SLEEP_PERIOD]:
async_setup_entry_rpc(
hass,
Expand Down
4 changes: 3 additions & 1 deletion homeassistant/components/shelly/button.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Final, Generic, TypeVar

from aioshelly.const import RPC_GENERATIONS

from homeassistant.components.button import (
ButtonDeviceClass,
ButtonEntity,
Expand Down Expand Up @@ -126,7 +128,7 @@ def _async_migrate_unique_ids(
return async_migrate_unique_ids(entity_entry, coordinator)

coordinator: ShellyRpcCoordinator | ShellyBlockCoordinator | None = None
if get_device_entry_gen(config_entry) == 2:
if get_device_entry_gen(config_entry) in RPC_GENERATIONS:
coordinator = get_entry_data(hass)[config_entry.entry_id].rpc
else:
coordinator = get_entry_data(hass)[config_entry.entry_id].block
Expand Down
3 changes: 2 additions & 1 deletion homeassistant/components/shelly/climate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, cast

from aioshelly.block_device import Block
from aioshelly.const import RPC_GENERATIONS
from aioshelly.exceptions import DeviceConnectionError, InvalidAuthError

from homeassistant.components.climate import (
Expand Down Expand Up @@ -51,7 +52,7 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up climate device."""
if get_device_entry_gen(config_entry) == 2:
if get_device_entry_gen(config_entry) in RPC_GENERATIONS:
return async_setup_rpc_entry(hass, config_entry, async_add_entities)

coordinator = get_entry_data(hass)[config_entry.entry_id].block
Expand Down
17 changes: 10 additions & 7 deletions homeassistant/components/shelly/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from aioshelly.block_device import BlockDevice
from aioshelly.common import ConnectionOptions, get_info
from aioshelly.const import BLOCK_GENERATIONS, RPC_GENERATIONS
from aioshelly.exceptions import (
DeviceConnectionError,
FirmwareUnsupported,
Expand Down Expand Up @@ -66,7 +67,9 @@ async def validate_input(
"""
options = ConnectionOptions(host, data.get(CONF_USERNAME), data.get(CONF_PASSWORD))

if get_info_gen(info) == 2:
gen = get_info_gen(info)

if gen in RPC_GENERATIONS:
ws_context = await get_ws_context(hass)
rpc_device = await RpcDevice.create(
async_get_clientsession(hass),
Expand All @@ -81,7 +84,7 @@ async def validate_input(
"title": rpc_device.name,
CONF_SLEEP_PERIOD: sleep_period,
"model": rpc_device.shelly.get("model"),
"gen": 2,
"gen": gen,
}

# Gen1
Expand All @@ -96,7 +99,7 @@ async def validate_input(
"title": block_device.name,
CONF_SLEEP_PERIOD: get_block_device_sleep_period(block_device.settings),
"model": block_device.model,
"gen": 1,
"gen": gen,
}


Expand Down Expand Up @@ -165,7 +168,7 @@ async def async_step_credentials(
"""Handle the credentials step."""
errors: dict[str, str] = {}
if user_input is not None:
if get_info_gen(self.info) == 2:
if get_info_gen(self.info) in RPC_GENERATIONS:
user_input[CONF_USERNAME] = "admin"
try:
device_info = await validate_input(
Expand Down Expand Up @@ -194,7 +197,7 @@ async def async_step_credentials(
else:
user_input = {}

if get_info_gen(self.info) == 2:
if get_info_gen(self.info) in RPC_GENERATIONS:
schema = {
vol.Required(CONF_PASSWORD, default=user_input.get(CONF_PASSWORD)): str,
}
Expand Down Expand Up @@ -331,7 +334,7 @@ async def async_step_reauth_confirm(
await self.hass.config_entries.async_reload(self.entry.entry_id)
return self.async_abort(reason="reauth_successful")

if self.entry.data.get("gen", 1) == 1:
if self.entry.data.get("gen", 1) in BLOCK_GENERATIONS:
schema = {
vol.Required(CONF_USERNAME): str,
vol.Required(CONF_PASSWORD): str,
Expand Down Expand Up @@ -360,7 +363,7 @@ def async_get_options_flow(config_entry: ConfigEntry) -> OptionsFlowHandler:
def async_supports_options_flow(cls, config_entry: ConfigEntry) -> bool:
"""Return options flow support for this handler."""
return (
config_entry.data.get("gen") == 2
config_entry.data.get("gen") in RPC_GENERATIONS
and not config_entry.data.get(CONF_SLEEP_PERIOD)
and config_entry.data.get("model") != MODEL_WALL_DISPLAY
)
Expand Down
3 changes: 2 additions & 1 deletion homeassistant/components/shelly/cover.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, cast

from aioshelly.block_device import Block
from aioshelly.const import RPC_GENERATIONS

from homeassistant.components.cover import (
ATTR_POSITION,
Expand All @@ -26,7 +27,7 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up covers for device."""
if get_device_entry_gen(config_entry) == 2:
if get_device_entry_gen(config_entry) in RPC_GENERATIONS:
return async_setup_rpc_entry(hass, config_entry, async_add_entities)

return async_setup_block_entry(hass, config_entry, async_add_entities)
Expand Down
4 changes: 2 additions & 2 deletions homeassistant/components/shelly/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import TYPE_CHECKING, Any, Final

from aioshelly.block_device import Block
from aioshelly.const import MODEL_I3
from aioshelly.const import MODEL_I3, RPC_GENERATIONS

from homeassistant.components.event import (
DOMAIN as EVENT_DOMAIN,
Expand Down Expand Up @@ -80,7 +80,7 @@ async def async_setup_entry(

coordinator: ShellyRpcCoordinator | ShellyBlockCoordinator | None = None

if get_device_entry_gen(config_entry) == 2:
if get_device_entry_gen(config_entry) in RPC_GENERATIONS:
coordinator = get_entry_data(hass)[config_entry.entry_id].rpc
if TYPE_CHECKING:
assert coordinator
Expand Down
4 changes: 2 additions & 2 deletions homeassistant/components/shelly/light.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, cast

from aioshelly.block_device import Block
from aioshelly.const import MODEL_BULB
from aioshelly.const import MODEL_BULB, RPC_GENERATIONS

from homeassistant.components.light import (
ATTR_BRIGHTNESS,
Expand Down Expand Up @@ -53,7 +53,7 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up lights for device."""
if get_device_entry_gen(config_entry) == 2:
if get_device_entry_gen(config_entry) in RPC_GENERATIONS:
return async_setup_rpc_entry(hass, config_entry, async_add_entities)

return async_setup_block_entry(hass, config_entry, async_add_entities)
Expand Down
3 changes: 2 additions & 1 deletion homeassistant/components/shelly/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Final, cast

from aioshelly.block_device import Block
from aioshelly.const import RPC_GENERATIONS

from homeassistant.components.sensor import (
RestoreSensor,
Expand Down Expand Up @@ -925,7 +926,7 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up sensors for device."""
if get_device_entry_gen(config_entry) == 2:
if get_device_entry_gen(config_entry) in RPC_GENERATIONS:
if config_entry.data[CONF_SLEEP_PERIOD]:
async_setup_entry_rpc(
hass,
Expand Down
4 changes: 2 additions & 2 deletions homeassistant/components/shelly/switch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, cast

from aioshelly.block_device import Block
from aioshelly.const import MODEL_2, MODEL_25, MODEL_GAS
from aioshelly.const import MODEL_2, MODEL_25, MODEL_GAS, RPC_GENERATIONS

from homeassistant.components.switch import SwitchEntity, SwitchEntityDescription
from homeassistant.config_entries import ConfigEntry
Expand Down Expand Up @@ -49,7 +49,7 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up switches for device."""
if get_device_entry_gen(config_entry) == 2:
if get_device_entry_gen(config_entry) in RPC_GENERATIONS:
return async_setup_rpc_entry(hass, config_entry, async_add_entities)

return async_setup_block_entry(hass, config_entry, async_add_entities)
Expand Down
3 changes: 2 additions & 1 deletion homeassistant/components/shelly/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
from typing import Any, Final, cast

from aioshelly.const import RPC_GENERATIONS
from aioshelly.exceptions import DeviceConnectionError, InvalidAuthError, RpcCallError

from homeassistant.components.update import (
Expand Down Expand Up @@ -119,7 +120,7 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up update entities for Shelly component."""
if get_device_entry_gen(config_entry) == 2:
if get_device_entry_gen(config_entry) in RPC_GENERATIONS:
if config_entry.data[CONF_SLEEP_PERIOD]:
async_setup_entry_rpc(
hass,
Expand Down
6 changes: 4 additions & 2 deletions homeassistant/components/shelly/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
from aiohttp.web import Request, WebSocketResponse
from aioshelly.block_device import COAP, Block, BlockDevice
from aioshelly.const import (
BLOCK_GENERATIONS,
MODEL_1L,
MODEL_DIMMER,
MODEL_DIMMER_2,
MODEL_EM3,
MODEL_I3,
MODEL_NAMES,
RPC_GENERATIONS,
)
from aioshelly.rpc_device import RpcDevice, WsServer

Expand Down Expand Up @@ -284,7 +286,7 @@ def get_info_gen(info: dict[str, Any]) -> int:

def get_model_name(info: dict[str, Any]) -> str:
"""Return the device model name."""
if get_info_gen(info) == 2:
if get_info_gen(info) in RPC_GENERATIONS:
return cast(str, MODEL_NAMES.get(info["model"], info["model"]))

return cast(str, MODEL_NAMES.get(info["type"], info["type"]))
Expand Down Expand Up @@ -420,4 +422,4 @@ def get_release_url(gen: int, model: str, beta: bool) -> str | None:
if beta or model in DEVICES_WITHOUT_FIRMWARE_CHANGELOG:
return None

return GEN1_RELEASE_URL if gen == 1 else GEN2_RELEASE_URL
return GEN1_RELEASE_URL if gen in BLOCK_GENERATIONS else GEN2_RELEASE_URL
14 changes: 14 additions & 0 deletions tests/components/shelly/test_config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
[
(1, MODEL_1),
(2, MODEL_PLUS_2PM),
(3, MODEL_PLUS_2PM),
],
)
async def test_form(
Expand Down Expand Up @@ -109,6 +110,12 @@ async def test_form(
{"password": "test2 password"},
"admin",
),
(
3,
MODEL_PLUS_2PM,
{"password": "test2 password"},
"admin",
),
],
)
async def test_form_auth(
Expand Down Expand Up @@ -465,6 +472,11 @@ async def test_form_auth_errors_test_connection_gen2(
MODEL_PLUS_2PM,
{"mac": "test-mac", "model": MODEL_PLUS_2PM, "auth": False, "gen": 2},
),
(
3,
MODEL_PLUS_2PM,
{"mac": "test-mac", "model": MODEL_PLUS_2PM, "auth": False, "gen": 3},
),
],
)
async def test_zeroconf(
Expand Down Expand Up @@ -742,6 +754,7 @@ async def test_zeroconf_require_auth(hass: HomeAssistant, mock_block_device) ->
[
(1, {"username": "test user", "password": "test1 password"}),
(2, {"password": "test2 password"}),
(3, {"password": "test2 password"}),
],
)
async def test_reauth_successful(
Expand Down Expand Up @@ -780,6 +793,7 @@ async def test_reauth_successful(
[
(1, {"username": "test user", "password": "test1 password"}),
(2, {"password": "test2 password"}),
(3, {"password": "test2 password"}),
],
)
async def test_reauth_unsuccessful(hass: HomeAssistant, gen, user_input) -> None:
Expand Down
8 changes: 4 additions & 4 deletions tests/components/shelly/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def test_custom_coap_port(
assert "Starting CoAP context with UDP port 7632" in caplog.text


@pytest.mark.parametrize("gen", [1, 2])
@pytest.mark.parametrize("gen", [1, 2, 3])
async def test_shared_device_mac(
hass: HomeAssistant,
gen,
Expand Down Expand Up @@ -74,7 +74,7 @@ async def test_setup_entry_not_shelly(
assert "probably comes from a custom integration" in caplog.text


@pytest.mark.parametrize("gen", [1, 2])
@pytest.mark.parametrize("gen", [1, 2, 3])
async def test_device_connection_error(
hass: HomeAssistant, gen, mock_block_device, mock_rpc_device, monkeypatch
) -> None:
Expand All @@ -90,7 +90,7 @@ async def test_device_connection_error(
assert entry.state == ConfigEntryState.SETUP_RETRY


@pytest.mark.parametrize("gen", [1, 2])
@pytest.mark.parametrize("gen", [1, 2, 3])
async def test_mac_mismatch_error(
hass: HomeAssistant, gen, mock_block_device, mock_rpc_device, monkeypatch
) -> None:
Expand All @@ -106,7 +106,7 @@ async def test_mac_mismatch_error(
assert entry.state == ConfigEntryState.SETUP_RETRY


@pytest.mark.parametrize("gen", [1, 2])
@pytest.mark.parametrize("gen", [1, 2, 3])
async def test_device_auth_error(
hass: HomeAssistant, gen, mock_block_device, mock_rpc_device, monkeypatch
) -> None:
Expand Down

0 comments on commit bf93929

Please sign in to comment.