diff --git a/dns/name.py b/dns/name.py index ef812cfa6..5288e1eab 100644 --- a/dns/name.py +++ b/dns/name.py @@ -488,9 +488,6 @@ def __len__(self): def __getitem__(self, index): return self.labels[index] - def __getslice__(self, start, stop): - return self.labels[start:stop] - def __add__(self, other): return self.concatenate(other) diff --git a/dns/resolver.py b/dns/resolver.py index 5bd1e8d86..c9a7c78d4 100644 --- a/dns/resolver.py +++ b/dns/resolver.py @@ -282,9 +282,6 @@ def __getitem__(self, i): def __delitem__(self, i): del self.rrset[i] - def __getslice__(self, i, j): - return self.rrset[i:j] - class Cache(object): diff --git a/dns/set.py b/dns/set.py index 0efc7d9b9..ef7fd2955 100644 --- a/dns/set.py +++ b/dns/set.py @@ -232,9 +232,6 @@ def __getitem__(self, i): def __delitem__(self, i): del self.items[i] - def __getslice__(self, i, j): - return self.items[i:j] - def issubset(self, other): """Is I{self} a subset of I{other}? diff --git a/dns/wiredata.py b/dns/wiredata.py index b381f7b92..ccef59545 100644 --- a/dns/wiredata.py +++ b/dns/wiredata.py @@ -15,6 +15,7 @@ """DNS Wire Data Helper""" +import sys import dns.exception from ._compat import binary_type, string_types @@ -26,12 +27,16 @@ # out what constant Python will use. -class _SliceUnspecifiedBound(str): +class _SliceUnspecifiedBound(binary_type): - def __getslice__(self, i, j): - return j + def __getitem__(self, key): + return key.stop + + if sys.version_info < (3,): + def __getslice__(self, i, j): # pylint: disable=getslice-method + return self.__getitem__(slice(i, j)) -_unspecified_bound = _SliceUnspecifiedBound('')[1:] +_unspecified_bound = _SliceUnspecifiedBound()[1:] class WireData(binary_type): @@ -40,26 +45,40 @@ class WireData(binary_type): def __getitem__(self, key): try: if isinstance(key, slice): - return WireData(super(WireData, self).__getitem__(key)) + # make sure we are not going outside of valid ranges, + # do stricter control of boundaries than python does + # by default + start = key.start + stop = key.stop + + if sys.version_info < (3,): + if stop == _unspecified_bound: + # handle the case where the right bound is unspecified + stop = len(self) + + if start < 0 or stop < 0: + raise dns.exception.FormError + # If it's not an empty slice, access left and right bounds + # to make sure they're valid + if start != stop: + super(WireData, self).__getitem__(start) + super(WireData, self).__getitem__(stop - 1) + else: + for index in (start, stop): + if index is None: + continue + elif abs(index) > len(self): + raise dns.exception.FormError + + return WireData(super(WireData, self).__getitem__( + slice(start, stop))) return bytearray(self.unwrap())[key] except IndexError: raise dns.exception.FormError - def __getslice__(self, i, j): - try: - if j == _unspecified_bound: - # handle the case where the right bound is unspecified - j = len(self) - if i < 0 or j < 0: - raise dns.exception.FormError - # If it's not an empty slice, access left and right bounds - # to make sure they're valid - if i != j: - super(WireData, self).__getitem__(i) - super(WireData, self).__getitem__(j - 1) - return WireData(super(WireData, self).__getslice__(i, j)) - except IndexError: - raise dns.exception.FormError + if sys.version_info < (3,): + def __getslice__(self, i, j): # pylint: disable=getslice-method + return self.__getitem__(slice(i, j)) def __iter__(self): i = 0 diff --git a/pylintrc b/pylintrc index c37ac1ea3..3f16509d0 100644 --- a/pylintrc +++ b/pylintrc @@ -23,7 +23,6 @@ disable= bare-except, deprecated-method, fixme, - getslice-method, global-statement, invalid-name, missing-docstring, diff --git a/tests/test_wiredata.py b/tests/test_wiredata.py new file mode 100644 index 000000000..eccc3e243 --- /dev/null +++ b/tests/test_wiredata.py @@ -0,0 +1,126 @@ +# Copyright (C) 2016 +# Author: Martin Basti +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. + +try: + import unittest2 as unittest +except ImportError: + import unittest + +from dns.exception import FormError +from dns.wiredata import WireData + + +class WireDataSlicingTestCase(unittest.TestCase): + + def testSliceAll(self): + """Get all data""" + inst = WireData(b'0123456789') + self.assertEqual(inst[:], WireData(b'0123456789')) + + def testSliceAllExplicitlyDefined(self): + """Get all data""" + inst = WireData(b'0123456789') + self.assertEqual(inst[0:10], WireData(b'0123456789')) + + def testSliceLowerHalf(self): + """Get lower half of data""" + inst = WireData(b'0123456789') + self.assertEqual(inst[:5], WireData(b'01234')) + + def testSliceLowerHalfWithNegativeIndex(self): + """Get lower half of data""" + inst = WireData(b'0123456789') + self.assertEqual(inst[:-5], WireData(b'01234')) + + def testSliceUpperHalf(self): + """Get upper half of data""" + inst = WireData(b'0123456789') + self.assertEqual(inst[5:], WireData(b'56789')) + + def testSliceMiddle(self): + """Get data from middle""" + inst = WireData(b'0123456789') + self.assertEqual(inst[3:6], WireData(b'345')) + + def testSliceMiddleWithNegativeIndex(self): + """Get data from middle""" + inst = WireData(b'0123456789') + self.assertEqual(inst[-6:-3], WireData(b'456')) + + def testSliceMiddleWithMixedIndex(self): + """Get data from middle""" + inst = WireData(b'0123456789') + self.assertEqual(inst[-8:3], WireData(b'2')) + self.assertEqual(inst[5:-3], WireData(b'56')) + + def testGetOne(self): + """Get data one by one item""" + data = b'0123456789' + inst = WireData(data) + for i, byte in enumerate(bytearray(data)): + self.assertEqual(inst[i], byte) + for i in range(-1, len(data) * -1, -1): + self.assertEqual(inst[i], bytearray(data)[i]) + + def testEmptySlice(self): + """Test empty slice""" + data = b'0123456789' + inst = WireData(data) + for i, byte in enumerate(data): + self.assertEqual(inst[i:i], b'') + for i in range(-1, len(data) * -1, -1): + self.assertEqual(inst[i:i], b'') + self.assertEqual(inst[-3:-6], b'') + + def testSliceStartOutOfLowerBorder(self): + """Get data from out of lower border""" + inst = WireData(b'0123456789') + with self.assertRaises(FormError): + inst[-11:] # pylint: disable=pointless-statement + + def testSliceStopOutOfLowerBorder(self): + """Get data from out of lower border""" + inst = WireData(b'0123456789') + with self.assertRaises(FormError): + inst[:-11] # pylint: disable=pointless-statement + + def testSliceBothOutOfLowerBorder(self): + """Get data from out of lower border""" + inst = WireData(b'0123456789') + with self.assertRaises(FormError): + inst[-12:-11] # pylint: disable=pointless-statement + + def testSliceStartOutOfUpperBorder(self): + """Get data from out of upper border""" + inst = WireData(b'0123456789') + with self.assertRaises(FormError): + inst[11:] # pylint: disable=pointless-statement + + def testSliceStopOutOfUpperBorder(self): + """Get data from out of upper border""" + inst = WireData(b'0123456789') + with self.assertRaises(FormError): + inst[:11] # pylint: disable=pointless-statement + + def testSliceBothOutOfUpperBorder(self): + """Get data from out of lower border""" + inst = WireData(b'0123456789') + with self.assertRaises(FormError): + inst[10:20] # pylint: disable=pointless-statement + + def testGetOneOutOfLowerBorder(self): + """Get item outside of range""" + inst = WireData(b'0123456789') + with self.assertRaises(FormError): + inst[-11] # pylint: disable=pointless-statement + + def testGetOneOutOfUpperBorder(self): + """Get item outside of range""" + inst = WireData(b'0123456789') + with self.assertRaises(FormError): + inst[10] # pylint: disable=pointless-statement