From 621932cbc86fcb1e3fe7b52a075a5fd84c19d190 Mon Sep 17 00:00:00 2001 From: Tom Parker-Shemilt Date: Mon, 1 Apr 2024 23:19:29 +0100 Subject: [PATCH] Actually test the whitelist --- Makefile | 2 +- custom_components/ban_whitelist/__init__.py | 10 ++-- tests/ban_whitelist/test_setup.py | 63 +++++++++++++++++---- 3 files changed, 60 insertions(+), 15 deletions(-) diff --git a/Makefile b/Makefile index 71201fc..c0409f0 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ sync: requirements.test ./uv pip sync requirements.test watch-tests: sync - PYTHONPATH=. ptw . --now -vvv -s + PYTHONPATH=. ptw . --now -vvv pre-commit: sync pre-commit run -a diff --git a/custom_components/ban_whitelist/__init__.py b/custom_components/ban_whitelist/__init__.py index 9f7d032..6794a37 100644 --- a/custom_components/ban_whitelist/__init__.py +++ b/custom_components/ban_whitelist/__init__.py @@ -31,7 +31,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up Ban Whitelist from a config entry.""" ban_manager: IpBanManager = hass.http.app[KEY_BAN_MANAGER] - _LOGGER.info("Ban manager %s", ban_manager) + _LOGGER.debug("Ban manager %s", ban_manager) whitelist: List[str] = config.get(DOMAIN, {}).get("ip_addresses", []) if len(whitelist) == 0: _LOGGER.info("Not setting whitelist, as no IPs set") @@ -41,15 +41,17 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: original_async_add_ban = IpBanManager.async_add_ban async def whitelist_async_add_ban( - self: IpBanManager, remote_addr: IPv4Address | IPv6Address + remote_addr: IPv4Address | IPv6Address, ) -> None: - if remote_addr in whitelist: + if str(remote_addr) in whitelist: _LOGGER.info( "Not adding %s to ban list, as it's in the whitelist", remote_addr ) return + else: + _LOGGER.info("Banning IP %s", remote_addr) - await original_async_add_ban(self, remote_addr) + await original_async_add_ban(ban_manager, remote_addr) ban_manager.async_add_ban = ( # type:ignore[method-assign] whitelist_async_add_ban # type:ignore[assignment] diff --git a/tests/ban_whitelist/test_setup.py b/tests/ban_whitelist/test_setup.py index 468ec22..bad6e6d 100644 --- a/tests/ban_whitelist/test_setup.py +++ b/tests/ban_whitelist/test_setup.py @@ -1,8 +1,11 @@ """Test Ban Whitelist setup.""" import logging +from ipaddress import IPv4Address +from typing import cast import pytest +from homeassistant.components.http.ban import KEY_BAN_MANAGER, IpBanManager from homeassistant.core import HomeAssistant from homeassistant.loader import DATA_CUSTOM_COMPONENTS, async_get_custom_components from homeassistant.setup import async_setup_component @@ -10,9 +13,20 @@ from custom_components.ban_whitelist.const import DOMAIN -@pytest.mark.anyio -async def test_setup(hass: HomeAssistant, caplog: pytest.LogCaptureFixture) -> None: - """Test setup of ban whitelist.""" +def check_records(records: list[logging.LogRecord]) -> None: + """Check log records don't have any warnings/errors.""" + for record in records: + if record.levelno >= logging.WARNING: + msg = record.getMessage() + if msg.startswith( + "We found a custom integration ban_whitelist which has not been tested by Home Assistant" + ): + continue + raise Exception(msg) + + +async def setup_ban_whitelist(hass: HomeAssistant) -> None: + """Configure ban_whitelist and dependencies.""" hass.data[DATA_CUSTOM_COMPONENTS] = None assert list((await async_get_custom_components(hass)).keys()) == ["ban_whitelist"] await async_setup_component(hass, "http", {}) @@ -20,11 +34,40 @@ async def test_setup(hass: HomeAssistant, caplog: pytest.LogCaptureFixture) -> N hass, DOMAIN, {DOMAIN: {"ip_addresses": ["192.168.1.1"]}, "foo": "bar"} ) + +@pytest.mark.anyio +async def test_setup(hass: HomeAssistant, caplog: pytest.LogCaptureFixture) -> None: + """Test setup of ban whitelist.""" + await setup_ban_whitelist(hass) + check_records(caplog.records) + + +@pytest.mark.anyio +async def test_hit_whitelist( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture +) -> None: + """Test hitting the whitelist.""" + await setup_ban_whitelist(hass) + await cast(IpBanManager, hass.http.app[KEY_BAN_MANAGER]).async_add_ban( + IPv4Address("192.168.1.1") + ) + await cast(IpBanManager, hass.http.app[KEY_BAN_MANAGER]).async_add_ban( + IPv4Address("10.0.0.1") + ) + check_records(caplog.records) + + messages = [] + for record in caplog.records: - if record.levelno >= logging.WARNING: - msg = record.getMessage() - if msg.startswith( - "We found a custom integration ban_whitelist which has not been tested by Home Assistant" - ): - continue - raise Exception(record.getMessage()) + if record.levelno < logging.INFO or not record.name.startswith( + "custom_components.ban_whitelist" + ): + continue + + messages.append(record.getMessage()) + + assert messages == [ + "Setting whitelist with ['192.168.1.1']", + "Not adding 192.168.1.1 to ban list, as it's in the whitelist", + "Banning IP 10.0.0.1", + ]