Skip to content

Commit

Permalink
Actually test the whitelist
Browse files Browse the repository at this point in the history
  • Loading branch information
palfrey committed Apr 1, 2024
1 parent 970b9d7 commit 621932c
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ sync: requirements.test
./uv pip sync requirements.test

watch-tests: sync
PYTHONPATH=. ptw . --now -vvv -s
PYTHONPATH=. ptw . --now -vvv

pre-commit: sync
pre-commit run -a
Expand Down
10 changes: 6 additions & 4 deletions custom_components/ban_whitelist/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up Ban Whitelist from a config entry."""
ban_manager: IpBanManager = hass.http.app[KEY_BAN_MANAGER]
_LOGGER.info("Ban manager %s", ban_manager)
_LOGGER.debug("Ban manager %s", ban_manager)
whitelist: List[str] = config.get(DOMAIN, {}).get("ip_addresses", [])
if len(whitelist) == 0:
_LOGGER.info("Not setting whitelist, as no IPs set")
Expand All @@ -41,15 +41,17 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
original_async_add_ban = IpBanManager.async_add_ban

async def whitelist_async_add_ban(
self: IpBanManager, remote_addr: IPv4Address | IPv6Address
remote_addr: IPv4Address | IPv6Address,
) -> None:
if remote_addr in whitelist:
if str(remote_addr) in whitelist:
_LOGGER.info(
"Not adding %s to ban list, as it's in the whitelist", remote_addr
)
return
else:
_LOGGER.info("Banning IP %s", remote_addr)

await original_async_add_ban(self, remote_addr)
await original_async_add_ban(ban_manager, remote_addr)

ban_manager.async_add_ban = ( # type:ignore[method-assign]
whitelist_async_add_ban # type:ignore[assignment]
Expand Down
63 changes: 53 additions & 10 deletions tests/ban_whitelist/test_setup.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,73 @@
"""Test Ban Whitelist setup."""

import logging
from ipaddress import IPv4Address
from typing import cast

import pytest
from homeassistant.components.http.ban import KEY_BAN_MANAGER, IpBanManager
from homeassistant.core import HomeAssistant
from homeassistant.loader import DATA_CUSTOM_COMPONENTS, async_get_custom_components
from homeassistant.setup import async_setup_component

from custom_components.ban_whitelist.const import DOMAIN


@pytest.mark.anyio
async def test_setup(hass: HomeAssistant, caplog: pytest.LogCaptureFixture) -> None:
"""Test setup of ban whitelist."""
def check_records(records: list[logging.LogRecord]) -> None:
"""Check log records don't have any warnings/errors."""
for record in records:
if record.levelno >= logging.WARNING:
msg = record.getMessage()
if msg.startswith(
"We found a custom integration ban_whitelist which has not been tested by Home Assistant"
):
continue
raise Exception(msg)


async def setup_ban_whitelist(hass: HomeAssistant) -> None:
"""Configure ban_whitelist and dependencies."""
hass.data[DATA_CUSTOM_COMPONENTS] = None
assert list((await async_get_custom_components(hass)).keys()) == ["ban_whitelist"]
await async_setup_component(hass, "http", {})
await async_setup_component(
hass, DOMAIN, {DOMAIN: {"ip_addresses": ["192.168.1.1"]}, "foo": "bar"}
)


@pytest.mark.anyio
async def test_setup(hass: HomeAssistant, caplog: pytest.LogCaptureFixture) -> None:
"""Test setup of ban whitelist."""
await setup_ban_whitelist(hass)
check_records(caplog.records)


@pytest.mark.anyio
async def test_hit_whitelist(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test hitting the whitelist."""
await setup_ban_whitelist(hass)
await cast(IpBanManager, hass.http.app[KEY_BAN_MANAGER]).async_add_ban(
IPv4Address("192.168.1.1")
)
await cast(IpBanManager, hass.http.app[KEY_BAN_MANAGER]).async_add_ban(
IPv4Address("10.0.0.1")
)
check_records(caplog.records)

messages = []

for record in caplog.records:
if record.levelno >= logging.WARNING:
msg = record.getMessage()
if msg.startswith(
"We found a custom integration ban_whitelist which has not been tested by Home Assistant"
):
continue
raise Exception(record.getMessage())
if record.levelno < logging.INFO or not record.name.startswith(
"custom_components.ban_whitelist"
):
continue

messages.append(record.getMessage())

assert messages == [
"Setting whitelist with ['192.168.1.1']",
"Not adding 192.168.1.1 to ban list, as it's in the whitelist",
"Banning IP 10.0.0.1",
]

0 comments on commit 621932c

Please sign in to comment.