Skip to content

Commit

Permalink
Change type definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
Miauwkeru committed Oct 9, 2024
1 parent ce1e705 commit dd149df
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions flow/record/fieldtypes/net/ip.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
from __future__ import annotations

from ipaddress import ip_address, ip_network
from ipaddress import (
IPv4Address,
IPv4Network,
IPv6Address,
IPv6Network,
ip_address,
ip_network,
)
from typing import Union

from flow.record.base import FieldType
from flow.record.fieldtypes import defang

_IPNetwork = Union[IPv4Network, IPv6Network]
_IPAddress = Union[IPv4Address, IPv6Address]


class ipaddress(FieldType):
val = None
_type = "net.ipaddress"

def __init__(self, addr: Union[str, int]):
def __init__(self, addr: str | int | bytes):
self.val = ip_address(addr)

def __eq__(self, b: Union[str, int]) -> bool:
def __eq__(self, b: str | int | bytes) -> bool:
try:
return self.val == ip_address(b)
except ValueError:
Expand Down Expand Up @@ -46,10 +56,10 @@ class ipnetwork(FieldType):
val = None
_type = "net.ipnetwork"

def __init__(self, addr: Union[str, int]):
def __init__(self, addr: str | int | bytes):
self.val = ip_network(addr)

def __eq__(self, b: Union[str, int]) -> bool:
def __eq__(self, b: str | int | bytes) -> bool:
try:
return self.val == ip_network(b)
except ValueError:
Expand All @@ -59,7 +69,7 @@ def __hash__(self) -> int:
return hash(self.val)

@staticmethod
def _is_subnet_of(a: ip_network, b: ip_network) -> bool:
def _is_subnet_of(a: _IPNetwork, b: _IPNetwork) -> bool:
try:
# Always false if one is v4 and the other is v6.
if a._version != b._version:
Expand All @@ -68,7 +78,7 @@ def _is_subnet_of(a: ip_network, b: ip_network) -> bool:
except AttributeError:
raise TypeError("Unable to test subnet containment " "between {} and {}".format(a, b))

def __contains__(self, b: Union[str, int, ip_address]) -> bool:
def __contains__(self, b: str | int | bytes | _IPAddress) -> bool:
try:
return self._is_subnet_of(ip_network(b), self.val)
except (ValueError, TypeError):
Expand Down

0 comments on commit dd149df

Please sign in to comment.