diff --git a/python/external/__init__.py b/python/external/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/python/external/ipaddress.py b/python/external/ipaddress.py deleted file mode 100644 index ac03c36ce0..0000000000 --- a/python/external/ipaddress.py +++ /dev/null @@ -1,2150 +0,0 @@ -# Copyright 2007 Google Inc. -# Licensed to PSF under a Contributor Agreement. - -"""A fast, lightweight IPv4/IPv6 manipulation library in Python. - -This library is used to create/poke/manipulate IPv4 and IPv6 addresses -and networks. - -""" - -__version__ = '1.0' - - -import functools - -IPV4LENGTH = 32 -IPV6LENGTH = 128 - -class AddressValueError(ValueError): - """A Value Error related to the address.""" - - -class NetmaskValueError(ValueError): - """A Value Error related to the netmask.""" - - -def ip_address(address): - """Take an IP string/int and return an object of the correct type. - - Args: - address: A string or integer, the IP address. Either IPv4 or - IPv6 addresses may be supplied; integers less than 2**32 will - be considered to be IPv4 by default. - - Returns: - An IPv4Address or IPv6Address object. - - Raises: - ValueError: if the *address* passed isn't either a v4 or a v6 - address - - """ - try: - return IPv4Address(address) - except (AddressValueError, NetmaskValueError): - pass - - try: - return IPv6Address(address) - except (AddressValueError, NetmaskValueError): - pass - - raise ValueError('%r does not appear to be an IPv4 or IPv6 address' % - address) - - -def ip_network(address, strict=True): - """Take an IP string/int and return an object of the correct type. - - Args: - address: A string or integer, the IP network. Either IPv4 or - IPv6 networks may be supplied; integers less than 2**32 will - be considered to be IPv4 by default. - - Returns: - An IPv4Network or IPv6Network object. - - Raises: - ValueError: if the string passed isn't either a v4 or a v6 - address. Or if the network has host bits set. - - """ - try: - return IPv4Network(address, strict) - except (AddressValueError, NetmaskValueError): - pass - - try: - return IPv6Network(address, strict) - except (AddressValueError, NetmaskValueError): - pass - - raise ValueError('%r does not appear to be an IPv4 or IPv6 network' % - address) - - -def ip_interface(address): - """Take an IP string/int and return an object of the correct type. - - Args: - address: A string or integer, the IP address. Either IPv4 or - IPv6 addresses may be supplied; integers less than 2**32 will - be considered to be IPv4 by default. - - Returns: - An IPv4Interface or IPv6Interface object. - - Raises: - ValueError: if the string passed isn't either a v4 or a v6 - address. - - Notes: - The IPv?Interface classes describe an Address on a particular - Network, so they're basically a combination of both the Address - and Network classes. - - """ - try: - return IPv4Interface(address) - except (AddressValueError, NetmaskValueError): - pass - - try: - return IPv6Interface(address) - except (AddressValueError, NetmaskValueError): - pass - - raise ValueError('%r does not appear to be an IPv4 or IPv6 interface' % - address) - - -def v4_int_to_packed(address): - """Represent an address as 4 packed bytes in network (big-endian) order. - - Args: - address: An integer representation of an IPv4 IP address. - - Returns: - The integer address packed as 4 bytes in network (big-endian) order. - - Raises: - ValueError: If the integer is negative or too large to be an - IPv4 IP address. - - """ - try: - return address.to_bytes(4, 'big') - except: - raise ValueError("Address negative or too large for IPv4") - - -def v6_int_to_packed(address): - """Represent an address as 16 packed bytes in network (big-endian) order. - - Args: - address: An integer representation of an IPv6 IP address. - - Returns: - The integer address packed as 16 bytes in network (big-endian) order. - - """ - try: - return address.to_bytes(16, 'big') - except: - raise ValueError("Address negative or too large for IPv6") - - -def _split_optional_netmask(address): - """Helper to split the netmask and raise AddressValueError if needed""" - addr = str(address).split('/') - if len(addr) > 2: - raise AddressValueError("Only one '/' permitted in %r" % address) - return addr - - -def _find_address_range(addresses): - """Find a sequence of IPv#Address. - - Args: - addresses: a list of IPv#Address objects. - - Returns: - A tuple containing the first and last IP addresses in the sequence. - - """ - first = last = addresses[0] - for ip in addresses[1:]: - if ip._ip == last._ip + 1: - last = ip - else: - break - return (first, last) - - -def _count_righthand_zero_bits(number, bits): - """Count the number of zero bits on the right hand side. - - Args: - number: an integer. - bits: maximum number of bits to count. - - Returns: - The number of zero bits on the right hand side of the number. - - """ - if number == 0: - return bits - for i in range(bits): - if (number >> i) & 1: - return i - # All bits of interest were zero, even if there are more in the number - return bits - - -def summarize_address_range(first, last): - """Summarize a network range given the first and last IP addresses. - - Example: - >>> list(summarize_address_range(IPv4Address('192.0.2.0'), - ... IPv4Address('192.0.2.130'))) - ... #doctest: +NORMALIZE_WHITESPACE - [IPv4Network('192.0.2.0/25'), IPv4Network('192.0.2.128/31'), - IPv4Network('192.0.2.130/32')] - - Args: - first: the first IPv4Address or IPv6Address in the range. - last: the last IPv4Address or IPv6Address in the range. - - Returns: - An iterator of the summarized IPv(4|6) network objects. - - Raise: - TypeError: - If the first and last objects are not IP addresses. - If the first and last objects are not the same version. - ValueError: - If the last object is not greater than the first. - If the version of the first address is not 4 or 6. - - """ - if (not (isinstance(first, _BaseAddress) and - isinstance(last, _BaseAddress))): - raise TypeError('first and last must be IP addresses, not networks') - if first.version != last.version: - raise TypeError("%s and %s are not of the same version" % ( - first, last)) - if first > last: - raise ValueError('last IP address must be greater than first') - - if first.version == 4: - ip = IPv4Network - elif first.version == 6: - ip = IPv6Network - else: - raise ValueError('unknown IP version') - - ip_bits = first._max_prefixlen - first_int = first._ip - last_int = last._ip - while first_int <= last_int: - nbits = min(_count_righthand_zero_bits(first_int, ip_bits), - (last_int - first_int + 1).bit_length() - 1) - net = ip('%s/%d' % (first, ip_bits - nbits)) - yield net - first_int += 1 << nbits - if first_int - 1 == ip._ALL_ONES: - break - first = first.__class__(first_int) - - -def _collapse_addresses_recursive(addresses): - """Loops through the addresses, collapsing concurrent netblocks. - - Example: - - ip1 = IPv4Network('192.0.2.0/26') - ip2 = IPv4Network('192.0.2.64/26') - ip3 = IPv4Network('192.0.2.128/26') - ip4 = IPv4Network('192.0.2.192/26') - - _collapse_addresses_recursive([ip1, ip2, ip3, ip4]) -> - [IPv4Network('192.0.2.0/24')] - - This shouldn't be called directly; it is called via - collapse_addresses([]). - - Args: - addresses: A list of IPv4Network's or IPv6Network's - - Returns: - A list of IPv4Network's or IPv6Network's depending on what we were - passed. - - """ - while True: - last_addr = None - ret_array = [] - optimized = False - - for cur_addr in addresses: - if not ret_array: - last_addr = cur_addr - ret_array.append(cur_addr) - elif (cur_addr.network_address >= last_addr.network_address and - cur_addr.broadcast_address <= last_addr.broadcast_address): - optimized = True - elif cur_addr == list(last_addr.supernet().subnets())[1]: - ret_array[-1] = last_addr = last_addr.supernet() - optimized = True - else: - last_addr = cur_addr - ret_array.append(cur_addr) - - addresses = ret_array - if not optimized: - return addresses - - -def collapse_addresses(addresses): - """Collapse a list of IP objects. - - Example: - collapse_addresses([IPv4Network('192.0.2.0/25'), - IPv4Network('192.0.2.128/25')]) -> - [IPv4Network('192.0.2.0/24')] - - Args: - addresses: An iterator of IPv4Network or IPv6Network objects. - - Returns: - An iterator of the collapsed IPv(4|6)Network objects. - - Raises: - TypeError: If passed a list of mixed version objects. - - """ - i = 0 - addrs = [] - ips = [] - nets = [] - - # split IP addresses and networks - for ip in addresses: - if isinstance(ip, _BaseAddress): - if ips and ips[-1]._version != ip._version: - raise TypeError("%s and %s are not of the same version" % ( - ip, ips[-1])) - ips.append(ip) - elif ip._prefixlen == ip._max_prefixlen: - if ips and ips[-1]._version != ip._version: - raise TypeError("%s and %s are not of the same version" % ( - ip, ips[-1])) - try: - ips.append(ip.ip) - except AttributeError: - ips.append(ip.network_address) - else: - if nets and nets[-1]._version != ip._version: - raise TypeError("%s and %s are not of the same version" % ( - ip, nets[-1])) - nets.append(ip) - - # sort and dedup - ips = sorted(set(ips)) - nets = sorted(set(nets)) - - while i < len(ips): - (first, last) = _find_address_range(ips[i:]) - i = ips.index(last) + 1 - addrs.extend(summarize_address_range(first, last)) - - return iter(_collapse_addresses_recursive(sorted( - addrs + nets, key=_BaseNetwork._get_networks_key))) - - -def get_mixed_type_key(obj): - """Return a key suitable for sorting between networks and addresses. - - Address and Network objects are not sortable by default; they're - fundamentally different so the expression - - IPv4Address('192.0.2.0') <= IPv4Network('192.0.2.0/24') - - doesn't make any sense. There are some times however, where you may wish - to have ipaddress sort these for you anyway. If you need to do this, you - can use this function as the key= argument to sorted(). - - Args: - obj: either a Network or Address object. - Returns: - appropriate key. - - """ - if isinstance(obj, _BaseNetwork): - return obj._get_networks_key() - elif isinstance(obj, _BaseAddress): - return obj._get_address_key() - return NotImplemented - - -class _IPAddressBase: - - """The mother class.""" - - @property - def exploded(self): - """Return the longhand version of the IP address as a string.""" - return self._explode_shorthand_ip_string() - - @property - def compressed(self): - """Return the shorthand version of the IP address as a string.""" - return str(self) - - @property - def version(self): - msg = '%200s has no version specified' % (type(self),) - raise NotImplementedError(msg) - - def _check_int_address(self, address): - if address < 0: - msg = "%d (< 0) is not permitted as an IPv%d address" - raise AddressValueError(msg % (address, self._version)) - if address > self._ALL_ONES: - msg = "%d (>= 2**%d) is not permitted as an IPv%d address" - raise AddressValueError(msg % (address, self._max_prefixlen, - self._version)) - - def _check_packed_address(self, address, expected_len): - address_len = len(address) - if address_len != expected_len: - msg = "%r (len %d != %d) is not permitted as an IPv%d address" - raise AddressValueError(msg % (address, address_len, - expected_len, self._version)) - - def _ip_int_from_prefix(self, prefixlen): - """Turn the prefix length into a bitwise netmask - - Args: - prefixlen: An integer, the prefix length. - - Returns: - An integer. - - """ - return self._ALL_ONES ^ (self._ALL_ONES >> prefixlen) - - def _prefix_from_ip_int(self, ip_int): - """Return prefix length from the bitwise netmask. - - Args: - ip_int: An integer, the netmask in expanded bitwise format - - Returns: - An integer, the prefix length. - - Raises: - ValueError: If the input intermingles zeroes & ones - """ - trailing_zeroes = _count_righthand_zero_bits(ip_int, - self._max_prefixlen) - prefixlen = self._max_prefixlen - trailing_zeroes - leading_ones = ip_int >> trailing_zeroes - all_ones = (1 << prefixlen) - 1 - if leading_ones != all_ones: - byteslen = self._max_prefixlen // 8 - details = ip_int.to_bytes(byteslen, 'big') - msg = 'Netmask pattern %r mixes zeroes & ones' - raise ValueError(msg % details) - return prefixlen - - def _report_invalid_netmask(self, netmask_str): - msg = '%r is not a valid netmask' % netmask_str - raise NetmaskValueError(msg) from None - - def _prefix_from_prefix_string(self, prefixlen_str): - """Return prefix length from a numeric string - - Args: - prefixlen_str: The string to be converted - - Returns: - An integer, the prefix length. - - Raises: - NetmaskValueError: If the input is not a valid netmask - """ - # int allows a leading +/- as well as surrounding whitespace, - # so we ensure that isn't the case - if not _BaseV4._DECIMAL_DIGITS.issuperset(prefixlen_str): - self._report_invalid_netmask(prefixlen_str) - try: - prefixlen = int(prefixlen_str) - except ValueError: - self._report_invalid_netmask(prefixlen_str) - if not (0 <= prefixlen <= self._max_prefixlen): - self._report_invalid_netmask(prefixlen_str) - return prefixlen - - def _prefix_from_ip_string(self, ip_str): - """Turn a netmask/hostmask string into a prefix length - - Args: - ip_str: The netmask/hostmask to be converted - - Returns: - An integer, the prefix length. - - Raises: - NetmaskValueError: If the input is not a valid netmask/hostmask - """ - # Parse the netmask/hostmask like an IP address. - try: - ip_int = self._ip_int_from_string(ip_str) - except AddressValueError: - self._report_invalid_netmask(ip_str) - - # Try matching a netmask (this would be /1*0*/ as a bitwise regexp). - # Note that the two ambiguous cases (all-ones and all-zeroes) are - # treated as netmasks. - try: - return self._prefix_from_ip_int(ip_int) - except ValueError: - pass - - # Invert the bits, and try matching a /0+1+/ hostmask instead. - ip_int ^= self._ALL_ONES - try: - return self._prefix_from_ip_int(ip_int) - except ValueError: - self._report_invalid_netmask(ip_str) - - -@functools.total_ordering -class _BaseAddress(_IPAddressBase): - - """A generic IP object. - - This IP class contains the version independent methods which are - used by single IP addresses. - """ - - def __init__(self, address): - if (not isinstance(address, bytes) - and '/' in str(address)): - raise AddressValueError("Unexpected '/' in %r" % address) - - def __int__(self): - return self._ip - - def __eq__(self, other): - try: - return (self._ip == other._ip - and self._version == other._version) - except AttributeError: - return NotImplemented - - def __lt__(self, other): - if not isinstance(other, _BaseAddress): - return NotImplemented - if self._version != other._version: - raise TypeError('%s and %s are not of the same version' % ( - self, other)) - if self._ip != other._ip: - return self._ip < other._ip - return False - - # Shorthand for Integer addition and subtraction. This is not - # meant to ever support addition/subtraction of addresses. - def __add__(self, other): - if not isinstance(other, int): - return NotImplemented - return self.__class__(int(self) + other) - - def __sub__(self, other): - if not isinstance(other, int): - return NotImplemented - return self.__class__(int(self) - other) - - def __repr__(self): - return '%s(%r)' % (self.__class__.__name__, str(self)) - - def __str__(self): - return str(self._string_from_ip_int(self._ip)) - - def __hash__(self): - return hash(hex(int(self._ip))) - - def _get_address_key(self): - return (self._version, self) - - -@functools.total_ordering -class _BaseNetwork(_IPAddressBase): - - """A generic IP network object. - - This IP class contains the version independent methods which are - used by networks. - - """ - def __init__(self, address): - self._cache = {} - - def __repr__(self): - return '%s(%r)' % (self.__class__.__name__, str(self)) - - def __str__(self): - return '%s/%d' % (self.network_address, self.prefixlen) - - def hosts(self): - """Generate Iterator over usable hosts in a network. - - This is like __iter__ except it doesn't return the network - or broadcast addresses. - - """ - network = int(self.network_address) - broadcast = int(self.broadcast_address) - for x in range(network + 1, broadcast): - yield self._address_class(x) - - def __iter__(self): - network = int(self.network_address) - broadcast = int(self.broadcast_address) - for x in range(network, broadcast + 1): - yield self._address_class(x) - - def __getitem__(self, n): - network = int(self.network_address) - broadcast = int(self.broadcast_address) - if n >= 0: - if network + n > broadcast: - raise IndexError - return self._address_class(network + n) - else: - n += 1 - if broadcast + n < network: - raise IndexError - return self._address_class(broadcast + n) - - def __lt__(self, other): - if not isinstance(other, _BaseNetwork): - return NotImplemented - if self._version != other._version: - raise TypeError('%s and %s are not of the same version' % ( - self, other)) - if self.network_address != other.network_address: - return self.network_address < other.network_address - if self.netmask != other.netmask: - return self.netmask < other.netmask - return False - - def __eq__(self, other): - try: - return (self._version == other._version and - self.network_address == other.network_address and - int(self.netmask) == int(other.netmask)) - except AttributeError: - return NotImplemented - - def __hash__(self): - return hash(int(self.network_address) ^ int(self.netmask)) - - def __contains__(self, other): - # always false if one is v4 and the other is v6. - if self._version != other._version: - return False - # dealing with another network. - if isinstance(other, _BaseNetwork): - return False - # dealing with another address - else: - # address - return (int(self.network_address) <= int(other._ip) <= - int(self.broadcast_address)) - - def overlaps(self, other): - """Tell if self is partly contained in other.""" - return self.network_address in other or ( - self.broadcast_address in other or ( - other.network_address in self or ( - other.broadcast_address in self))) - - @property - def broadcast_address(self): - x = self._cache.get('broadcast_address') - if x is None: - x = self._address_class(int(self.network_address) | - int(self.hostmask)) - self._cache['broadcast_address'] = x - return x - - @property - def hostmask(self): - x = self._cache.get('hostmask') - if x is None: - x = self._address_class(int(self.netmask) ^ self._ALL_ONES) - self._cache['hostmask'] = x - return x - - @property - def with_prefixlen(self): - return '%s/%d' % (self.network_address, self._prefixlen) - - @property - def with_netmask(self): - return '%s/%s' % (self.network_address, self.netmask) - - @property - def with_hostmask(self): - return '%s/%s' % (self.network_address, self.hostmask) - - @property - def num_addresses(self): - """Number of hosts in the current subnet.""" - return int(self.broadcast_address) - int(self.network_address) + 1 - - @property - def _address_class(self): - # Returning bare address objects (rather than interfaces) allows for - # more consistent behaviour across the network address, broadcast - # address and individual host addresses. - msg = '%200s has no associated address class' % (type(self),) - raise NotImplementedError(msg) - - @property - def prefixlen(self): - return self._prefixlen - - def address_exclude(self, other): - """Remove an address from a larger block. - - For example: - - addr1 = ip_network('192.0.2.0/28') - addr2 = ip_network('192.0.2.1/32') - addr1.address_exclude(addr2) = - [IPv4Network('192.0.2.0/32'), IPv4Network('192.0.2.2/31'), - IPv4Network('192.0.2.4/30'), IPv4Network('192.0.2.8/29')] - - or IPv6: - - addr1 = ip_network('2001:db8::1/32') - addr2 = ip_network('2001:db8::1/128') - addr1.address_exclude(addr2) = - [ip_network('2001:db8::1/128'), - ip_network('2001:db8::2/127'), - ip_network('2001:db8::4/126'), - ip_network('2001:db8::8/125'), - ... - ip_network('2001:db8:8000::/33')] - - Args: - other: An IPv4Network or IPv6Network object of the same type. - - Returns: - An iterator of the IPv(4|6)Network objects which is self - minus other. - - Raises: - TypeError: If self and other are of differing address - versions, or if other is not a network object. - ValueError: If other is not completely contained by self. - - """ - if not self._version == other._version: - raise TypeError("%s and %s are not of the same version" % ( - self, other)) - - if not isinstance(other, _BaseNetwork): - raise TypeError("%s is not a network object" % other) - - if not (other.network_address >= self.network_address and - other.broadcast_address <= self.broadcast_address): - raise ValueError('%s not contained in %s' % (other, self)) - if other == self: - raise StopIteration - - # Make sure we're comparing the network of other. - other = other.__class__('%s/%s' % (other.network_address, - other.prefixlen)) - - s1, s2 = self.subnets() - while s1 != other and s2 != other: - if (other.network_address >= s1.network_address and - other.broadcast_address <= s1.broadcast_address): - yield s2 - s1, s2 = s1.subnets() - elif (other.network_address >= s2.network_address and - other.broadcast_address <= s2.broadcast_address): - yield s1 - s1, s2 = s2.subnets() - else: - # If we got here, there's a bug somewhere. - raise AssertionError('Error performing exclusion: ' - 's1: %s s2: %s other: %s' % - (s1, s2, other)) - if s1 == other: - yield s2 - elif s2 == other: - yield s1 - else: - # If we got here, there's a bug somewhere. - raise AssertionError('Error performing exclusion: ' - 's1: %s s2: %s other: %s' % - (s1, s2, other)) - - def compare_networks(self, other): - """Compare two IP objects. - - This is only concerned about the comparison of the integer - representation of the network addresses. This means that the - host bits aren't considered at all in this method. If you want - to compare host bits, you can easily enough do a - 'HostA._ip < HostB._ip' - - Args: - other: An IP object. - - Returns: - If the IP versions of self and other are the same, returns: - - -1 if self < other: - eg: IPv4Network('192.0.2.0/25') < IPv4Network('192.0.2.128/25') - IPv6Network('2001:db8::1000/124') < - IPv6Network('2001:db8::2000/124') - 0 if self == other - eg: IPv4Network('192.0.2.0/24') == IPv4Network('192.0.2.0/24') - IPv6Network('2001:db8::1000/124') == - IPv6Network('2001:db8::1000/124') - 1 if self > other - eg: IPv4Network('192.0.2.128/25') > IPv4Network('192.0.2.0/25') - IPv6Network('2001:db8::2000/124') > - IPv6Network('2001:db8::1000/124') - - Raises: - TypeError if the IP versions are different. - - """ - # does this need to raise a ValueError? - if self._version != other._version: - raise TypeError('%s and %s are not of the same type' % ( - self, other)) - # self._version == other._version below here: - if self.network_address < other.network_address: - return -1 - if self.network_address > other.network_address: - return 1 - # self.network_address == other.network_address below here: - if self.netmask < other.netmask: - return -1 - if self.netmask > other.netmask: - return 1 - return 0 - - def _get_networks_key(self): - """Network-only key function. - - Returns an object that identifies this address' network and - netmask. This function is a suitable "key" argument for sorted() - and list.sort(). - - """ - return (self._version, self.network_address, self.netmask) - - def subnets(self, prefixlen_diff=1, new_prefix=None): - """The subnets which join to make the current subnet. - - In the case that self contains only one IP - (self._prefixlen == 32 for IPv4 or self._prefixlen == 128 - for IPv6), yield an iterator with just ourself. - - Args: - prefixlen_diff: An integer, the amount the prefix length - should be increased by. This should not be set if - new_prefix is also set. - new_prefix: The desired new prefix length. This must be a - larger number (smaller prefix) than the existing prefix. - This should not be set if prefixlen_diff is also set. - - Returns: - An iterator of IPv(4|6) objects. - - Raises: - ValueError: The prefixlen_diff is too small or too large. - OR - prefixlen_diff and new_prefix are both set or new_prefix - is a smaller number than the current prefix (smaller - number means a larger network) - - """ - if self._prefixlen == self._max_prefixlen: - yield self - return - - if new_prefix is not None: - if new_prefix < self._prefixlen: - raise ValueError('new prefix must be longer') - if prefixlen_diff != 1: - raise ValueError('cannot set prefixlen_diff and new_prefix') - prefixlen_diff = new_prefix - self._prefixlen - - if prefixlen_diff < 0: - raise ValueError('prefix length diff must be > 0') - new_prefixlen = self._prefixlen + prefixlen_diff - - if new_prefixlen > self._max_prefixlen: - raise ValueError( - 'prefix length diff %d is invalid for netblock %s' % ( - new_prefixlen, self)) - - first = self.__class__('%s/%s' % - (self.network_address, - self._prefixlen + prefixlen_diff)) - - yield first - current = first - while True: - broadcast = current.broadcast_address - if broadcast == self.broadcast_address: - return - new_addr = self._address_class(int(broadcast) + 1) - current = self.__class__('%s/%s' % (new_addr, - new_prefixlen)) - - yield current - - def supernet(self, prefixlen_diff=1, new_prefix=None): - """The supernet containing the current network. - - Args: - prefixlen_diff: An integer, the amount the prefix length of - the network should be decreased by. For example, given a - /24 network and a prefixlen_diff of 3, a supernet with a - /21 netmask is returned. - - Returns: - An IPv4 network object. - - Raises: - ValueError: If self.prefixlen - prefixlen_diff < 0. I.e., you have - a negative prefix length. - OR - If prefixlen_diff and new_prefix are both set or new_prefix is a - larger number than the current prefix (larger number means a - smaller network) - - """ - if self._prefixlen == 0: - return self - - if new_prefix is not None: - if new_prefix > self._prefixlen: - raise ValueError('new prefix must be shorter') - if prefixlen_diff != 1: - raise ValueError('cannot set prefixlen_diff and new_prefix') - prefixlen_diff = self._prefixlen - new_prefix - - if self.prefixlen - prefixlen_diff < 0: - raise ValueError( - 'current prefixlen is %d, cannot have a prefixlen_diff of %d' % - (self.prefixlen, prefixlen_diff)) - # TODO (pmoody): optimize this. - t = self.__class__('%s/%d' % (self.network_address, - self.prefixlen - prefixlen_diff), - strict=False) - return t.__class__('%s/%d' % (t.network_address, t.prefixlen)) - - @property - def is_multicast(self): - """Test if the address is reserved for multicast use. - - Returns: - A boolean, True if the address is a multicast address. - See RFC 2373 2.7 for details. - - """ - return (self.network_address.is_multicast and - self.broadcast_address.is_multicast) - - @property - def is_reserved(self): - """Test if the address is otherwise IETF reserved. - - Returns: - A boolean, True if the address is within one of the - reserved IPv6 Network ranges. - - """ - return (self.network_address.is_reserved and - self.broadcast_address.is_reserved) - - @property - def is_link_local(self): - """Test if the address is reserved for link-local. - - Returns: - A boolean, True if the address is reserved per RFC 4291. - - """ - return (self.network_address.is_link_local and - self.broadcast_address.is_link_local) - - @property - def is_private(self): - """Test if this address is allocated for private networks. - - Returns: - A boolean, True if the address is reserved per - iana-ipv4-special-registry or iana-ipv6-special-registry. - - """ - return (self.network_address.is_private and - self.broadcast_address.is_private) - - @property - def is_global(self): - """Test if this address is allocated for public networks. - - Returns: - A boolean, True if the address is not reserved per - iana-ipv4-special-registry or iana-ipv6-special-registry. - - """ - return not self.is_private - - @property - def is_unspecified(self): - """Test if the address is unspecified. - - Returns: - A boolean, True if this is the unspecified address as defined in - RFC 2373 2.5.2. - - """ - return (self.network_address.is_unspecified and - self.broadcast_address.is_unspecified) - - @property - def is_loopback(self): - """Test if the address is a loopback address. - - Returns: - A boolean, True if the address is a loopback address as defined in - RFC 2373 2.5.3. - - """ - return (self.network_address.is_loopback and - self.broadcast_address.is_loopback) - - -class _BaseV4: - - """Base IPv4 object. - - The following methods are used by IPv4 objects in both single IP - addresses and networks. - - """ - - # Equivalent to 255.255.255.255 or 32 bits of 1's. - _ALL_ONES = (2**IPV4LENGTH) - 1 - _DECIMAL_DIGITS = frozenset('0123456789') - - # the valid octets for host and netmasks. only useful for IPv4. - _valid_mask_octets = frozenset((255, 254, 252, 248, 240, 224, 192, 128, 0)) - - def __init__(self, address): - self._version = 4 - self._max_prefixlen = IPV4LENGTH - - def _explode_shorthand_ip_string(self): - return str(self) - - def _ip_int_from_string(self, ip_str): - """Turn the given IP string into an integer for comparison. - - Args: - ip_str: A string, the IP ip_str. - - Returns: - The IP ip_str as an integer. - - Raises: - AddressValueError: if ip_str isn't a valid IPv4 Address. - - """ - if not ip_str: - raise AddressValueError('Address cannot be empty') - - octets = ip_str.split('.') - if len(octets) != 4: - raise AddressValueError("Expected 4 octets in %r" % ip_str) - - try: - return int.from_bytes(map(self._parse_octet, octets), 'big') - except ValueError as exc: - raise AddressValueError("%s in %r" % (exc, ip_str)) from None - - def _parse_octet(self, octet_str): - """Convert a decimal octet into an integer. - - Args: - octet_str: A string, the number to parse. - - Returns: - The octet as an integer. - - Raises: - ValueError: if the octet isn't strictly a decimal from [0..255]. - - """ - if not octet_str: - raise ValueError("Empty octet not permitted") - # Whitelist the characters, since int() allows a lot of bizarre stuff. - if not self._DECIMAL_DIGITS.issuperset(octet_str): - msg = "Only decimal digits permitted in %r" - raise ValueError(msg % octet_str) - # We do the length check second, since the invalid character error - # is likely to be more informative for the user - if len(octet_str) > 3: - msg = "At most 3 characters permitted in %r" - raise ValueError(msg % octet_str) - # Convert to integer (we know digits are legal) - octet_int = int(octet_str, 10) - # Any octets that look like they *might* be written in octal, - # and which don't look exactly the same in both octal and - # decimal are rejected as ambiguous - if octet_int > 7 and octet_str[0] == '0': - msg = "Ambiguous (octal/decimal) value in %r not permitted" - raise ValueError(msg % octet_str) - if octet_int > 255: - raise ValueError("Octet %d (> 255) not permitted" % octet_int) - return octet_int - - def _string_from_ip_int(self, ip_int): - """Turns a 32-bit integer into dotted decimal notation. - - Args: - ip_int: An integer, the IP address. - - Returns: - The IP address as a string in dotted decimal notation. - - """ - return '.'.join(map(str, ip_int.to_bytes(4, 'big'))) - - def _is_valid_netmask(self, netmask): - """Verify that the netmask is valid. - - Args: - netmask: A string, either a prefix or dotted decimal - netmask. - - Returns: - A boolean, True if the prefix represents a valid IPv4 - netmask. - - """ - mask = netmask.split('.') - if len(mask) == 4: - try: - for x in mask: - if int(x) not in self._valid_mask_octets: - return False - except ValueError: - # Found something that isn't an integer or isn't valid - return False - for idx, y in enumerate(mask): - if idx > 0 and y > mask[idx - 1]: - return False - return True - try: - netmask = int(netmask) - except ValueError: - return False - return 0 <= netmask <= self._max_prefixlen - - def _is_hostmask(self, ip_str): - """Test if the IP string is a hostmask (rather than a netmask). - - Args: - ip_str: A string, the potential hostmask. - - Returns: - A boolean, True if the IP string is a hostmask. - - """ - bits = ip_str.split('.') - try: - parts = [x for x in map(int, bits) if x in self._valid_mask_octets] - except ValueError: - return False - if len(parts) != len(bits): - return False - if parts[0] < parts[-1]: - return True - return False - - @property - def max_prefixlen(self): - return self._max_prefixlen - - @property - def version(self): - return self._version - - -class IPv4Address(_BaseV4, _BaseAddress): - - """Represent and manipulate single IPv4 Addresses.""" - - def __init__(self, address): - - """ - Args: - address: A string or integer representing the IP - - Additionally, an integer can be passed, so - IPv4Address('192.0.2.1') == IPv4Address(3221225985). - or, more generally - IPv4Address(int(IPv4Address('192.0.2.1'))) == - IPv4Address('192.0.2.1') - - Raises: - AddressValueError: If ipaddress isn't a valid IPv4 address. - - """ - _BaseAddress.__init__(self, address) - _BaseV4.__init__(self, address) - - # Efficient constructor from integer. - if isinstance(address, int): - self._check_int_address(address) - self._ip = address - return - - # Constructing from a packed address - if isinstance(address, bytes): - self._check_packed_address(address, 4) - self._ip = int.from_bytes(address, 'big') - return - - # Assume input argument to be string or any object representation - # which converts into a formatted IP string. - addr_str = str(address) - self._ip = self._ip_int_from_string(addr_str) - - @property - def packed(self): - """The binary representation of this address.""" - return v4_int_to_packed(self._ip) - - @property - def is_reserved(self): - """Test if the address is otherwise IETF reserved. - - Returns: - A boolean, True if the address is within the - reserved IPv4 Network range. - - """ - reserved_network = IPv4Network('240.0.0.0/4') - return self in reserved_network - - @property - @functools.lru_cache() - def is_private(self): - """Test if this address is allocated for private networks. - - Returns: - A boolean, True if the address is reserved per - iana-ipv4-special-registry. - - """ - return (self in IPv4Network('0.0.0.0/8') or - self in IPv4Network('10.0.0.0/8') or - self in IPv4Network('127.0.0.0/8') or - self in IPv4Network('169.254.0.0/16') or - self in IPv4Network('172.16.0.0/12') or - self in IPv4Network('192.0.0.0/29') or - self in IPv4Network('192.0.0.170/31') or - self in IPv4Network('192.0.2.0/24') or - self in IPv4Network('192.168.0.0/16') or - self in IPv4Network('198.18.0.0/15') or - self in IPv4Network('198.51.100.0/24') or - self in IPv4Network('203.0.113.0/24') or - self in IPv4Network('240.0.0.0/4') or - self in IPv4Network('255.255.255.255/32')) - - - @property - def is_multicast(self): - """Test if the address is reserved for multicast use. - - Returns: - A boolean, True if the address is multicast. - See RFC 3171 for details. - - """ - multicast_network = IPv4Network('224.0.0.0/4') - return self in multicast_network - - @property - def is_unspecified(self): - """Test if the address is unspecified. - - Returns: - A boolean, True if this is the unspecified address as defined in - RFC 5735 3. - - """ - unspecified_address = IPv4Address('0.0.0.0') - return self == unspecified_address - - @property - def is_loopback(self): - """Test if the address is a loopback address. - - Returns: - A boolean, True if the address is a loopback per RFC 3330. - - """ - loopback_network = IPv4Network('127.0.0.0/8') - return self in loopback_network - - @property - def is_link_local(self): - """Test if the address is reserved for link-local. - - Returns: - A boolean, True if the address is link-local per RFC 3927. - - """ - linklocal_network = IPv4Network('169.254.0.0/16') - return self in linklocal_network - - -class IPv4Interface(IPv4Address): - - def __init__(self, address): - if isinstance(address, (bytes, int)): - IPv4Address.__init__(self, address) - self.network = IPv4Network(self._ip) - self._prefixlen = self._max_prefixlen - return - - addr = _split_optional_netmask(address) - IPv4Address.__init__(self, addr[0]) - - self.network = IPv4Network(address, strict=False) - self._prefixlen = self.network._prefixlen - - self.netmask = self.network.netmask - self.hostmask = self.network.hostmask - - def __str__(self): - return '%s/%d' % (self._string_from_ip_int(self._ip), - self.network.prefixlen) - - def __eq__(self, other): - address_equal = IPv4Address.__eq__(self, other) - if not address_equal or address_equal is NotImplemented: - return address_equal - try: - return self.network == other.network - except AttributeError: - # An interface with an associated network is NOT the - # same as an unassociated address. That's why the hash - # takes the extra info into account. - return False - - def __lt__(self, other): - address_less = IPv4Address.__lt__(self, other) - if address_less is NotImplemented: - return NotImplemented - try: - return self.network < other.network - except AttributeError: - # We *do* allow addresses and interfaces to be sorted. The - # unassociated address is considered less than all interfaces. - return False - - def __hash__(self): - return self._ip ^ self._prefixlen ^ int(self.network.network_address) - - @property - def ip(self): - return IPv4Address(self._ip) - - @property - def with_prefixlen(self): - return '%s/%s' % (self._string_from_ip_int(self._ip), - self._prefixlen) - - @property - def with_netmask(self): - return '%s/%s' % (self._string_from_ip_int(self._ip), - self.netmask) - - @property - def with_hostmask(self): - return '%s/%s' % (self._string_from_ip_int(self._ip), - self.hostmask) - - -class IPv4Network(_BaseV4, _BaseNetwork): - - """This class represents and manipulates 32-bit IPv4 network + addresses.. - - Attributes: [examples for IPv4Network('192.0.2.0/27')] - .network_address: IPv4Address('192.0.2.0') - .hostmask: IPv4Address('0.0.0.31') - .broadcast_address: IPv4Address('192.0.2.32') - .netmask: IPv4Address('255.255.255.224') - .prefixlen: 27 - - """ - # Class to use when creating address objects - _address_class = IPv4Address - - def __init__(self, address, strict=True): - - """Instantiate a new IPv4 network object. - - Args: - address: A string or integer representing the IP [& network]. - '192.0.2.0/24' - '192.0.2.0/255.255.255.0' - '192.0.0.2/0.0.0.255' - are all functionally the same in IPv4. Similarly, - '192.0.2.1' - '192.0.2.1/255.255.255.255' - '192.0.2.1/32' - are also functionally equivalent. That is to say, failing to - provide a subnetmask will create an object with a mask of /32. - - If the mask (portion after the / in the argument) is given in - dotted quad form, it is treated as a netmask if it starts with a - non-zero field (e.g. /255.0.0.0 == /8) and as a hostmask if it - starts with a zero field (e.g. 0.255.255.255 == /8), with the - single exception of an all-zero mask which is treated as a - netmask == /0. If no mask is given, a default of /32 is used. - - Additionally, an integer can be passed, so - IPv4Network('192.0.2.1') == IPv4Network(3221225985) - or, more generally - IPv4Interface(int(IPv4Interface('192.0.2.1'))) == - IPv4Interface('192.0.2.1') - - Raises: - AddressValueError: If ipaddress isn't a valid IPv4 address. - NetmaskValueError: If the netmask isn't valid for - an IPv4 address. - ValueError: If strict is True and a network address is not - supplied. - - """ - - _BaseV4.__init__(self, address) - _BaseNetwork.__init__(self, address) - - # Constructing from a packed address - if isinstance(address, bytes): - self.network_address = IPv4Address(address) - self._prefixlen = self._max_prefixlen - self.netmask = IPv4Address(self._ALL_ONES) - #fixme: address/network test here - return - - # Efficient constructor from integer. - if isinstance(address, int): - self.network_address = IPv4Address(address) - self._prefixlen = self._max_prefixlen - self.netmask = IPv4Address(self._ALL_ONES) - #fixme: address/network test here. - return - - # Assume input argument to be string or any object representation - # which converts into a formatted IP prefix string. - addr = _split_optional_netmask(address) - self.network_address = IPv4Address(self._ip_int_from_string(addr[0])) - - if len(addr) == 2: - try: - # Check for a netmask in prefix length form - self._prefixlen = self._prefix_from_prefix_string(addr[1]) - except NetmaskValueError: - # Check for a netmask or hostmask in dotted-quad form. - # This may raise NetmaskValueError. - self._prefixlen = self._prefix_from_ip_string(addr[1]) - else: - self._prefixlen = self._max_prefixlen - self.netmask = IPv4Address(self._ip_int_from_prefix(self._prefixlen)) - - if strict: - if (IPv4Address(int(self.network_address) & int(self.netmask)) != - self.network_address): - raise ValueError('%s has host bits set' % self) - self.network_address = IPv4Address(int(self.network_address) & - int(self.netmask)) - - if self._prefixlen == (self._max_prefixlen - 1): - self.hosts = self.__iter__ - - @property - @functools.lru_cache() - def is_global(self): - """Test if this address is allocated for public networks. - - Returns: - A boolean, True if the address is not reserved per - iana-ipv4-special-registry. - - """ - return (not (self.network_address in IPv4Network('100.64.0.0/10') and - self.broadcast_address in IPv4Network('100.64.0.0/10')) and - not self.is_private) - - - -class _BaseV6: - - """Base IPv6 object. - - The following methods are used by IPv6 objects in both single IP - addresses and networks. - - """ - - _ALL_ONES = (2**IPV6LENGTH) - 1 - _HEXTET_COUNT = 8 - _HEX_DIGITS = frozenset('0123456789ABCDEFabcdef') - - def __init__(self, address): - self._version = 6 - self._max_prefixlen = IPV6LENGTH - - def _ip_int_from_string(self, ip_str): - """Turn an IPv6 ip_str into an integer. - - Args: - ip_str: A string, the IPv6 ip_str. - - Returns: - An int, the IPv6 address - - Raises: - AddressValueError: if ip_str isn't a valid IPv6 Address. - - """ - if not ip_str: - raise AddressValueError('Address cannot be empty') - - parts = ip_str.split(':') - - # An IPv6 address needs at least 2 colons (3 parts). - _min_parts = 3 - if len(parts) < _min_parts: - msg = "At least %d parts expected in %r" % (_min_parts, ip_str) - raise AddressValueError(msg) - - # If the address has an IPv4-style suffix, convert it to hexadecimal. - if '.' in parts[-1]: - try: - ipv4_int = IPv4Address(parts.pop())._ip - except AddressValueError as exc: - raise AddressValueError("%s in %r" % (exc, ip_str)) from None - parts.append('%x' % ((ipv4_int >> 16) & 0xFFFF)) - parts.append('%x' % (ipv4_int & 0xFFFF)) - - # An IPv6 address can't have more than 8 colons (9 parts). - # The extra colon comes from using the "::" notation for a single - # leading or trailing zero part. - _max_parts = self._HEXTET_COUNT + 1 - if len(parts) > _max_parts: - msg = "At most %d colons permitted in %r" % (_max_parts-1, ip_str) - raise AddressValueError(msg) - - # Disregarding the endpoints, find '::' with nothing in between. - # This indicates that a run of zeroes has been skipped. - skip_index = None - for i in range(1, len(parts) - 1): - if not parts[i]: - if skip_index is not None: - # Can't have more than one '::' - msg = "At most one '::' permitted in %r" % ip_str - raise AddressValueError(msg) - skip_index = i - - # parts_hi is the number of parts to copy from above/before the '::' - # parts_lo is the number of parts to copy from below/after the '::' - if skip_index is not None: - # If we found a '::', then check if it also covers the endpoints. - parts_hi = skip_index - parts_lo = len(parts) - skip_index - 1 - if not parts[0]: - parts_hi -= 1 - if parts_hi: - msg = "Leading ':' only permitted as part of '::' in %r" - raise AddressValueError(msg % ip_str) # ^: requires ^:: - if not parts[-1]: - parts_lo -= 1 - if parts_lo: - msg = "Trailing ':' only permitted as part of '::' in %r" - raise AddressValueError(msg % ip_str) # :$ requires ::$ - parts_skipped = self._HEXTET_COUNT - (parts_hi + parts_lo) - if parts_skipped < 1: - msg = "Expected at most %d other parts with '::' in %r" - raise AddressValueError(msg % (self._HEXTET_COUNT-1, ip_str)) - else: - # Otherwise, allocate the entire address to parts_hi. The - # endpoints could still be empty, but _parse_hextet() will check - # for that. - if len(parts) != self._HEXTET_COUNT: - msg = "Exactly %d parts expected without '::' in %r" - raise AddressValueError(msg % (self._HEXTET_COUNT, ip_str)) - if not parts[0]: - msg = "Leading ':' only permitted as part of '::' in %r" - raise AddressValueError(msg % ip_str) # ^: requires ^:: - if not parts[-1]: - msg = "Trailing ':' only permitted as part of '::' in %r" - raise AddressValueError(msg % ip_str) # :$ requires ::$ - parts_hi = len(parts) - parts_lo = 0 - parts_skipped = 0 - - try: - # Now, parse the hextets into a 128-bit integer. - ip_int = 0 - for i in range(parts_hi): - ip_int <<= 16 - ip_int |= self._parse_hextet(parts[i]) - ip_int <<= 16 * parts_skipped - for i in range(-parts_lo, 0): - ip_int <<= 16 - ip_int |= self._parse_hextet(parts[i]) - return ip_int - except ValueError as exc: - raise AddressValueError("%s in %r" % (exc, ip_str)) from None - - def _parse_hextet(self, hextet_str): - """Convert an IPv6 hextet string into an integer. - - Args: - hextet_str: A string, the number to parse. - - Returns: - The hextet as an integer. - - Raises: - ValueError: if the input isn't strictly a hex number from - [0..FFFF]. - - """ - # Whitelist the characters, since int() allows a lot of bizarre stuff. - if not self._HEX_DIGITS.issuperset(hextet_str): - raise ValueError("Only hex digits permitted in %r" % hextet_str) - # We do the length check second, since the invalid character error - # is likely to be more informative for the user - if len(hextet_str) > 4: - msg = "At most 4 characters permitted in %r" - raise ValueError(msg % hextet_str) - # Length check means we can skip checking the integer value - return int(hextet_str, 16) - - def _compress_hextets(self, hextets): - """Compresses a list of hextets. - - Compresses a list of strings, replacing the longest continuous - sequence of "0" in the list with "" and adding empty strings at - the beginning or at the end of the string such that subsequently - calling ":".join(hextets) will produce the compressed version of - the IPv6 address. - - Args: - hextets: A list of strings, the hextets to compress. - - Returns: - A list of strings. - - """ - best_doublecolon_start = -1 - best_doublecolon_len = 0 - doublecolon_start = -1 - doublecolon_len = 0 - for index, hextet in enumerate(hextets): - if hextet == '0': - doublecolon_len += 1 - if doublecolon_start == -1: - # Start of a sequence of zeros. - doublecolon_start = index - if doublecolon_len > best_doublecolon_len: - # This is the longest sequence of zeros so far. - best_doublecolon_len = doublecolon_len - best_doublecolon_start = doublecolon_start - else: - doublecolon_len = 0 - doublecolon_start = -1 - - if best_doublecolon_len > 1: - best_doublecolon_end = (best_doublecolon_start + - best_doublecolon_len) - # For zeros at the end of the address. - if best_doublecolon_end == len(hextets): - hextets += [''] - hextets[best_doublecolon_start:best_doublecolon_end] = [''] - # For zeros at the beginning of the address. - if best_doublecolon_start == 0: - hextets = [''] + hextets - - return hextets - - def _string_from_ip_int(self, ip_int=None): - """Turns a 128-bit integer into hexadecimal notation. - - Args: - ip_int: An integer, the IP address. - - Returns: - A string, the hexadecimal representation of the address. - - Raises: - ValueError: The address is bigger than 128 bits of all ones. - - """ - if ip_int is None: - ip_int = int(self._ip) - - if ip_int > self._ALL_ONES: - raise ValueError('IPv6 address is too large') - - hex_str = '%032x' % ip_int - hextets = ['%x' % int(hex_str[x:x+4], 16) for x in range(0, 32, 4)] - - hextets = self._compress_hextets(hextets) - return ':'.join(hextets) - - def _explode_shorthand_ip_string(self): - """Expand a shortened IPv6 address. - - Args: - ip_str: A string, the IPv6 address. - - Returns: - A string, the expanded IPv6 address. - - """ - if isinstance(self, IPv6Network): - ip_str = str(self.network_address) - elif isinstance(self, IPv6Interface): - ip_str = str(self.ip) - else: - ip_str = str(self) - - ip_int = self._ip_int_from_string(ip_str) - hex_str = '%032x' % ip_int - parts = [hex_str[x:x+4] for x in range(0, 32, 4)] - if isinstance(self, (_BaseNetwork, IPv6Interface)): - return '%s/%d' % (':'.join(parts), self._prefixlen) - return ':'.join(parts) - - @property - def max_prefixlen(self): - return self._max_prefixlen - - @property - def version(self): - return self._version - - -class IPv6Address(_BaseV6, _BaseAddress): - - """Represent and manipulate single IPv6 Addresses.""" - - def __init__(self, address): - """Instantiate a new IPv6 address object. - - Args: - address: A string or integer representing the IP - - Additionally, an integer can be passed, so - IPv6Address('2001:db8::') == - IPv6Address(42540766411282592856903984951653826560) - or, more generally - IPv6Address(int(IPv6Address('2001:db8::'))) == - IPv6Address('2001:db8::') - - Raises: - AddressValueError: If address isn't a valid IPv6 address. - - """ - _BaseAddress.__init__(self, address) - _BaseV6.__init__(self, address) - - # Efficient constructor from integer. - if isinstance(address, int): - self._check_int_address(address) - self._ip = address - return - - # Constructing from a packed address - if isinstance(address, bytes): - self._check_packed_address(address, 16) - self._ip = int.from_bytes(address, 'big') - return - - # Assume input argument to be string or any object representation - # which converts into a formatted IP string. - addr_str = str(address) - self._ip = self._ip_int_from_string(addr_str) - - @property - def packed(self): - """The binary representation of this address.""" - return v6_int_to_packed(self._ip) - - @property - def is_multicast(self): - """Test if the address is reserved for multicast use. - - Returns: - A boolean, True if the address is a multicast address. - See RFC 2373 2.7 for details. - - """ - multicast_network = IPv6Network('ff00::/8') - return self in multicast_network - - @property - def is_reserved(self): - """Test if the address is otherwise IETF reserved. - - Returns: - A boolean, True if the address is within one of the - reserved IPv6 Network ranges. - - """ - reserved_networks = [IPv6Network('::/8'), IPv6Network('100::/8'), - IPv6Network('200::/7'), IPv6Network('400::/6'), - IPv6Network('800::/5'), IPv6Network('1000::/4'), - IPv6Network('4000::/3'), IPv6Network('6000::/3'), - IPv6Network('8000::/3'), IPv6Network('A000::/3'), - IPv6Network('C000::/3'), IPv6Network('E000::/4'), - IPv6Network('F000::/5'), IPv6Network('F800::/6'), - IPv6Network('FE00::/9')] - - return any(self in x for x in reserved_networks) - - @property - def is_link_local(self): - """Test if the address is reserved for link-local. - - Returns: - A boolean, True if the address is reserved per RFC 4291. - - """ - linklocal_network = IPv6Network('fe80::/10') - return self in linklocal_network - - @property - def is_site_local(self): - """Test if the address is reserved for site-local. - - Note that the site-local address space has been deprecated by RFC 3879. - Use is_private to test if this address is in the space of unique local - addresses as defined by RFC 4193. - - Returns: - A boolean, True if the address is reserved per RFC 3513 2.5.6. - - """ - sitelocal_network = IPv6Network('fec0::/10') - return self in sitelocal_network - - @property - @functools.lru_cache() - def is_private(self): - """Test if this address is allocated for private networks. - - Returns: - A boolean, True if the address is reserved per - iana-ipv6-special-registry. - - """ - return (self in IPv6Network('::1/128') or - self in IPv6Network('::/128') or - self in IPv6Network('::ffff:0:0/96') or - self in IPv6Network('100::/64') or - self in IPv6Network('2001::/23') or - self in IPv6Network('2001:2::/48') or - self in IPv6Network('2001:db8::/32') or - self in IPv6Network('2001:10::/28') or - self in IPv6Network('fc00::/7') or - self in IPv6Network('fe80::/10')) - - @property - def is_global(self): - """Test if this address is allocated for public networks. - - Returns: - A boolean, true if the address is not reserved per - iana-ipv6-special-registry. - - """ - return not self.is_private - - @property - def is_unspecified(self): - """Test if the address is unspecified. - - Returns: - A boolean, True if this is the unspecified address as defined in - RFC 2373 2.5.2. - - """ - return self._ip == 0 - - @property - def is_loopback(self): - """Test if the address is a loopback address. - - Returns: - A boolean, True if the address is a loopback address as defined in - RFC 2373 2.5.3. - - """ - return self._ip == 1 - - @property - def ipv4_mapped(self): - """Return the IPv4 mapped address. - - Returns: - If the IPv6 address is a v4 mapped address, return the - IPv4 mapped address. Return None otherwise. - - """ - if (self._ip >> 32) != 0xFFFF: - return None - return IPv4Address(self._ip & 0xFFFFFFFF) - - @property - def teredo(self): - """Tuple of embedded teredo IPs. - - Returns: - Tuple of the (server, client) IPs or None if the address - doesn't appear to be a teredo address (doesn't start with - 2001::/32) - - """ - if (self._ip >> 96) != 0x20010000: - return None - return (IPv4Address((self._ip >> 64) & 0xFFFFFFFF), - IPv4Address(~self._ip & 0xFFFFFFFF)) - - @property - def sixtofour(self): - """Return the IPv4 6to4 embedded address. - - Returns: - The IPv4 6to4-embedded address if present or None if the - address doesn't appear to contain a 6to4 embedded address. - - """ - if (self._ip >> 112) != 0x2002: - return None - return IPv4Address((self._ip >> 80) & 0xFFFFFFFF) - - -class IPv6Interface(IPv6Address): - - def __init__(self, address): - if isinstance(address, (bytes, int)): - IPv6Address.__init__(self, address) - self.network = IPv6Network(self._ip) - self._prefixlen = self._max_prefixlen - return - - addr = _split_optional_netmask(address) - IPv6Address.__init__(self, addr[0]) - self.network = IPv6Network(address, strict=False) - self.netmask = self.network.netmask - self._prefixlen = self.network._prefixlen - self.hostmask = self.network.hostmask - - def __str__(self): - return '%s/%d' % (self._string_from_ip_int(self._ip), - self.network.prefixlen) - - def __eq__(self, other): - address_equal = IPv6Address.__eq__(self, other) - if not address_equal or address_equal is NotImplemented: - return address_equal - try: - return self.network == other.network - except AttributeError: - # An interface with an associated network is NOT the - # same as an unassociated address. That's why the hash - # takes the extra info into account. - return False - - def __lt__(self, other): - address_less = IPv6Address.__lt__(self, other) - if address_less is NotImplemented: - return NotImplemented - try: - return self.network < other.network - except AttributeError: - # We *do* allow addresses and interfaces to be sorted. The - # unassociated address is considered less than all interfaces. - return False - - def __hash__(self): - return self._ip ^ self._prefixlen ^ int(self.network.network_address) - - @property - def ip(self): - return IPv6Address(self._ip) - - @property - def with_prefixlen(self): - return '%s/%s' % (self._string_from_ip_int(self._ip), - self._prefixlen) - - @property - def with_netmask(self): - return '%s/%s' % (self._string_from_ip_int(self._ip), - self.netmask) - - @property - def with_hostmask(self): - return '%s/%s' % (self._string_from_ip_int(self._ip), - self.hostmask) - - @property - def is_unspecified(self): - return self._ip == 0 and self.network.is_unspecified - - @property - def is_loopback(self): - return self._ip == 1 and self.network.is_loopback - - -class IPv6Network(_BaseV6, _BaseNetwork): - - """This class represents and manipulates 128-bit IPv6 networks. - - Attributes: [examples for IPv6('2001:db8::1000/124')] - .network_address: IPv6Address('2001:db8::1000') - .hostmask: IPv6Address('::f') - .broadcast_address: IPv6Address('2001:db8::100f') - .netmask: IPv6Address('ffff:ffff:ffff:ffff:ffff:ffff:ffff:fff0') - .prefixlen: 124 - - """ - - # Class to use when creating address objects - _address_class = IPv6Address - - def __init__(self, address, strict=True): - """Instantiate a new IPv6 Network object. - - Args: - address: A string or integer representing the IPv6 network or the - IP and prefix/netmask. - '2001:db8::/128' - '2001:db8:0000:0000:0000:0000:0000:0000/128' - '2001:db8::' - are all functionally the same in IPv6. That is to say, - failing to provide a subnetmask will create an object with - a mask of /128. - - Additionally, an integer can be passed, so - IPv6Network('2001:db8::') == - IPv6Network(42540766411282592856903984951653826560) - or, more generally - IPv6Network(int(IPv6Network('2001:db8::'))) == - IPv6Network('2001:db8::') - - strict: A boolean. If true, ensure that we have been passed - A true network address, eg, 2001:db8::1000/124 and not an - IP address on a network, eg, 2001:db8::1/124. - - Raises: - AddressValueError: If address isn't a valid IPv6 address. - NetmaskValueError: If the netmask isn't valid for - an IPv6 address. - ValueError: If strict was True and a network address was not - supplied. - - """ - _BaseV6.__init__(self, address) - _BaseNetwork.__init__(self, address) - - # Efficient constructor from integer. - if isinstance(address, int): - self.network_address = IPv6Address(address) - self._prefixlen = self._max_prefixlen - self.netmask = IPv6Address(self._ALL_ONES) - return - - # Constructing from a packed address - if isinstance(address, bytes): - self.network_address = IPv6Address(address) - self._prefixlen = self._max_prefixlen - self.netmask = IPv6Address(self._ALL_ONES) - return - - # Assume input argument to be string or any object representation - # which converts into a formatted IP prefix string. - addr = _split_optional_netmask(address) - - self.network_address = IPv6Address(self._ip_int_from_string(addr[0])) - - if len(addr) == 2: - # This may raise NetmaskValueError - self._prefixlen = self._prefix_from_prefix_string(addr[1]) - else: - self._prefixlen = self._max_prefixlen - - self.netmask = IPv6Address(self._ip_int_from_prefix(self._prefixlen)) - if strict: - if (IPv6Address(int(self.network_address) & int(self.netmask)) != - self.network_address): - raise ValueError('%s has host bits set' % self) - self.network_address = IPv6Address(int(self.network_address) & - int(self.netmask)) - - if self._prefixlen == (self._max_prefixlen - 1): - self.hosts = self.__iter__ - - def hosts(self): - """Generate Iterator over usable hosts in a network. - - This is like __iter__ except it doesn't return the - Subnet-Router anycast address. - - """ - network = int(self.network_address) - broadcast = int(self.broadcast_address) - for x in range(network + 1, broadcast + 1): - yield self._address_class(x) - - @property - def is_site_local(self): - """Test if the address is reserved for site-local. - - Note that the site-local address space has been deprecated by RFC 3879. - Use is_private to test if this address is in the space of unique local - addresses as defined by RFC 4193. - - Returns: - A boolean, True if the address is reserved per RFC 3513 2.5.6. - - """ - return (self.network_address.is_site_local and - self.broadcast_address.is_site_local) diff --git a/python/flake8.ini b/python/flake8.ini index 029dcdbdc3..7da1f9608e 100644 --- a/python/flake8.ini +++ b/python/flake8.ini @@ -1,3 +1,2 @@ [flake8] -exclude = external max-line-length = 100 diff --git a/python/lib/crypto/__init__.py b/python/lib/crypto/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/python/lib/crypto/asymcrypto.py b/python/lib/crypto/asymcrypto.py deleted file mode 100644 index 87c197ae5e..0000000000 --- a/python/lib/crypto/asymcrypto.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright 2014 ETH Zurich -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -:mod:`asymcrypto` --- SCION asymmetric crypto functions -======================================================= -""" -# External -from nacl.exceptions import BadSignatureError -from nacl.public import Box, PrivateKey, PublicKey -from nacl.signing import SigningKey, VerifyKey -from nacl.utils import random - -# SCION -from lib.errors import SCIONVerificationError - - -def sign(msg, signing_key): - """ - Sign a message with a given signing key and return the signature. - - :param bytes msg: message to be signed. - :param bytes signing_key: signing key from generate_signature_keypair(). - :returns: ed25519 signature. - :rtype: bytes - """ - return SigningKey(signing_key).sign(msg)[:64] - - -def verify(msg, sig, verifying_key): - """ - Verify a signature. - - :param bytes msg: message that was signed. - :param bytes sig: signature to verify. - :param bytes verifying_key: verifying key from generate_signature_keypair(). - :returns: True or False whether the verification succeeds or fails. - :rtype: boolean - """ - try: - return msg == VerifyKey(verifying_key).verify(msg, sig) - except BadSignatureError: - raise SCIONVerificationError("Signature corrupt or forged.") from None - - -def encrypt(msg, private_key, public_key): - """ - Encrypt message. - - :param bytes msg: message to be encrypted. - :param bytes private_key: Private Key of encrypter. - :param bytes public_key: Public Key of decrypter. - :returns: The encrypted message. - :rtype: nacl.utils.EncryptedMessage - """ - return Box(PrivateKey(private_key), PublicKey(public_key)).encrypt(msg, random(Box.NONCE_SIZE)) - - -def decrypt(msg, private_key, public_key): - """ - Decrypt ciphertext. - - :param bytes msg: ciphertext to be decrypted. - :param bytes private_key: Private Key of decrypter. - :param bytes public_key: Public Key of encrypter. - :returns: The decrypted message. - :rtype: bytes - """ - return Box(PrivateKey(private_key), PublicKey(public_key)).decrypt(msg) diff --git a/python/lib/crypto/certificate.py b/python/lib/crypto/certificate.py deleted file mode 100644 index 18468065c3..0000000000 --- a/python/lib/crypto/certificate.py +++ /dev/null @@ -1,213 +0,0 @@ -# Copyright 2014 ETH Zurich -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -:mod:`certificate` --- SCION certificate parser -=============================================== -""" -# Stdlib -import base64 -import json -import time - -# SCION -from lib.crypto.asymcrypto import sign, verify -from lib.errors import SCIONVerificationError - -SUBJECT_STRING = 'Subject' -ISSUER_STRING = 'Issuer' -TRC_VERSION_STRING = "TRCVersion" -VERSION_STRING = 'Version' -COMMENT_STRING = 'Comment' -CAN_ISSUE_STRING = 'CanIssue' -ISSUING_TIME_STRING = 'IssuingTime' -EXPIRATION_TIME_STRING = 'ExpirationTime' -ENC_ALGORITHM_STRING = 'EncAlgorithm' -SUBJECT_ENC_KEY_STRING = 'SubjectEncKey' -SIGN_ALGORITHM_STRING = 'SignAlgorithm' -SUBJECT_SIG_KEY_STRING = 'SubjectSignKey' -SIGNATURE_STRING = 'Signature' - - -class Certificate(object): - """ - The Certificate class parses a certificate of an AS and stores such - information for further use. - - :ivar str subject: the certificate subject. - :ivar str issuer: the certificate issuer. It can only be a core AS. - :ivar int trc_version: the version of the issuing trc. - :ivar int version: the certificate version. - :ivar str comment: is an arbitrary and optional string used by the subject - to describe the certificate - :ivar bool can_issue: describes whether the subject is able to issue - certificates - :ivar int issuing_time: the time at which the certificate was created. - :ivar int expiration_time: the time at which the certificate expires. - :ivar str enc_algorithm: the algorithm used to encrypt messages. - :ivar bytes subject_enc_key: the public key used for decryption. - :ivar str sign_algorithm: the algorithm used to sign the certificate. - :ivar bytes subject_sig_key: the public key used for signing. - :ivar bytes signature: the certificate signature. It is computed over the - rest of the certificate. - :cvar int as_validity_period: - default validity period (in real seconds) of a new regular AS certificate. - :cvar int core_as_validity_period: - default validity period (in real seconds) of a new core AS certificate. - :cvar str sign_algortihm: default algorithm used to sign a certificate. - :cvar str enc_alorithm: default algorithm used to encrypt messages. - """ - SIGN_ALGORTIHM = 'ed25519' - ENC_ALGORITHM = 'curve25519xsalsa20poly1305' - FIELDS_MAP = { - SUBJECT_STRING: ("subject", str), - ISSUER_STRING: ("issuer", str), - TRC_VERSION_STRING: ("trc_version", int), - VERSION_STRING: ("version", int), - COMMENT_STRING: ("comment", str), - CAN_ISSUE_STRING: ("can_issue", bool), - ISSUING_TIME_STRING: ("issuing_time", int), - EXPIRATION_TIME_STRING: ("expiration_time", int), - ENC_ALGORITHM_STRING: ("enc_algorithm", str), - SUBJECT_ENC_KEY_STRING: ("subject_enc_key", bytes), - SIGN_ALGORITHM_STRING: ("sign_algorithm", str), - SUBJECT_SIG_KEY_STRING: ("subject_sig_key", bytes), - SIGNATURE_STRING: ("signature", bytes), - } - - def __init__(self, cert_dict): - """ - :param certificate_file: the name of the certificate file. - :type certificate_file: str - """ - for k, (name, type_) in self.FIELDS_MAP.items(): - val = cert_dict[k] - if type_ in (int,): - val = int(val) - setattr(self, name, val) - self.subject_enc_key_raw = base64.b64decode(self.subject_enc_key) - self.subject_sig_key_raw = base64.b64decode(self.subject_sig_key) - self.signature_raw = base64.b64decode(self.signature) - - def verify(self, subject, verifying_key): - """ - Perform one step verification. - - :param str subject: - the certificate subject. It can either be an AS, an email address or - a domain address. - :param bytes verifying_key: the key to be used for signature verification. - :raises: SCIONVerificationError if the verification fails. - """ - if self.version == 0: - raise SCIONVerificationError("Invalid certificate version 0:\n%s" % self) - if subject != self.subject: - raise SCIONVerificationError( - "The given subject (%s) doesn't match the certificate's subject (%s):\n%s" % - (subject, self.subject, self)) - if int(time.time()) >= self.expiration_time: - raise SCIONVerificationError("This certificate expired:\n%s" % self) - try: - self._verify_signature(self.signature_raw, verifying_key) - except SCIONVerificationError: - raise SCIONVerificationError("Signature verification failed:\n%s" % self) - - def _verify_signature(self, signature, public_key): - """ - Checks if the signature can be verified with the given public key - """ - verify(self._sig_input(), signature, public_key) - - def dict(self, with_signature=True): - """ - Return the certificate information. - - :param bool with_signature: - tells whether the signature must also be included in the returned - data. - :returns: the certificate information. - :rtype: dict - """ - cert_dict = {} - for k, (name, _) in self.FIELDS_MAP.items(): - cert_dict[k] = getattr(self, name) - if not with_signature: - del cert_dict[SIGNATURE_STRING] - return cert_dict - - def sign(self, iss_priv_key): - data = self._sig_input() - self.signature_raw = sign(data, iss_priv_key) - self.signature = base64.b64encode(self.signature_raw).decode('utf-8') - - @classmethod - def from_values(cls, subject, issuer, trc_version, version, comment, can_issue, validity_period, - subject_enc_key, subject_sig_key, iss_priv_key, issuing_time=0): - """ - Generate a Certificate instance. - - :param str subject: - the certificate subject. It can either be an AS, an email address or - a domain address. - :param str issuer: the certificate issuer. It can only be an AS. - :param int trc_version: the version of the issuing certificate/trc. - :param int version: the certificate version. - :param str comment: a comment describing the certificate. - :param bool can_issue: - states whether the subject is allowed to issue certificates for other ASes. - :param int validity_period: the validity period after creation of this certificate. - :param bytes iss_priv_key: - the issuer's signing key. It is used to sign the certificate. - :param bytes subject_sig_key: the public key of the subject. - :param bytes subject_enc_key: the public part of the encryption key. - :param int issuing_time: the certificate issuing time. - In case of 0, the current time is used. - :returns: the newly created Certificate instance. - :rtype: :class:`Certificate` - """ - if not issuing_time: - issuing_time = int(time.time()) - cert_dict = { - SUBJECT_STRING: subject, - ISSUER_STRING: issuer, - TRC_VERSION_STRING: trc_version, - VERSION_STRING: version, - COMMENT_STRING: comment, - CAN_ISSUE_STRING: can_issue, - ISSUING_TIME_STRING: issuing_time, - EXPIRATION_TIME_STRING: issuing_time + validity_period, - ENC_ALGORITHM_STRING: cls.ENC_ALGORITHM, - SUBJECT_ENC_KEY_STRING: - base64.b64encode(subject_enc_key).decode("utf-8"), - SIGN_ALGORITHM_STRING: cls.SIGN_ALGORTIHM, - SUBJECT_SIG_KEY_STRING: - base64.b64encode(subject_sig_key).decode("utf-8"), - SIGNATURE_STRING: "", - } - cert = Certificate(cert_dict) - cert.sign(iss_priv_key) - return cert - - def _sig_input(self): - d = self.dict(False) - j = json.dumps(d, sort_keys=True, separators=(',', ':'), ensure_ascii=False) - return j.encode('utf-8') - - def to_json(self, indent=4): - return json.dumps(self.dict(), sort_keys=True, indent=indent) - - def __str__(self): - return self.to_json(None) - - def __eq__(self, other): # pragma: no cover - return str(self) == str(other) diff --git a/python/lib/crypto/certificate_chain.py b/python/lib/crypto/certificate_chain.py deleted file mode 100644 index a08d7236e9..0000000000 --- a/python/lib/crypto/certificate_chain.py +++ /dev/null @@ -1,138 +0,0 @@ -# Copyright 2014 ETH Zurich -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -:mod:`certificate_chain` --- SCION certificate_chain parser -=========================================================== -""" -# Stdlib -import json - -# External -import lz4 - -# SCION -from lib.crypto.certificate import Certificate -from lib.crypto.trc import ( - ONLINE_KEY_STRING, -) -from lib.errors import SCIONVerificationError, SCIONParseError -from lib.packet.scion_addr import ISD_AS -from lib.util import iso_timestamp - - -class CertificateChain(object): - """ - The CertificateChain class contains an ordered sequence of certificates, in - which: the first certificate is the one at the end of a certificate chain - and the last is the certificate signed by the core ISD. Therefore, starting - from the first one, each certificate should be verified by the next one in - the sequence. - - :ivar list certs: (ordered) certificates forming the chain. - """ - - def __init__(self, cert_list): - """ - :param list(Certificate) cert_list: certificate chain as list. - """ - if len(cert_list) != 2: - raise SCIONParseError("Certificate chains must have length 2.") - self.as_cert = cert_list[0] - self.core_as_cert = cert_list[1] - - @classmethod - def from_raw(cls, chain_raw, lz4_=False): - if lz4_: - chain_raw = lz4.loads(chain_raw).decode("utf-8") - chain = json.loads(chain_raw) - certs = [] - for k in sorted(chain): - cert = Certificate(chain[k]) - certs.append(cert) - return CertificateChain(certs) - - def verify(self, subject, trc): - """ - Perform the entire chain verification. First verifies the AS certificate against the core AS - certificate, then verifies the core AS certificate against the TRC. - - :param str subject: - the subject of the first certificate in the certificate chain. - :param trc: TRC containing all root of trust certificates for one ISD. - :type trc: :class:`TRC` - :raises: SCIONVerificationError if the verification fails. - """ - # Verify AS certificate against core AS certificate - leaf = self.as_cert - core = self.core_as_cert - if leaf.issuing_time < core.issuing_time: - raise SCIONVerificationError( - "AS certificate verification failed: Leaf issued before core certificate. Leaf: %s " - "Core: %s" % (iso_timestamp(leaf.issuing_time), iso_timestamp(core.issuing_time))) - if leaf.expiration_time > core.expiration_time: - raise SCIONVerificationError( - "AS certificate verification failed: Leaf expires after core certificate. Leaf: %s " - "Core: %s" % (iso_timestamp(leaf.expiration_time), - iso_timestamp(core.expiration_time))) - if not core.can_issue: - raise SCIONVerificationError( - "AS certificate verification failed: Core certificate cannot issue certificates") - try: - leaf.verify(subject, core.subject_sig_key_raw) - except SCIONVerificationError as e: - raise SCIONVerificationError("AS certificate verification failed: %s" % e) - # Verify core AS certificate against TRC - if core.expiration_time > trc.exp_time: - raise SCIONVerificationError( - "Core AS certificate verification failed: Core certificate expires after TRC. " - "Core: %s TRC: %s" % (iso_timestamp(core.expiration_time), - iso_timestamp(trc.exp_time))) - try: - core.verify(leaf.issuer, trc.core_ases[core.issuer][ONLINE_KEY_STRING]) - except SCIONVerificationError as e: - raise SCIONVerificationError("Core AS certificate verification failed: %s" % e) - - def get_leaf_isd_as_ver(self): - isd_as = ISD_AS(self.as_cert.subject) - return isd_as, self.as_cert.version - - def to_json(self): - """ - Convert the instance to json format - - :returns: the CertificateChain information. - :rtype: str - """ - chain_dict = {} - index = 0 - for cert in (self.as_cert, self.core_as_cert): - chain_dict[index] = cert.dict(True) - index += 1 - chain_str = json.dumps(chain_dict, indent=4) - return chain_str - - def pack(self, lz4_=False): - ret = self.to_json().encode('utf-8') - if lz4_: - return lz4.dumps(ret) - return ret - - def __len__(self): - return len(self.pack()) - - def __str__(self): - return self.to_json() - - def __eq__(self, other): # pragma: no cover - return str(self) == str(other) diff --git a/python/lib/crypto/symcrypto.py b/python/lib/crypto/symcrypto.py deleted file mode 100644 index d65320efa6..0000000000 --- a/python/lib/crypto/symcrypto.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2014 ETH Zurich -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -:mod:`symcrypto` --- SCION symmetric crypto functions -===================================================== -""" -# Stdlib -import hashlib - -# External packages -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives.ciphers.algorithms import AES -from cryptography.hazmat.primitives.cmac import CMAC - - -def mac(key, msg): - """ - Default MAC function (CMAC using AES-128). - - Args: - key: key for MAC creation. - msg: Plaintext to be MACed, as a bytes object. - - Returns: - MAC output, as a bytes object. - - Raises: - ValueError: An error occurred when key is NULL or ciphertext is NULL. - """ - if key is None: - raise ValueError('Key is NULL.') - elif msg is None: - raise ValueError('Message is NULL.') - else: - cobj = CMAC(AES(key), backend=default_backend()) - cobj.update(msg) - return cobj.finalize() - - -def kdf(secret, phrase): - """ - Default key derivation function. - """ - return hashlib.pbkdf2_hmac('sha256', secret, phrase, 1000)[:16] - - -def sha256(data): - """ - Default hash function. - """ - digest = hashlib.sha256() - digest.update(data) - return digest.digest() - - -# Default hash function -crypto_hash = sha256 diff --git a/python/lib/crypto/trc.py b/python/lib/crypto/trc.py deleted file mode 100644 index c481d141f1..0000000000 --- a/python/lib/crypto/trc.py +++ /dev/null @@ -1,371 +0,0 @@ -# Copyright 2014 ETH Zurich -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -:mod:`trc` --- SCION TRC parser -=============================================== -""" -# Stdlib -import base64 -import copy -import json -import time - -# External -import lz4 - -# SCION -from lib.crypto.asymcrypto import verify, sign -from lib.errors import SCIONParseError, SCIONVerificationError -from lib.packet.scion_addr import ISD_AS - -ISD_STRING = 'ISD' -DESCRIPTION_STRING = 'Description' -VERSION_STRING = 'Version' -CREATION_TIME_STRING = 'CreationTime' -EXPIRATION_TIME_STRING = 'ExpirationTime' -CORE_ASES_STRING = 'CoreASes' -ROOT_CAS_STRING = 'RootCAs' -CERT_LOGS_STRING = 'CertLogs' -THRESHOLD_EEPKI_STRING = 'ThresholdEEPKI' -RAINS_STRING = 'RAINS' -QUORUM_TRC_STRING = 'QuorumTRC' -QUORUM_CAS_STRING = 'QuorumCAs' -GRACE_PERIOD_STRING = 'GracePeriod' -QUARANTINE_STRING = 'Quarantine' -SIGNATURES_STRING = 'Signatures' - -ARPKI_KEY_STRING = 'ARPKIKey' -ARPKI_SRV_STRING = 'ARPKISrv' -CERTIFICATE_STRING = 'Certificate' -OFFLINE_KEY_ALG_STRING = 'OfflineKeyAlg' -OFFLINE_KEY_STRING = 'OfflineKey' -ONLINE_KEY_ALG_STRING = 'OnlineKeyAlg' -ONLINE_KEY_STRING = 'OnlineKey' -ROOT_RAINS_KEY_STRING = 'RootRAINSKey' -TRC_SRV_STRING = 'TRCSrv' - - -class TRC(object): - """ - The TRC class parses the TRC file of an ISD and stores such - information for further use. - - :ivar int isd: the ISD identifier. - :ivar str description: is a human readable description of an ISD. - :ivar int version: the TRC file version. - :ivar int create_time: the TRC file creation timestamp. - :ivar int exp_time: the TRC expiration timestamp. - :ivar dict core_ases: the set of core ASes and their certificates. - :ivar dict root_cas: the set of root CAs and their certificates. - :ivar dict cert_logs: is a dictionary of end entity certificate log servers of - form {name: {"isd_as IP": pub_key}} - :ivar int threshold_eepki: is a threshold number (nonnegative integer) of - CAs that have to sign a domain’s policy - :ivar dict rains: the RAINS section. - :ivar int quorum_trc: number of core ASes necessary to sign a new TRC. - :ivar int quorum_cas: number of CAs necessary to change CA entries - :ivar int grace_period: defines for how long this TRC is valid when a new - TRC is available - :ivar bool quarantine: flag defining whether TRC is valid(quarantine=false) - or an early announcement(quarantine=true) - :ivar dict signatures: signatures generated by a quorum of trust roots. - """ - - FIELDS_MAP = { - ISD_STRING: ("isd", int), - DESCRIPTION_STRING: ("description", str), - VERSION_STRING: ("version", int), - CREATION_TIME_STRING: ("create_time", int), - EXPIRATION_TIME_STRING: ("exp_time", int), - CORE_ASES_STRING: ("core_ases", dict), - ROOT_CAS_STRING: ("root_cas", dict), - CERT_LOGS_STRING: ("cert_logs", dict), - THRESHOLD_EEPKI_STRING: ("threshold_eepki", int), - RAINS_STRING: ("rains", dict), - QUORUM_TRC_STRING: ("quorum_trc", int), - QUORUM_CAS_STRING: ("quorum_cas", int), - QUARANTINE_STRING: ("quarantine", bool), - SIGNATURES_STRING: ("signatures", dict), - GRACE_PERIOD_STRING: ("grace_period", int), - } - - # list of fields in a dict of dicts which have to be encoded/decoded from base64 - MULTI_DICT_DECODE_FIELDS = { - CORE_ASES_STRING: [ONLINE_KEY_STRING, OFFLINE_KEY_STRING], - ROOT_CAS_STRING: [CERTIFICATE_STRING, ONLINE_KEY_STRING, ARPKI_KEY_STRING], - } - - # list of fields in a dict which have to be encoded/decoded - SIMPLE_DICT_DECODE_FIELDS = { - RAINS_STRING: [ROOT_RAINS_KEY_STRING, ONLINE_KEY_STRING], - SIGNATURES_STRING: [], - } - - def __init__(self, trc_dict): - """ - :param dict trc_dict: TRC as dict. - """ - for k, (name, type_) in self.FIELDS_MAP.items(): - val = trc_dict[k] - if type_ in (int,): - val = int(val) - elif type_ in (dict, ): - val = copy.deepcopy(val) - setattr(self, name, val) - - for attr, decode_list in self.MULTI_DICT_DECODE_FIELDS.items(): - field = getattr(self, self.FIELDS_MAP[attr][0]) - for entry in field.values(): - for key in decode_list: - entry[key] = base64.b64decode(entry[key].encode('utf-8')) - - for attr, decode_list in self.SIMPLE_DICT_DECODE_FIELDS.items(): - entry = getattr(self, self.FIELDS_MAP[attr][0]) - if not entry: - continue - for key in decode_list or entry: - entry[key] = base64.b64decode(entry[key].encode('utf-8')) - - for subject, entry in trc_dict[CERT_LOGS_STRING].items(): - try: - addr, pub_key = next(iter(entry.items())) - self.cert_logs[subject][addr] = base64.b64decode(pub_key.encode('utf-8')) - except StopIteration: - raise SCIONParseError("Invalid CertLogs entry for %s: %s", subject, entry) - - def get_isd_ver(self): - return self.isd, self.version - - def get_core_ases(self): - res = [] - for key in self.core_ases: - res.append(ISD_AS(key)) - return res - - def dict(self, with_signatures): - """ - Return the TRC information. - - :param bool with_signatures: - If True, include signatures in the return value. - :returns: the TRC information. - :rtype: dict - """ - trc_dict = {} - for k, (name, _) in self.FIELDS_MAP.items(): - trc_dict[k] = getattr(self, name) - if not with_signatures: - del trc_dict[SIGNATURES_STRING] - return trc_dict - - @classmethod - def from_raw(cls, trc_raw, lz4_=False): - if lz4_: - trc_raw = lz4.loads(trc_raw).decode("utf-8") - trc = json.loads(trc_raw) - return TRC(trc) - - @classmethod - def from_values(cls, isd, description, version, core_ases, root_cas, - cert_logs, threshold_eepki, rains, quorum_trc, - quorum_cas, grace_period, quarantine, signatures, validity_period): - """ - Generate a TRC instance. - """ - now = int(time.time()) - trc_dict = { - ISD_STRING: isd, - DESCRIPTION_STRING: description, - VERSION_STRING: version, - CREATION_TIME_STRING: now, - EXPIRATION_TIME_STRING: now + validity_period, - CORE_ASES_STRING: core_ases, - ROOT_CAS_STRING: root_cas, - CERT_LOGS_STRING: cert_logs, - THRESHOLD_EEPKI_STRING: threshold_eepki, - RAINS_STRING: rains, - QUORUM_TRC_STRING: quorum_trc, - QUORUM_CAS_STRING: quorum_cas, - GRACE_PERIOD_STRING: grace_period, - QUARANTINE_STRING: quarantine, - SIGNATURES_STRING: signatures, - } - trc = TRC(trc_dict) - return trc - - def sign(self, isd_as, sig_priv_key): - """ - Sign TRC and add computed signature to the TRC. - - :param ISD_AS isd_as: the ISD-AS of signer. - :param SigningKey sig_priv_key: the signing key of signer. - """ - data = self._sig_input() - self.signatures[str(isd_as)] = sign(data, sig_priv_key) - - def _sig_input(self): - d = self.dict(False) - for k in d: - if self.FIELDS_MAP[k][1] == dict: - d[k] = self._encode_dict(d[k]) - j = json.dumps(d, sort_keys=True, separators=(',', ':')) - return j.encode('utf-8') - - def _encode_dict(self, dict_): - encoded_dict = {} - for key, val in dict_.items(): - if type(val) is dict: - val = self._encode_sub_dict(val) - elif type(val) is bytes: - val = base64.b64encode(val).decode('utf-8') - encoded_dict[key] = val - return encoded_dict - - def _encode_sub_dict(self, dict_): - encoded_dict = {} - for key, val in dict_.items(): - if type(val) is bytes: - val = base64.b64encode(val).decode('utf-8') - encoded_dict[key] = val - return encoded_dict - - def to_json(self, with_signatures=True): - """ - Convert the instance to json format. - """ - trc_dict = copy.deepcopy(self.dict(with_signatures)) - for field, decode_list in self.MULTI_DICT_DECODE_FIELDS.items(): - for entry in trc_dict[field].values(): - for key in decode_list: - entry[key] = base64.b64encode(entry[key]).decode('utf-8') - for field, decode_list in self.SIMPLE_DICT_DECODE_FIELDS.items(): - entry = trc_dict.get(field, None) - if not entry or (field == SIGNATURES_STRING and not with_signatures): - continue - # Every value is decoded, if decode_list is empty - for key in decode_list or entry: - entry[key] = base64.b64encode(entry[key]).decode('utf-8') - cert_logs = {} - for subject, entry in trc_dict[CERT_LOGS_STRING].items(): - try: - addr = next(iter(entry.keys())) - entry[addr] = base64.b64encode(entry[addr]).decode('utf-8') - cert_logs[subject] = entry - except StopIteration: - pass - trc_dict[CERT_LOGS_STRING] = cert_logs - trc_str = json.dumps(trc_dict, sort_keys=True, indent=4) - return trc_str - - def pack(self, lz4_=False): - ret = self.to_json().encode('utf-8') - if lz4_: - return lz4.dumps(ret) - return ret - - def __str__(self): - return self.to_json() - - def __eq__(self, other): # pragma: no cover - return str(self) == str(other) - - def check_active(self, max_trc=None): - """ - Check if trusted TRC is active and can be used for certificate chain verification. - - :param TRC max_trc: newest available TRC for same ISD. (If none, self is newest TRC) - :raises: SCIONVerificationError - """ - if self.quarantine: - raise SCIONVerificationError("Early announcement") - now = int(time.time()) - if not (self.create_time <= now <= self.exp_time): - raise SCIONVerificationError("Current time outside of validity period. " - "Now %s Creation %s Expiration %s" % - (now, self.create_time, self.exp_time)) - if not max_trc or self.version == max_trc.version: - return - if self.version + 1 != max_trc.version: - raise SCIONVerificationError("Inactive TRC version: %s. Expected %s or %s" % ( - self.version, max_trc.version, max_trc.version - 1)) - if now > max_trc.create_time + max_trc.grace_period: - raise SCIONVerificationError("TRC grace period has passed. Now %s Expiration %s" % ( - now, max_trc.create_time + max_trc.grace_period)) - - def verify(self, trusted_trc): - """ - Verify TRC based on a trusted TRC. - - :param TRC trusted_trc: a verified TRC, used as a trust anchor. - :raises: SCIONVerificationError - """ - if self.version == 0: - raise SCIONVerificationError("Invalid TRC version 0") - if self.isd == trusted_trc.isd: - self.verify_update(trusted_trc) - else: - self.verify_xsig(trusted_trc) - - def verify_update(self, old_trc): - """ - Verify TRC update. - Unsuccessful verification raises an error. - - :param TRC old_trc: a verified TRC, used as a trust anchor. - :raises: SCIONVerificationError - """ - if old_trc.isd != self.isd: - raise SCIONVerificationError("Invalid TRC ISD %s. Expected %s" % ( - self.isd, old_trc.isd)) - if old_trc.version + 1 != self.version: - raise SCIONVerificationError("Invalid TRC version %s. Expected %s" % ( - self.version, old_trc.version + 1)) - if self.create_time < old_trc.create_time + old_trc.grace_period: - raise SCIONVerificationError("Invalid timestamp %s. Expected > %s " % ( - self.create_time, old_trc.create_time + old_trc.grace_period)) - if self.quarantine or old_trc.quarantine: - raise SCIONVerificationError("Early announcement") - self._verify_signatures(old_trc) - - def verify_xsig(self, neigh_trc): - """ - Verify cross signatures. - - :param TRC neigh_trc: neighbour TRC, used as a trust anchor. - :raises: SCIONVerificationError - """ - pass - - def _verify_signatures(self, old_trc): - """ - Perform signature verification for core signatures as defined - in old TRC. Raises an error if verification is unsuccessful. - - :param: TRC old_trc: the previous TRC which has already been verified. - :raises: SCIONVerificationError - """ - # Only look at signatures which are from core ASes as defined in old TRC - val_count = 0 - # Count number of verifiable signatures - for signer in old_trc.core_ases.keys(): - public_key = old_trc.core_ases[signer][ONLINE_KEY_STRING] - try: - verify(self._sig_input(), self.signatures[signer], public_key) - val_count += 1 - except (SCIONVerificationError, KeyError): - continue - # Check if enough valid signatures - if val_count < old_trc.quorum_trc: - raise SCIONVerificationError("Not enough valid signatures %s. Expected %s" % ( - val_count, old_trc.quorum_trc)) diff --git a/python/lib/crypto/util.py b/python/lib/crypto/util.py deleted file mode 100644 index d4a2894fe4..0000000000 --- a/python/lib/crypto/util.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2017 ETH Zurich -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -:mod:`util` --- SCION crypto utilities -=============================== - -Various utilities for SCION functionality. -""" -# Stdlib -import os - - -def get_ca_private_key_file_path(conf_dir, name): - """ - Return the ca private key file path - """ - return os.path.join(conf_dir, "%s.key" % name) - - -def get_ca_cert_file_path(conf_dir, name): - """ - Return the ca certificate file path - """ - return os.path.join(conf_dir, "%s.cert" % name) diff --git a/python/lib/defines.py b/python/lib/defines.py index e9d1908073..bef592788c 100644 --- a/python/lib/defines.py +++ b/python/lib/defines.py @@ -17,20 +17,6 @@ Contains constant definitions used throughout the codebase. """ -#: SCION protocol version -SCION_PROTO_VERSION = 0 - -#: Default TTL of a PathSegment in realtime seconds. -DEFAULT_SEGMENT_TTL = 6 * 60 * 60 -#: Max TTL of a PathSegment in realtime seconds. -MAX_SEGMENT_TTL = 24 * 60 * 60 -#: Time unit for HOF expiration. -EXP_TIME_UNIT = int(MAX_SEGMENT_TTL / 2 ** 8) -#: Max number of supported HopByHop extensions (does not include SCMP) -MAX_HOPBYHOP_EXT = 3 -#: Number of bytes per 'line'. Used for padding in many places. -LINE_LEN = 8 - #: Generated files directory GEN_PATH = 'gen' #: Topology configuration @@ -44,32 +30,13 @@ #: Prometheus config PROM_FILE = "prometheus.yml" -#: Buffer size for receiving packets -SCION_BUFLEN = 65535 -#: Default SCION endhost data port -SCION_UDP_EH_DATA_PORT = 30041 #: Default SCION router UDP port. SCION_ROUTER_PORT = 50000 -#: Default SCION dispatcher UNIX socket directory -DISPATCHER_DIR = "/run/shm/dispatcher" -#: Default SCION dispatcher ID -DEFAULT_DISPATCHER_ID = "default" - -#: Dispatcher registration timeout -DISPATCHER_TIMEOUT = 60.0 #: Default MTU - assumes overlay is ipv4+udp DEFAULT_MTU = 1500 - 20 - 8 #: IPv6 min value SCION_MIN_MTU = 1280 -#: Length of opaque fields -OPAQUE_FIELD_LEN = 8 - -PATH_FLAG_CACHEONLY = "CACHE_ONLY" - -# Minimum revocation TTL in seconds -MIN_REVOCATION_TTL = 10 -REVOCATION_GRACE = 1 # Default IPv6 network, our equivalent to 127.0.0.0/8 # https://en.wikipedia.org/wiki/Unique_local_address#Definition diff --git a/python/lib/errors.py b/python/lib/errors.py index a9766c320a..e896f0716c 100644 --- a/python/lib/errors.py +++ b/python/lib/errors.py @@ -37,25 +37,9 @@ class SCIONIOError(SCIONBaseError): """IO error""" -class SCIONIndexError(SCIONBaseError): - """Index error (accessing out of bound index on array)""" - - -class SCIONKeyError(SCIONBaseError): - """Key error (trying to access invalid entry in dictionary)""" - - class SCIONYAMLError(SCIONBaseError): """YAML parsing error""" class SCIONParseError(SCIONBaseError): """Parsing error""" - - -class SCIONTypeError(SCIONBaseError): - """Wrong type""" - - -class SCIONVerificationError(SCIONBaseError): - """MAC/Signature verification error""" diff --git a/python/lib/log.py b/python/lib/log.py deleted file mode 100644 index 6c5543103d..0000000000 --- a/python/lib/log.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright 2015 ETH Zurich -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -:mod:`log` --- Logging utilites -=============================== -""" -# Stdlib -import logging -import logging.handlers -import traceback -from datetime import datetime, timezone - -# This file should not include other SCION libraries, to prevent circular import -# errors. - -#: Bytes -LOG_MAX_SIZE = 10 * 1024 * 1024 -LOG_BACKUP_COUNT = 1 - -# Logging handlers that will log logging exceptions, and then re-raise them. The -# default behaviour of python's logging handlers is to catch logging exceptions, -# which hides the problem. -# -# We don't try to use the normal logging system at this point because we don't -# know if that's working at all. If it is (e.g. when the exception is a -# formatting error), when we re-raise the exception, it'll get handled by the -# normal process. - -_dispatch_formatter = None - - -def _handleError(self, _): - self.stream.write("Exception in logging module:\n") - for line in traceback.format_exc().split("\n"): - self.stream.write(line+"\n") - self.flush() - raise - - -class _RotatingErrorHandler(logging.handlers.RotatingFileHandler): - handleError = _handleError - - -class _ConsoleErrorHandler(logging.StreamHandler): - handleError = _handleError - - -class Rfc3339Formatter(logging.Formatter): - def format(self, record): # pragma: no cover - lines = super().format(record).splitlines() - return "\n> ".join(lines) - - def formatTime(self, record, _): # pragma: no cover - # Not using lib.util.iso_timestamp here, to avoid potential import - # loops. - # Also, using str on a datetime object inserts a ":" into the time zone, - # which, while legal, is inconsistent with logging in Go. Fortunately, - # Python's strftime does the right thing. - return datetime.fromtimestamp( - record.created, tz=timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f%z") - - -class DispatchFormatter: # pragma: no cover - """ - A dispatching formatter that allows modules to install custom formatters for - their child loggers. - """ - def __init__(self, default_formatter, formatters=None): - self._default_formatter = default_formatter - self._formatters = formatters or {} - - def add_formatter(self, key, formatter): - self._formatters[key] = formatter - - def format(self, record): - formatter = self._formatters.get(record.name, self._default_formatter) - return formatter.format(record) - - -def add_formatter(name, formatter): # pragma: no cover - _dispatch_formatter.add_formatter(name, formatter) - - -def init_logging(log_base=None, file_level=logging.DEBUG, - console_level=logging.NOTSET): - """ - Configure logging for components (servers, routers, gateways). - """ - default_formatter = Rfc3339Formatter( - "%(asctime)s [%(levelname)s] (%(threadName)s) %(message)s") - global _dispatch_formatter - _dispatch_formatter = DispatchFormatter(default_formatter) - handlers = [] - if log_base: - for lvl in sorted(logging._levelToName): - if lvl < file_level: - continue - log_file = "%s.%s" % (log_base, logging._levelToName[lvl]) - h = _RotatingErrorHandler( - log_file, maxBytes=LOG_MAX_SIZE, backupCount=LOG_BACKUP_COUNT, - encoding="utf-8") - h.setLevel(lvl) - handlers.append(h) - if console_level: - h = _ConsoleErrorHandler() - h.setLevel(console_level) - handlers.append(h) - for h in handlers: - h.setFormatter(_dispatch_formatter) - # Use logging.DEBUG here, so that the handlers themselves can decide what to - # filter. - logging.basicConfig(level=logging.DEBUG, handlers=handlers) - - -def log_exception(msg, *args, level=logging.CRITICAL, **kwargs): - """ - Properly format an exception before logging. - """ - logging.log(level, msg, *args, **kwargs) - for line in traceback.format_exc().split("\n"): - logging.log(level, line) - - -def log_stack(level=logging.DEBUG): - logging.log(level, "".join(traceback.format_stack())) diff --git a/python/lib/packet/__init__.py b/python/lib/packet/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/python/lib/packet/host_addr.py b/python/lib/packet/host_addr.py deleted file mode 100644 index 5336cac4fd..0000000000 --- a/python/lib/packet/host_addr.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright 2015 ETH Zurich -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -:mod:`host_addr` --- L2 host address library -============================================ -""" - -# Stdlib -import struct - -# External -from external.ipaddress import ( - AddressValueError, - IPV4LENGTH, - IPV6LENGTH, - IPv4Interface, - IPv6Interface, -) - -# SCION -from lib.errors import SCIONBaseError, SCIONParseError -from lib.packet.packet_base import Serializable -from lib.types import AddrType -from lib.util import Raw - - -class HostAddrInvalidType(SCIONBaseError): - """ - HostAddr type is invalid. - """ - pass - - -class HostAddrBase(Serializable): - """ - Base HostAddr class. Should not be used directly. - """ - TYPE = None - LEN = None - - def __init__(self, addr, raw=True): # pragma: no cover - """ - :param addr: Address to parse/store. - :param bool raw: Does the address need to be parsed? - """ - self.addr = None - if raw: - self._parse(addr) - else: - self.addr = addr - - def from_values(self, *args, **kwargs): - raise NotImplementedError - - @classmethod - def name(cls): - return AddrType.to_str(cls.TYPE) - - def __str__(self): # pragma: no cover - return str(self.addr) - - def __len__(self): # pragma: no cover - return self.LEN - - def __eq__(self, other): # pragma: no cover - if other is None: - return False - return (self.TYPE == other.TYPE) and (self.addr == other.addr) - - def __lt__(self, other): # pragma: no cover - return str(self) < str(other) - - def __hash__(self): - return hash(self.pack()) - - -class HostAddrNone(HostAddrBase): # pragma: no cover - """ - Host "None" address. Used to indicate there's no address. - """ - TYPE = AddrType.NONE - LEN = 0 - - def __init__(self): - self.addr = None - - def _parse(self, raw): - raise NotImplementedError - - def pack(self): - return b"" - - -class HostAddrIPv4(HostAddrBase): - """ - Host IPv4 address. - """ - TYPE = AddrType.IPV4 - LEN = IPV4LENGTH // 8 - - def _parse(self, raw): - """ - Parse IPv4 address - - :param raw: Can be either `bytes` or `str` - """ - try: - intf = IPv4Interface(raw) - except AddressValueError as e: - raise SCIONParseError("Unable to parse %s address: %s" % - (self.name(), e)) from None - self.addr = intf.ip - - def pack(self): # pragma: no cover - return self.addr.packed - - -class HostAddrIPv6(HostAddrBase): - """ - Host IPv6 address. - """ - TYPE = AddrType.IPV6 - LEN = IPV6LENGTH // 8 - - def _parse(self, raw): - """ - Parse IPv6 address - - :param raw: Can be either `bytes` or `str` - """ - try: - intf = IPv6Interface(raw) - except AddressValueError as e: - raise SCIONParseError("Unable to parse %s address: %s" % - (self.name(), e)) from None - self.addr = intf.ip - - def pack(self): # pragma: no cover - return self.addr.packed - - -class HostAddrSVC(HostAddrBase): - """ - Host "SVC" address. This is a pseudo- address type used for SCION services. - """ - TYPE = AddrType.SVC - LEN = 2 - NAME = "HostAddrSVC" - MCAST = 0x8000 - - def _parse(self, raw): - data = Raw(raw, self.NAME, self.LEN) - self.addr = struct.unpack("!H", data.pop(self.LEN))[0] - - def pack(self): # pragma: no cover - return struct.pack("!H", self.addr) - - def is_mcast(self): # pragma: no cover - return self.addr & self.MCAST - - def multicast(self): - return HostAddrSVC(self.addr | self.MCAST, raw=False) - - def anycast(self): - return HostAddrSVC(self.addr & ~self.MCAST, raw=False) - - def __str__(self): - s = "0x%02x" % (self.addr & ~self.MCAST) - if self.is_mcast(): - return s + " M" - return s + " A" - - -_map = { - # By type - AddrType.NONE: HostAddrNone, - AddrType.IPV4: HostAddrIPv4, - AddrType.IPV6: HostAddrIPv6, - AddrType.SVC: HostAddrSVC, - # By name - "NONE": HostAddrNone, - "IPV4": HostAddrIPv4, - "IPV6": HostAddrIPv6, - "SVC": HostAddrSVC, -} - - -def haddr_get_type(type_): # pragma: no cover - r""" - Look up host address class by type. - - :param type\_: host address type. E.g. ``1`` or ``"IPV4"``. - :type type\_: int or string - """ - try: - return _map[type_] - except KeyError: - raise HostAddrInvalidType("Unknown host addr type '%s'" % - type_) from None - - -def haddr_parse(type_, *args, **kwargs): # pragma: no cover - r""" - Parse host address and return object. - - :param type\_: host address type. E.g. ``1`` or ``"IPV4"``. - :type type\_: int or string - :param \*args: - Arguments to pass to the host address object constructor. E.g. - ``"127.0.0.1"``. - :param \*\*kwargs: - Keyword args to pass to the host address object constructor. E.g. - ``raw=False``. - """ - typecls = haddr_get_type(type_) - return typecls(*args, **kwargs) - - -def haddr_parse_interface(intf): - """ - Try to parse a string as either an ipv6 or ipv4 interface - - :param str interface: E.g. ``127.0.0.1/8``. - """ - for type_ in AddrType.IPV6, AddrType.IPV4: - try: - return haddr_parse(type_, intf) - except SCIONParseError: - pass - else: - raise SCIONParseError("Unable to parse interface '%s'" % intf) diff --git a/python/lib/packet/packet_base.py b/python/lib/packet/packet_base.py deleted file mode 100644 index b2a3ebb20f..0000000000 --- a/python/lib/packet/packet_base.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2014 ETH Zurich -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -:mod:`packet_base` --- Packet base class -======================================== -""" -# Stdlib -from abc import ABCMeta, abstractmethod - - -class Serializable(object, metaclass=ABCMeta): # pragma: no cover - """ - Base class for all objects which serialize into raw bytes. - """ - def __init__(self, raw=None): - if raw: - self._parse(raw) - - @abstractmethod - def _parse(self, raw): - raise NotImplementedError - - @abstractmethod - def from_values(self, *args, **kwargs): - raise NotImplementedError - - @abstractmethod - def pack(self): - raise NotImplementedError - - @abstractmethod - def __len__(self): - raise NotImplementedError - - @abstractmethod - def __str__(self): - raise NotImplementedError diff --git a/python/lib/packet/scion_addr.py b/python/lib/scion_addr.py similarity index 56% rename from python/lib/packet/scion_addr.py rename to python/lib/scion_addr.py index f9d41050be..9e4b0fa3c9 100644 --- a/python/lib/packet/scion_addr.py +++ b/python/lib/scion_addr.py @@ -15,20 +15,12 @@ :mod:`scion_addr` --- SCION host address specifications ======================================================= """ -# Stdlib -import struct # SCION -from lib.errors import SCIONIndexError, SCIONParseError -from lib.packet.packet_base import Serializable -from lib.packet.host_addr import ( - HostAddrBase, - haddr_get_type, -) -from lib.util import Raw +from lib.errors import SCIONParseError -class ISD_AS(Serializable): +class ISD_AS: """ Class for representing ISD-AS pair. The underlying type is a 64-bit unsigned int; ISD is represented by the top 16 bits (though the top 4 bits are currently reserved), and AS by the @@ -36,8 +28,6 @@ class ISD_AS(Serializable): See formatting and allocations here: https://github.com/scionproto/scion/wiki/ISD-and-AS-numbering """ - NAME = "ISD_AS" - LEN = 8 ISD_BITS = 16 MAX_ISD = (1 << ISD_BITS) - 1 AS_BITS = 48 @@ -52,25 +42,10 @@ class ISD_AS(Serializable): def __init__(self, raw=None): self._isd = 0 self._as = 0 - super().__init__(raw) + if raw: + self._parse(raw) - def _parse(self, raw): # pragma: no cover - if isinstance(raw, bytes): - self._parse_bytes(raw) - elif isinstance(raw, int): - self._parse_int(raw) - else: - self._parse_str(raw) - - def _parse_bytes(self, raw): - """ - :param bytes raw: a byte string containing a 64-bit unsigned integer. - """ - data = Raw(raw, self.NAME, self.LEN) - isd_as = struct.unpack("!Q", data.pop())[0] - self._parse_int(isd_as) - - def _parse_str(self, raw): + def _parse(self, raw): """ :param str raw: a string of the format "isd-as". """ @@ -129,16 +104,6 @@ def _parse_int(self, raw): self._isd = raw >> self.AS_BITS self._as = raw & self.MAX_AS - @classmethod - def from_values(cls, isd, as_): # pragma: no cover - inst = cls() - inst._isd = isd - inst._as = as_ - return inst - - def pack(self): - return struct.pack("!Q", self.int()) - def int(self): isd_as = self._isd << self.AS_BITS isd_as |= self._as & self.MAX_AS @@ -150,32 +115,9 @@ def any_as(self): # pragma: no cover def is_zero(self): # pragma: no cover return self._isd == 0 and self._as == 0 - def params(self, name="first"): # pragma: no cover - """Provides parameters for querying PathSegmentDB""" - if self._as == 0: - return {"%s_isd" % name: self._isd} - else: - return {"%s_ia" % name: self} - def __eq__(self, other): # pragma: no cover return self._isd == other._isd and self._as == other._as - def __getitem__(self, idx): # pragma: no cover - if idx == 0: - return self._isd - elif idx == 1: - return self._as - else: - raise SCIONIndexError("Invalid index used on %s object: %s" % ( - (self.NAME, idx))) - - def __int__(self): # pragma: no cover - return self.int() - - def __iter__(self): # pragma: no cover - yield self._isd - yield self._as - def isd_str(self): s = str(self._isd) if self._isd > self.MAX_ISD: @@ -212,77 +154,3 @@ def __len__(self): # pragma: no cover def __hash__(self): # pragma: no cover return hash(str(self)) - - -class SCIONAddr(object): - """ - Class for complete SCION addresses. - - :ivar ISD_AS isd_as: ISD-AS identifier. - :ivar HostAddrBase host: host address. - :ivar int addr_len: address length. - """ - def __init__(self, addr_info=()): # pragma: no cover - """ - Initialize an instance of the class SCIONAddr. - - :param addr_info: Tuple of (addr_type, addr) for the host address - """ - self.isd_as = None - self.host = None - if addr_info: - self._parse(*addr_info) - - def _parse(self, addr_type, raw): - """ - Parse a raw byte string. - - :param int addr_type: Host address type - :param bytes raw: raw bytes. - """ - haddr_type = haddr_get_type(addr_type) - addr_len = ISD_AS.LEN + haddr_type.LEN - data = Raw(raw, "SCIONAddr", addr_len, min_=True) - self.isd_as = ISD_AS(data.pop(ISD_AS.LEN)) - self.host = haddr_type(data.pop(haddr_type.LEN)) - - @classmethod - def from_values(cls, isd_as, host): # pragma: no cover - """ - Create an instance of the class SCIONAddr. - - :param ISD_AS isd_as: ISD-AS identifier. - :param HostAddrBase host: host address - """ - assert isinstance(host, HostAddrBase), type(host) - addr = cls() - addr.isd_as = isd_as - addr.host = host - return addr - - def pack(self): # pragma: no cover - """ - Pack the class variables into a byte string. - - :returns: a byte string containing ISD ID, AS ID, and host address. - :rtype: bytes - """ - return self.isd_as.pack() + self.host.pack() - - @classmethod - def calc_len(cls, type_): # pragma: no cover - class_ = haddr_get_type(type_) - return ISD_AS.LEN + class_.LEN - - def __len__(self): # pragma: no cover - return len(self.isd_as) + len(self.host) - - def __eq__(self, other): # pragma: no cover - return (self.isd_as == other.isd_as and - self.host == other.host) - - def __str__(self): - """ - Return a string containing ISD-AS, and host address. - """ - return "(%s (%s) %s)" % (self.isd_as, self.host.name(), self.host) diff --git a/python/lib/topology.py b/python/lib/topology.py deleted file mode 100644 index b44461d6b2..0000000000 --- a/python/lib/topology.py +++ /dev/null @@ -1,305 +0,0 @@ -# Copyright 2014 ETH Zurich -# Copyright 2018 ETH Zurich, Anapaya Systems -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -:mod:`topology` --- SCION topology parser -=========================================== -""" -# Stdlib -import logging - -# SCION -from lib.errors import SCIONKeyError -from lib.packet.host_addr import haddr_parse_interface -from lib.packet.scion_addr import ISD_AS -from lib.types import ( - LinkType, - ServiceType -) -from lib.util import load_yaml_file - - -class Element(object): - """ - The Element class is the base class for elements specified in the topology - file. - - :ivar HostAddrBase addr: Host address of a server or border router. - :ivar str name: element name or id - """ - def __init__(self, addrs=None, name=None): - """ - :param dict addrs: - contains the public and bind addresses. Only one public/bind addresses pair - is chosen from all the available addresses in the map. - :param str name: element name or id - """ - public, bind = self._get_pub_bind(addrs) - self.public = self._parse_addrs(public) - self.bind = self._parse_addrs(bind) - self.name = None - if name is not None: - self.name = str(name) - - def _get_pub_bind(self, addrs): - if addrs is None: - return None, None - pub_bind = addrs.get('IPv6') - if pub_bind is not None: - return pub_bind['Public'], pub_bind.get('Bind') - pub_bind = addrs.get('IPv4') - if pub_bind is not None: - return pub_bind['Public'], pub_bind.get('Bind') - return None, None - - def _parse_addrs(self, value): - if not value: - return None - return (haddr_parse_interface(value['Addr']), value['L4Port']) - - -class ServerElement(Element): - """The ServerElement class represents one of the servers in the AS.""" - def __init__(self, server_dict, name=None): # pragma: no cover - """ - :param dict server_dict: contains information about a particular server. - :param str name: server element name or id - """ - super().__init__(server_dict['Addrs'], name) - - -class RouterAddrElement(object): - """ - The RouterAddrElement class is the base class for elements specified in the - Border router topology section. - - :ivar HostAddrBase addr: Host address of a border router. - :ivar str name: element name or id - """ - def __init__(self, addrs=None, name=None): - """ - :param dict public: - ((addr_type, address), overlay_port) of the element's public address. - (i.e. the address visible to other network elements). - :param dict bind: - (addr_type, address) of the element's bind address, if any - (i.e. the address the element uses to identify itself to the local - operating system, if it differs from the public address due to NAT). - :param str name: element name or id - """ - public, bind = self._get_pub_bind(addrs) - self.public = self._parse_addrs(public) - self.bind = self._parse_addrs(bind) - self.name = None - if name is not None: - self.name = str(name) - - def _get_pub_bind(self, addrs): - if addrs is None: - return None, None - pub_bind = addrs.get('IPv6') - if pub_bind is not None: - return pub_bind['PublicOverlay'], pub_bind.get('BindOverlay') - pub_bind = addrs.get('IPv4') - if pub_bind is not None: - return pub_bind['PublicOverlay'], pub_bind.get('BindOverlay') - return None, None - - def _parse_addrs(self, value): - if not value: - return None - return (haddr_parse_interface(value['Addr']), value['OverlayPort']) - - -class InterfaceElement(RouterAddrElement): - """ - The InterfaceElement class represents one of the interfaces of a border - router. - - :ivar int if_id: the interface ID. - :ivar int isd_as: the ISD-AS identifier of the neighbor AS. - :ivar str link_type: the type of relationship to the neighbor AS. - :ivar int to_udp_port: - the port number receiving UDP traffic on the other end of the link. - :ivar int udp_port: the port number used to send UDP traffic. - """ - def __init__(self, if_id, interface_dict, name=None): - """ - :pacam int if_id: interface id - :param dict interface_dict: contains information about the interface. - """ - self.if_id = int(if_id) - self.isd_as = ISD_AS(interface_dict['ISD_AS']) - self.link_type = interface_dict['LinkTo'].lower() - self.bandwidth = interface_dict['Bandwidth'] - self.mtu = interface_dict['MTU'] - self.overlay = interface_dict.get('Overlay') - self.to_if_id = 0 # Filled in later by IFID packets - self.remote = self._parse_addrs(interface_dict.get('RemoteOverlay')) - super().__init__(self._new_addrs(interface_dict), name) - - def _new_addrs(self, interface_dict): - addrs = {} - if not self.overlay: - return None - if 'IPv4' in self.overlay: - addrType = 'IPv4' - else: # Assume IPv6 - addrType = 'IPv6' - addrs[addrType] = {} - addrs[addrType]['PublicOverlay'] = interface_dict['PublicOverlay'] - bind = interface_dict.get('BindOverlay') - if bind is not None: - addrs[addrType]['BindOverlay'] = bind - return addrs - - def __lt__(self, other): # pragma: no cover - return self.if_id < other.if_id - - -class RouterElement(object): - """ - The RouterElement class represents one of the border routers. - """ - def __init__(self, router_dict, name=None): # pragma: no cover - """ - :param dict router_dict: contains information about an border router. - :param str name: router element name or id - """ - self.name = name - self.ctrl_addrs = Element(router_dict['CtrlAddr']) - self.int_addrs = RouterAddrElement(router_dict['InternalAddrs']) - self.interfaces = {} - for if_id, intf in router_dict['Interfaces'].items(): - if_id = int(if_id) - self.interfaces[if_id] = InterfaceElement(if_id, intf) - - def __lt__(self, other): # pragma: no cover - return self.name < other.name - - -class Topology(object): - """ - The Topology class parses the topology file of an AS and stores such - information for further use. - - :ivar ISD_AS isd_is: the ISD-AS identifier. - :ivar list control_servers: control servers in the AS. - :ivar list sigs: SIGs in the as. - :ivar list discovery_servers: discovery servers in the AS. - :ivar list border_routers: border routers in the AS. - :ivar list parent_interfaces: BR interfaces linking to upstream ASes. - :ivar list child_interfaces: BR interfaces linking to downstream ASes. - :ivar list peer_interfaces: BR interfaces linking to peer ASes. - :ivar list core_interfaces: BR interfaces linking to core ASes. - """ - def __init__(self): # pragma: no cover - self.isd_as = None - self.mtu = None - self.control_servers = [] - self.sigs = [] - self.discovery_servers = [] - self.border_routers = [] - self.parent_interfaces = [] - self.child_interfaces = [] - self.peer_interfaces = [] - self.core_interfaces = [] - - @classmethod - def from_file(cls, topology_file): # pragma: no cover - """ - Create a Topology instance from the file. - - :param str topology_file: path to the topology file - """ - return cls.from_dict(load_yaml_file(topology_file)) - - @classmethod - def from_dict(cls, topology_dict): # pragma: no cover - """ - Create a Topology instance from the dictionary. - - :param dict topology_dict: dictionary representation of a topology - :returns: the newly created Topology instance - :rtype: :class:`Topology` - """ - topology = cls() - topology.parse_dict(topology_dict) - return topology - - def parse_dict(self, topology): - """ - Parse a topology dictionary and populate the instance's attributes. - - :param dict topology: dictionary representation of a topology - """ - self.isd_as = ISD_AS(topology['ISD_AS']) - self.mtu = topology['MTU'] - self.overlay = topology['Overlay'] - self._parse_srv_dicts(topology) - self._parse_router_dicts(topology) - - def _parse_srv_dicts(self, topology): - for type_, list_ in ( - ("ControlService", self.control_servers), - ("SIG", self.sigs), - ("DiscoveryService", self.discovery_servers), - ): - for k, v in topology.get(type_, {}).items(): - list_.append(ServerElement(v, k)) - - def _parse_router_dicts(self, topology): - for k, v in topology.get('BorderRouters', {}).items(): - router = RouterElement(v, k) - self.border_routers.append(router) - for intf in router.interfaces.values(): - ntype_map = { - LinkType.PARENT: self.parent_interfaces, - LinkType.CHILD: self.child_interfaces, - LinkType.PEER: self.peer_interfaces, - LinkType.CORE: self.core_interfaces, - } - ntype_map[intf.link_type].append(intf) - - def get_all_interfaces(self): - """ - Return all border router interfaces associated to the AS. - - :returns: all border router interfaces associated to the AS. - :rtype: list - """ - all_interfaces = [] - all_interfaces.extend(self.parent_interfaces) - all_interfaces.extend(self.child_interfaces) - all_interfaces.extend(self.peer_interfaces) - all_interfaces.extend(self.core_interfaces) - return all_interfaces - - def get_own_config(self, server_type, server_id): - type_map = { - ServiceType.CS: self.control_servers, - ServiceType.SIG: self.sigs, - } - try: - target = type_map[server_type] - except KeyError: - logging.critical("Unknown server type: \"%s\"", server_type) - raise SCIONKeyError from None - - for i in target: - if i.name == server_id: - return i - else: - logging.critical("Could not find server: %s", server_id) - raise SCIONKeyError from None diff --git a/python/lib/types.py b/python/lib/types.py index caafa777ff..46021dc16a 100644 --- a/python/lib/types.py +++ b/python/lib/types.py @@ -37,31 +37,6 @@ def all(cls): not callable(getattr(cls, attr))] -############################ -# Basic types -############################ -class AddrType(TypeBase): - NONE = 0 - IPV4 = 1 - IPV6 = 2 - SVC = 3 - - -############################ -# Service types -############################ -class ServiceType(TypeBase): - # these values must be kept in sync with the common.capnp ServiceType enum - #: Unset - UNSET = "unset" - #: Certificate service - CS = "cs" - #: Border router - BR = "br" - #: SCION-IP gateway - SIG = "sig" - - ############################ # Link types ############################ diff --git a/python/lib/util.py b/python/lib/util.py index 002083a5e6..eb8429fd3e 100644 --- a/python/lib/util.py +++ b/python/lib/util.py @@ -19,7 +19,6 @@ """ # Stdlib import os -from datetime import datetime, timezone # External packages import json @@ -28,32 +27,9 @@ # SCION from lib.errors import ( SCIONIOError, - SCIONIndexError, - SCIONParseError, - SCIONTypeError, SCIONYAMLError, ) -TRACE_DIR = 'traces' - - -def read_file(file_path): - """ - Read and return contents of a file. - - :param str file_path: the path to the file. - :returns: the file's contents. - :rtype: str - :raises: - lib.errors.SCIONIOError: error opening/reading from file. - """ - try: - with open(file_path) as file_handler: - return file_handler.read() - except OSError as e: - raise SCIONIOError("Unable to open '%s': %s" % ( - file_path, e.strerror)) from None - def write_file(file_path, text): """ @@ -117,108 +93,3 @@ def load_sciond_file(file_path): """ with open(file_path) as f: return json.load(f) - - -def iso_timestamp(ts): # pragma: no cover - """ - Format a unix timestamp as a UTC ISO 8601 format string - (YYYY-MM-DD HH:MM:SS.mmmmmm+00:00) - - :param float ts: Seconds since the UNIX epoch. - """ - return str(datetime.fromtimestamp(ts, tz=timezone.utc)) - - -class Raw(object): - """A class to wrap raw bytes objects.""" - def __init__(self, data, desc="", len_=None, - min_=False): # pragma: no cover - self._data = data - self._desc = desc - self._len = len_ - self._min = min_ - self._offset = 0 - self.check_type() - self.check_len() - - def check_type(self): - """ - Check that the data is a `bytes` instance. If not, raise an exception. - - :raises: - lib.errors.SCIONTypeError: data is the wrong type - """ - if not isinstance(self._data, bytes): - raise SCIONTypeError( - "Error parsing raw %s: Expected %s, got %s" % - (self._desc, bytes, type(self._data))) - - def check_len(self): - """ - Check that the data is of the expected length. If not, raise an - exception. - - :raises: - lib.errors.SCIONTypeError: data is the wrong length - """ - if self._len is None: - return - if self._min: - if len(self._data) >= self._len: - return - else: - op = ">=" - elif len(self._data) == self._len: - return - else: - op = "==" - raise SCIONParseError( - "Error parsing raw %s: Expected len %s %s, got %s" % - (self._desc, op, self._len, len(self._data))) - - def get(self, n=None, bounds=True): - """ - Return next elements from data. - - If `n` is not specified, return all remaining elements of data. - If `n` is 1, return the next element of data (as an int). - If `n` is > 1, return the next `n` elements of data (as bytes). - - :param n: How many elements to return (see above) - :param bool bounds: Perform bounds checking on access if True - """ - dlen = len(self._data) - if n and bounds and (self._offset + n) > dlen: - raise SCIONIndexError("%s: Attempted to access beyond end of raw " - "data (len=%d, offset=%d, request=%d)" % - (self._desc, dlen, self._offset, n)) - if n is None: - return self._data[self._offset:] - elif n == 1: - return self._data[self._offset] - else: - return self._data[self._offset:self._offset + n] - - def pop(self, n=None, bounds=True): - """ - Return next elements from data, and advance the internal offset. - - Arguments have the same meaning as for Raw.get - """ - ret = self.get(n, bounds) - dlen = len(self._data) - if n is None: - self._offset = dlen - elif n == 1: - self._offset += 1 - else: - self._offset += n - if self._offset > dlen: - self._offset = dlen - return ret - - def offset(self): # pragma: no cover - return self._offset - - def __len__(self): - return max(0, len(self._data) - self._offset) diff --git a/python/test/lib/log_test.py b/python/test/lib/log_test.py deleted file mode 100644 index 69c2a71b56..0000000000 --- a/python/test/lib/log_test.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright 2015 ETH Zurich -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -:mod:`lib_log_test` --- lib.log unit tests -========================================== -""" -# Stdlib -import logging -from unittest.mock import patch, MagicMock, call - -# External packages -import nose -import nose.tools as ntools - -# SCION -from lib.log import ( - LOG_BACKUP_COUNT, - LOG_MAX_SIZE, - _handleError, - init_logging, - log_exception, -) -from test.testcommon import SCIONTestError, assert_these_calls, create_mock - - -class TestHandleError(object): - """ - Unit tests for lib.log._handleError - """ - @patch("lib.log.traceback.format_exc", autospec=True) - def test(self, format_exc): - # Setup - handler = MagicMock(spec_set=["stream", "flush"]) - handler.stream = MagicMock(spec_set=['write']) - handler.stream.write = MagicMock(spec_set=[]) - handler.flush = MagicMock(spec_set=[]) - format_exc.return_value = MagicMock(spec_set=['split']) - format_exc.return_value.split.return_value = ['line0', 'line1'] - # Call - try: - raise SCIONTestError - except SCIONTestError: - ntools.assert_raises(SCIONTestError, _handleError, handler, "hi") - # Tests - ntools.eq_(handler.stream.write.call_count, 3) - handler.flush.assert_called_once_with() - - -class TestInitLogging(object): - """ - Unit tests for lib.log.init_logging - """ - @patch("lib.log.logging.basicConfig", autospec=True) - @patch("lib.log._ConsoleErrorHandler", autospec=True) - @patch("lib.log._RotatingErrorHandler", autospec=True) - @patch("lib.log.DispatchFormatter", autospec=True) - def test_full(self, formatter, rotate, console, basic_config): - levels = "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL" - file_handlers = [ - create_mock(["setLevel", "setFormatter"]), - create_mock(["setLevel", "setFormatter"]), - create_mock(["setLevel", "setFormatter"]), - create_mock(["setLevel", "setFormatter"]), - create_mock(["setLevel", "setFormatter"]), - ] - console_handler = console.return_value - rotate.side_effect = file_handlers - # Call - init_logging("logbase", file_level=logging.DEBUG, - console_level=logging.CRITICAL) - # Tests - rotate_calls = [] - for lvl in levels: - rotate_calls.append(call( - "logbase.%s" % lvl, maxBytes=LOG_MAX_SIZE, - backupCount=LOG_BACKUP_COUNT, encoding="utf-8")) - assert_these_calls(rotate, rotate_calls) - for lvl, f_h in zip(levels, file_handlers): - f_h.setLevel.assert_called_once_with(logging._nameToLevel[lvl]) - f_h.setFormatter.assert_called_once_with(formatter.return_value) - console_handler.setLevel.assert_called_once_with(logging.CRITICAL) - console_handler.setFormatter.assert_called_once_with( - formatter.return_value) - basic_config.assert_called_once_with( - level=logging.DEBUG, handlers=file_handlers + [console_handler] - ) - - @patch("lib.log._RotatingErrorHandler", autospec=True) - @patch("lib.log.logging.basicConfig", autospec=True) - def test_file(self, basic_config, rotate): - # Call - init_logging("logfile", file_level=logging.CRITICAL) - # Tests - basic_config.assert_called_once_with( - level=logging.DEBUG, handlers=[rotate.return_value], - ) - - @patch("lib.log._ConsoleErrorHandler", autospec=True) - @patch("lib.log.logging.basicConfig", autospec=True) - def test_console(self, basic_config, console): - # Call - init_logging(console_level=logging.DEBUG) - # Tests - basic_config.assert_called_once_with( - level=logging.DEBUG, handlers=[console.return_value], - ) - - -class TestLogException(object): - """ - Unit tests for lib.log.log_exception - """ - @patch("lib.log.traceback.format_exc", autospec=True) - @patch("lib.log.logging.log", autospec=True) - def test(self, log, format_exc): - format_exc.return_value = MagicMock(spec_set=['split']) - format_exc.return_value.split.return_value = ['line0', 'line1'] - log_exception('msg', 'arg0', level=123, arg1='arg1') - log.assert_has_calls([call(123, 'msg', 'arg0', arg1='arg1'), - call(123, 'line0'), call(123, 'line1')]) - - @patch("lib.log.traceback.format_exc", autospec=True) - @patch("lib.log.logging.log", autospec=True) - def test_less_arg(self, log, format_exc): - format_exc.return_value = MagicMock(spec_set=['split']) - format_exc.return_value.split.return_value = ['line0', 'line1'] - log_exception('msg', 'arg0', arg1='arg1') - calls = [call(logging.CRITICAL, 'msg', 'arg0', arg1='arg1'), - call(logging.CRITICAL, 'line0'), - call(logging.CRITICAL, 'line1')] - log.assert_has_calls(calls) - - -if __name__ == "__main__": - nose.run(defaultTest=__name__) diff --git a/python/test/lib/topology_test.py b/python/test/lib/topology_test.py deleted file mode 100644 index b7029af9f6..0000000000 --- a/python/test/lib/topology_test.py +++ /dev/null @@ -1,238 +0,0 @@ -# Copyright 2015 ETH Zurich -# Copyright 2018 ETH Zurich, Anapaya Systems -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -:mod:`lib_topology_test` --- SCION topology tests -================================================= -""" -# Stdlib -from collections import defaultdict -from unittest.mock import call, patch - -# External packages -import nose -import nose.tools as ntools - -# SCION -from lib.errors import SCIONKeyError -from lib.topology import ( - Element, - InterfaceElement, - Topology -) -from test.testcommon import assert_these_calls, create_mock - - -class TestElementInit(object): - """ - Unit tests for lib.topology.Element.__init__ - """ - def test_basic(self): - inst = Element() - ntools.assert_equal(inst.public, None) - ntools.assert_is_none(inst.name) - - @patch("lib.topology.haddr_parse_interface", autospec=True) - def test_public(self, parse): - public = {'IPv4': {'Public': {'Addr': 'addr', 'L4Port': 'port'}}} - inst = Element(public) - parse.assert_called_with("addr") - ntools.eq_(inst.public[0], parse.return_value) - ntools.eq_(inst.public[1], 'port') - - @patch("lib.topology.haddr_parse_interface", autospec=True) - def test_bind(self, parse): - bind = {'IPv4': {'Bind': {'Addr': 'pub_addr', 'L4Port': 'port'}, - 'Public': {'Addr': 'bind_addr', 'L4Port': 'port'}}} - inst = Element(bind) - parse.assert_has_calls([call('bind_addr'), call('pub_addr')]) - ntools.eq_(inst.bind[0], parse.return_value) - ntools.eq_(inst.public[1], 'port') - - def test_name(self): - name = create_mock(["__str__"]) - name.__str__.return_value = "hostname" - # Call - inst = Element(name=name) - # Tests - ntools.assert_equal(inst.name, "hostname") - - -class TestInterfaceElementInit(object): - """ - Unit tests for lib.topology.InterfaceElement.__init__ - """ - - @patch("lib.topology.haddr_parse_interface", autospec=True) - @patch("lib.topology.ISD_AS", autospec=True) - @patch("lib.topology.RouterAddrElement.__init__", autospec=True) - def test_full(self, super_init, isd_as, parse): - intf_dict = { - 'Overlay': 'UDP/IPv4', - 'PublicOverlay': { - 'Addr': 'addr', - 'OverlayPort': 6 - }, - 'RemoteOverlay': { - 'Addr': 'toaddr', - 'OverlayPort': 5 - }, - 'Bandwidth': 1001, - 'ISD_AS': '3-ff00:0:301', - 'LinkTo': 'PARENT', - 'MTU': 4242 - } - if_id = 1 - addrs = {'IPv4': {'PublicOverlay': {'Addr': 'addr', 'OverlayPort': 6}}} - # Call - inst = InterfaceElement(if_id, intf_dict, 'name') - # Tests - super_init.assert_called_once_with(inst, addrs, 'name') - ntools.eq_(inst.if_id, 1) - ntools.eq_(inst.isd_as, isd_as.return_value) - ntools.eq_(inst.link_type, "parent") - ntools.eq_(inst.bandwidth, 1001) - ntools.eq_(inst.mtu, 4242) - ntools.eq_(inst.overlay, "UDP/IPv4") - parse.assert_called_once_with("toaddr") - ntools.eq_(inst.remote, (parse.return_value, 5)) - - @patch("lib.topology.haddr_parse_interface", autospec=True) - @patch("lib.topology.ISD_AS", autospec=True) - @patch("lib.topology.RouterAddrElement.__init__", autospec=True) - def test_stripped(self, super_init, isd_as, parse): - intf_dict = { - 'Bandwidth': 1001, - 'ISD_AS': '3-ff00:0:301', - 'LinkTo': 'PARENT', - 'MTU': 4242 - } - if_id = 1 - # Call - inst = InterfaceElement(if_id, intf_dict, 'name') - # Tests - super_init.assert_called_once_with(inst, None, 'name') - ntools.eq_(inst.if_id, 1) - ntools.eq_(inst.isd_as, isd_as.return_value) - ntools.eq_(inst.link_type, "parent") - ntools.eq_(inst.bandwidth, 1001) - ntools.eq_(inst.mtu, 4242) - ntools.eq_(inst.overlay, None) - assert parse.call_count == 0 - ntools.eq_(inst.remote, None) - - -class TestTopologyParseDict(object): - """ - Unit tests for lib.topology.Topology.parse_dict - """ - @patch("lib.topology.ISD_AS", autospec=True) - def test(self, isd_as): - topo_dict = {'Core': True, 'ISD_AS': '1-ff00:0:312', 'MTU': 440, 'Overlay': 'UDP/IPv4'} - inst = Topology() - inst._parse_srv_dicts = create_mock() - inst._parse_router_dicts = create_mock() - # Call - inst.parse_dict(topo_dict) - # Tests - ntools.eq_(inst.isd_as, isd_as.return_value) - ntools.eq_(inst.mtu, 440) - inst._parse_srv_dicts.assert_called_once_with(topo_dict) - inst._parse_router_dicts.assert_called_once_with(topo_dict) - - -class TestTopologyParseSrvDicts(object): - """ - Unit tests for lib.topology.Topology.parse_srv_dicts - """ - @patch("lib.topology.ServerElement", autospec=True) - def test(self, server): - topo_dict = { - 'ControlService': {"cs1": "cs1 val"}, - 'SIG': {"sig1": "sig1 val"}, - } - inst = Topology() - server.side_effect = lambda v, k: "%s-%s" % (k, v) - # Call - inst._parse_srv_dicts(topo_dict) - # Tests - assert_these_calls(server, [ - call("cs1 val", "cs1"), - call("sig1 val", "sig1"), - ], any_order=True) - ntools.eq_(inst.control_servers, ["cs1-cs1 val"]) - - -class TestTopologyParseRouterDicts(object): - """ - Unit tests for lib.topology.Topology.parse_router_dicts - """ - @patch("lib.topology.RouterElement", autospec=True) - def test(self, router): - def _mk_router(type_): - m = create_mock(["interfaces"]) - m.interfaces = {0: create_mock(["link_type"])} - m.interfaces[0].link_type = type_ - routers[type_].append(m) - return m - routers = defaultdict(list) - router_dict = {"br-parent": "parent"} - inst = Topology() - router.side_effect = lambda v, k: _mk_router(v) - # Call - inst._parse_router_dicts({"BorderRouters": router_dict}) - # Tests - ntools.assert_count_equal(inst.border_routers, routers["parent"]) - - -class TestTopologyGetAllInterfaces(object): - """ - Unit tests for lib.topology.Topology.get_all_border_routers - """ - def test(self): - topology = Topology() - topology.parent_interfaces = [0, 1] - topology.child_interfaces = [2] - topology.peer_interfaces = [3, 4, 5] - topology.core_interfaces = [6, 7] - ntools.eq_(topology.get_all_interfaces(), list(range(8))) - - -class TestTopologyGetOwnConfig(object): - """ - Unit tests for lib.topology.Topology.get_own_config - """ - def test_basic(self): - inst = Topology() - for i in range(4): - bs = create_mock(["name"]) - bs.name = "cs%d" % i - inst.control_servers.append(bs) - # Call - ntools.eq_(inst.get_own_config("cs", "cs3"), - inst.control_servers[3]) - - def test_unknown_type(self): - inst = Topology() - # Call - ntools.assert_raises(SCIONKeyError, inst.get_own_config, "asdf", 1) - - def test_unknown_server(self): - inst = Topology() - # Call - ntools.assert_raises(SCIONKeyError, inst.get_own_config, "bs", "name") - - -if __name__ == "__main__": - nose.run(defaultTest=__name__) diff --git a/python/test/lib/util_test.py b/python/test/lib/util_test.py index 3d6f329da7..630956a67e 100644 --- a/python/test/lib/util_test.py +++ b/python/test/lib/util_test.py @@ -17,7 +17,7 @@ """ # Stdlib import builtins -from unittest.mock import patch, mock_open, MagicMock +from unittest.mock import patch, mock_open # External packages import nose @@ -27,36 +27,14 @@ # SCION from lib.errors import ( SCIONIOError, - SCIONIndexError, SCIONYAMLError, - SCIONParseError, - SCIONTypeError, ) from lib.util import ( - Raw, load_yaml_file, - read_file, write_file, ) -class TestReadFile(object): - """ - Unit tests for lib.util.read_file - """ - @patch.object(builtins, 'open', - mock_open(read_data="file contents")) - def test_basic(self): - ntools.eq_(read_file("File_Path"), "file contents") - builtins.open.assert_called_once_with("File_Path") - builtins.open.return_value.read.assert_called_once_with() - - @patch.object(builtins, 'open', mock_open()) - def test_error(self): - builtins.open.side_effect = IsADirectoryError - ntools.assert_raises(SCIONIOError, read_file, "File_Path") - - class TestWriteFile(object): """ Unit tests for lib.util.write_file @@ -129,138 +107,5 @@ def test_json_error(self): ) -class TestRawCheckType(object): - """ - Unit tests for lib.util.Raw.check_type - """ - def test_bytes(self): - inst = MagicMock(spec_set=["_data"]) - inst._data = b"asdf" - Raw.check_type(inst) - - def test_error(self): - inst = MagicMock(spec_set=["_data", "_desc"]) - inst._data = "asdf" - ntools.assert_raises(SCIONTypeError, Raw.check_type, inst) - - -class TestRawCheckLen(object): - """ - Unit tests for lib.util.Raw.check_len - """ - def test_no_len(self): - inst = MagicMock(spec_set=["_len"]) - inst._len = None - Raw.check_len(inst) - - def test_min(self): - inst = MagicMock(spec_set=["_data", "_len", "_min"]) - inst._len = 4 - inst._data = "abcde" - Raw.check_len(inst) - - def test_basic(self): - inst = MagicMock(spec_set=["_data", "_len", "_min"]) - inst._min = False - inst._len = 4 - inst._data = "abcd" - Raw.check_len(inst) - - def test_min_error(self): - inst = MagicMock(spec_set=["_data", "_desc", "_len", "_min"]) - inst._len = 4 - inst._data = "abc" - ntools.assert_raises(SCIONParseError, Raw.check_len, inst) - - def test_basic_error(self): - inst = MagicMock(spec_set=["_data", "_desc", "_len", "_min"]) - inst._min = False - inst._len = 4 - inst._data = "abc" - ntools.assert_raises(SCIONParseError, Raw.check_len, inst) - - -class TestRawGet(object): - """ - Unit tests for lib.util.Raw.get - """ - def _check(self, count, start_off, expected): - # Setup - r = Raw(b"data") - r._offset = start_off - # Call - data = r.get(count) - # Tests - ntools.eq_(data, expected) - - def test(self): - for count, start_off, expected in ( - (None, 0, b"data"), - (None, 2, b"ta"), - (1, 0, 0x64), # "d" - (1, 2, 0x74), # "t" - (2, 0, b"da"), - (2, 2, b"ta"), - ): - yield self._check, count, start_off, expected - - def test_bounds_true(self): - # Setup - r = Raw(b"data") - # Call - ntools.assert_raises(SCIONIndexError, r.get, 100) - - def test_bounds_false(self): - # Setup - r = Raw(b"data") - # Call - r.get(100, bounds=False) - - -class TestRawPop(object): - """ - Unit tests for lib.util.Raw.pop - """ - @patch("lib.util.Raw.get", autospec=True) - def _check(self, pop, start_off, end_off, get): - # Setup - r = Raw(b"data") - r._offset = start_off - # Call - r.pop(pop) - # Tests - get.assert_called_once_with(r, pop, True) - ntools.eq_(r._offset, end_off) - - def test(self): - for pop, start_off, end_off in ( - (None, 0, 4), - (None, 2, 4), - (1, 0, 1), - (1, 2, 3), - (2, 0, 2), - (3, 2, 4), - ): - yield self._check, pop, start_off, end_off - - -class TestRawLen(object): - """ - Unit tests for lib.util.Raw.__len__ - """ - def _check(self, start_off, expected): - # Setup - r = Raw(b"data") - r._offset = start_off - # Check - ntools.eq_(len(r), expected) - - def test(self): - for start_off, expected in ( - (0, 4), (1, 3), (3, 1), (4, 0), (10, 0), - ): - yield self._check, start_off, expected - - if __name__ == "__main__": nose.run(defaultTest=__name__) diff --git a/python/topology/common.py b/python/topology/common.py index abc1203cf2..5e053eca02 100644 --- a/python/topology/common.py +++ b/python/topology/common.py @@ -18,7 +18,7 @@ import sys # SCION -from lib.packet.scion_addr import ISD_AS +from lib.scion_addr import ISD_AS from topology.net import AddressProxy COMMON_DIR = 'endhost' diff --git a/python/topology/config.py b/python/topology/config.py index 16fcf16efe..9a77f534a0 100755 --- a/python/topology/config.py +++ b/python/topology/config.py @@ -30,7 +30,7 @@ DEFAULT6_NETWORK, NETWORKS_FILE, ) -from lib.packet.scion_addr import ISD_AS +from lib.scion_addr import ISD_AS from lib.util import ( load_yaml_file, write_file, @@ -99,10 +99,10 @@ def _ensure_uniq_ases(self): seen = set() for asStr in self.topo_config["ASes"]: ia = ISD_AS(asStr) - if ia[1] in seen: - logging.critical("Non-unique AS Id '%s'", ia[1]) + if ia.as_str() in seen: + logging.critical("Non-unique AS Id '%s'", ia.as_str()) sys.exit(1) - seen.add(ia[1]) + seen.add(ia.as_str()) def _generate_with_topo(self, topo_dicts): self._generate_go(topo_dicts) diff --git a/python/topology/generator.py b/python/topology/generator.py index a954c92c50..746520f567 100755 --- a/python/topology/generator.py +++ b/python/topology/generator.py @@ -22,7 +22,6 @@ # SCION from lib.defines import ( - DEFAULT_SEGMENT_TTL, GEN_PATH, ) from topology.config import ( @@ -43,8 +42,6 @@ def add_arguments(parser): help='Output directory') parser.add_argument('-t', '--trace', action='store_true', help='Enable TRACE level file logging in Go services') - parser.add_argument('--pseg-ttl', type=int, default=DEFAULT_SEGMENT_TTL, - help='Path segment TTL (in seconds)') parser.add_argument('-f', '--svcfrac', type=float, default=0.4, help='Attempt SVC resolution in RPC calls for a fraction of\ available timeout') diff --git a/python/topology/net.py b/python/topology/net.py index 30e0d94224..f6f53590af 100755 --- a/python/topology/net.py +++ b/python/topology/net.py @@ -21,10 +21,10 @@ import math import sys from collections import defaultdict +from ipaddress import ip_interface, ip_network # External packages import yaml -from external.ipaddress import ip_interface, ip_network # SCION from lib.defines import DEFAULT6_NETWORK_ADDR @@ -98,6 +98,7 @@ def alloc_subnets(self): alloc = self._allocations[prefix].pop() # Carve out subnet of the required size new_net = next(alloc.subnets(new_prefix=req_prefix)) + new_net = _workaround_ip_network_hosts_py35(new_net) logging.debug("Allocating %s from %s for subnet size %d" % (new_net, alloc, len(subnet))) networks[new_net] = subnet.alloc_addrs(new_net) @@ -173,3 +174,15 @@ def socket_address_str(ip, port): if ip.version == 4: return "%s:%d" % (ip, port) return "[%s]:%d" % (ip, port) + + +def _workaround_ip_network_hosts_py35(net): + """ + Returns an _identical_ ipaddress.ip_network for which hosts() which will work as it should. + + This works around a regression in python 3.5, where the behaviour of hosts was broken + when using a certain form of the ip_network constructor. + This regression is fixed in python 3.6.6 / 3.7.0. + See https://bugs.python.org/issue27683 + """ + return ip_network('%s/%i' % (net.network_address, net.prefixlen)) diff --git a/python/topology/topo.py b/python/topology/topo.py index 47aec568a2..b653c40978 100755 --- a/python/topology/topo.py +++ b/python/topology/topo.py @@ -36,7 +36,6 @@ SCION_ROUTER_PORT, TOPO_FILE, ) -from lib.topology import Topology from lib.types import LinkType from lib.util import write_file from topology.common import ( @@ -348,8 +347,6 @@ def _write_as_topos(self): contents_json = json.dumps(self.topo_dicts[topo_id], default=json_default, indent=2) write_file(path, contents_json + '\n') - # Test if topo file parses cleanly - Topology.from_file(path) def _write_as_list(self): list_path = os.path.join(self.args.output_dir, AS_LIST_FILE)