From dd149dfe5d01c658b53c27e7b422cf307c2f0758 Mon Sep 17 00:00:00 2001 From: Miauwkeru Date: Wed, 9 Oct 2024 11:22:34 +0000 Subject: [PATCH] Change type definitions --- flow/record/fieldtypes/net/ip.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/flow/record/fieldtypes/net/ip.py b/flow/record/fieldtypes/net/ip.py index 660118c..3a661e1 100644 --- a/flow/record/fieldtypes/net/ip.py +++ b/flow/record/fieldtypes/net/ip.py @@ -1,20 +1,30 @@ from __future__ import annotations -from ipaddress import ip_address, ip_network +from ipaddress import ( + IPv4Address, + IPv4Network, + IPv6Address, + IPv6Network, + ip_address, + ip_network, +) from typing import Union from flow.record.base import FieldType from flow.record.fieldtypes import defang +_IPNetwork = Union[IPv4Network, IPv6Network] +_IPAddress = Union[IPv4Address, IPv6Address] + class ipaddress(FieldType): val = None _type = "net.ipaddress" - def __init__(self, addr: Union[str, int]): + def __init__(self, addr: str | int | bytes): self.val = ip_address(addr) - def __eq__(self, b: Union[str, int]) -> bool: + def __eq__(self, b: str | int | bytes) -> bool: try: return self.val == ip_address(b) except ValueError: @@ -46,10 +56,10 @@ class ipnetwork(FieldType): val = None _type = "net.ipnetwork" - def __init__(self, addr: Union[str, int]): + def __init__(self, addr: str | int | bytes): self.val = ip_network(addr) - def __eq__(self, b: Union[str, int]) -> bool: + def __eq__(self, b: str | int | bytes) -> bool: try: return self.val == ip_network(b) except ValueError: @@ -59,7 +69,7 @@ def __hash__(self) -> int: return hash(self.val) @staticmethod - def _is_subnet_of(a: ip_network, b: ip_network) -> bool: + def _is_subnet_of(a: _IPNetwork, b: _IPNetwork) -> bool: try: # Always false if one is v4 and the other is v6. if a._version != b._version: @@ -68,7 +78,7 @@ def _is_subnet_of(a: ip_network, b: ip_network) -> bool: except AttributeError: raise TypeError("Unable to test subnet containment " "between {} and {}".format(a, b)) - def __contains__(self, b: Union[str, int, ip_address]) -> bool: + def __contains__(self, b: str | int | bytes | _IPAddress) -> bool: try: return self._is_subnet_of(ip_network(b), self.val) except (ValueError, TypeError):