Skip to content

Commit

Permalink
handle CIDR notation
Browse files Browse the repository at this point in the history
  • Loading branch information
SalemHarrache committed Oct 3, 2024
1 parent e0a75b8 commit 3d58f18
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ To use this, [install with HACS](https://hacs.xyz/) as [a custom repository](htt
Then add to your `configuration.yaml` something like the following:
```
ban_allowlist:
ip_addresses: ["my.ip.address", "another.ip.address"]
ip_addresses: ["my.ip.address", "another.network.address/24"]
```
23 changes: 13 additions & 10 deletions custom_components/ban_allowlist/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import logging
from ipaddress import IPv4Address, IPv6Address
from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network, ip_network
from typing import List

import voluptuous as vol
Expand Down Expand Up @@ -38,25 +38,28 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
)
return True
_LOGGER.debug("Ban manager %s", ban_manager)
allowlist: List[str] = config.get(DOMAIN, {}).get("ip_addresses", [])
allowlist: List[IPv4Network | IPv6Network] = [
ip_network(ip) for ip in config.get(DOMAIN, {}).get("ip_addresses", [])
]
if len(allowlist) == 0:
_LOGGER.info("Not setting allowlist, as no IPs set")
else:
_LOGGER.info("Setting allowlist with %s", allowlist)
_LOGGER.info("Setting allowlist with %s", [str(ip) for ip in allowlist])

original_async_add_ban = IpBanManager.async_add_ban

async def allowlist_async_add_ban(
remote_addr: IPv4Address | IPv6Address,
) -> None:
if str(remote_addr) in allowlist:
_LOGGER.info(
"Not adding %s to ban list, as it's in the allowlist", remote_addr
)
return
else:
_LOGGER.info("Banning IP %s", remote_addr)
for allowed_network in allowlist:
if remote_addr in allowed_network:
_LOGGER.info(
"Not adding %s to ban list, as it's in the allowlist",
remote_addr,
)
return

_LOGGER.info("Banning IP %s", remote_addr)
await original_async_add_ban(ban_manager, remote_addr)

ban_manager.async_add_ban = ( # type:ignore[method-assign]
Expand Down
14 changes: 12 additions & 2 deletions tests/ban_allowlist/test_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ async def setup_ban_allowlist(hass: HomeAssistant) -> None:
assert list((await async_get_custom_components(hass)).keys()) == ["ban_allowlist"]
await async_setup_component(hass, "http", {})
await async_setup_component(
hass, DOMAIN, {DOMAIN: {"ip_addresses": ["192.168.1.1"]}, "foo": "bar"}
hass,
DOMAIN,
{DOMAIN: {"ip_addresses": ["192.168.1.1", "172.17.0.0/24"]}, "foo": "bar"},
)


Expand All @@ -54,6 +56,12 @@ async def test_hit_allowlist(
await cast(IpBanManager, hass.http.app[KEY_BAN_MANAGER]).async_add_ban(
IPv4Address("10.0.0.1")
)
await cast(IpBanManager, hass.http.app[KEY_BAN_MANAGER]).async_add_ban(
IPv4Address("172.17.0.10")
)
await cast(IpBanManager, hass.http.app[KEY_BAN_MANAGER]).async_add_ban(
IPv4Address("172.17.1.10")
)
check_records(caplog.records)

messages = []
Expand All @@ -67,7 +75,9 @@ async def test_hit_allowlist(
messages.append(record.getMessage())

assert messages == [
"Setting allowlist with ['192.168.1.1']",
"Setting allowlist with ['192.168.1.1/32', '172.17.0.0/24']",
"Not adding 192.168.1.1 to ban list, as it's in the allowlist",
"Banning IP 10.0.0.1",
"Not adding 172.17.0.10 to ban list, as it's in the allowlist",
"Banning IP 172.17.1.10",
]

0 comments on commit 3d58f18

Please sign in to comment.