Skip to content

Commit

Permalink
COMPAT: ensure proper extension dtype's don't pickle the cache (#16207)
Browse files Browse the repository at this point in the history
xref #16201
  • Loading branch information
jreback authored May 3, 2017
1 parent 39cc1d0 commit 154a647
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 18 deletions.
28 changes: 25 additions & 3 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class ExtensionDtype(object):
isbuiltin = 0
isnative = 0
_metadata = []
_cache = {}

def __unicode__(self):
return self.name
Expand Down Expand Up @@ -71,6 +72,15 @@ def __eq__(self, other):
def __ne__(self, other):
return not self.__eq__(other)

def __getstate__(self):
# pickle support; we don't want to pickle the cache
return {k: getattr(self, k, None) for k in self._metadata}

@classmethod
def reset_cache(cls):
""" clear the cache """
cls._cache = {}

@classmethod
def is_dtype(cls, dtype):
""" Return a boolean if the passed type is an actual dtype that
Expand Down Expand Up @@ -110,6 +120,7 @@ class CategoricalDtype(ExtensionDtype):
kind = 'O'
str = '|O08'
base = np.dtype('O')
_metadata = []
_cache = {}

def __new__(cls):
Expand Down Expand Up @@ -408,9 +419,15 @@ def __new__(cls, subtype=None):

if isinstance(subtype, IntervalDtype):
return subtype
elif subtype is None or (isinstance(subtype, compat.string_types) and
subtype == 'interval'):
subtype = None
elif subtype is None:
# we are called as an empty constructor
# generally for pickle compat
u = object.__new__(cls)
u.subtype = None
return u
elif (isinstance(subtype, compat.string_types) and
subtype == 'interval'):
subtype = ''
else:
if isinstance(subtype, compat.string_types):
m = cls._match.search(subtype)
Expand All @@ -423,6 +440,11 @@ def __new__(cls, subtype=None):
except TypeError:
raise ValueError("could not construct IntervalDtype")

if subtype is None:
u = object.__new__(cls)
u.subtype = None
return u

try:
return cls._cache[str(subtype)]
except KeyError:
Expand Down
109 changes: 94 additions & 15 deletions pandas/tests/dtypes/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@

class Base(object):

def setup_method(self, method):
self.dtype = self.create()

def test_hash(self):
hash(self.dtype)

Expand All @@ -37,14 +40,38 @@ def test_numpy_informed(self):
assert not np.str_ == self.dtype

def test_pickle(self):
# make sure our cache is NOT pickled

# clear the cache
type(self.dtype).reset_cache()
assert not len(self.dtype._cache)

# force back to the cache
result = tm.round_trip_pickle(self.dtype)
assert not len(self.dtype._cache)
assert result == self.dtype


class TestCategoricalDtype(Base, tm.TestCase):
class TestCategoricalDtype(Base):

def create(self):
return CategoricalDtype()

def test_pickle(self):
# make sure our cache is NOT pickled

# clear the cache
type(self.dtype).reset_cache()
assert not len(self.dtype._cache)

def setUp(self):
self.dtype = CategoricalDtype()
# force back to the cache
result = tm.round_trip_pickle(self.dtype)

# we are a singular object so we are added
# back to the cache upon unpickling
# this is to ensure object identity
assert len(self.dtype._cache) == 1
assert result == self.dtype

def test_hash_vs_equality(self):
# make sure that we satisfy is semantics
Expand Down Expand Up @@ -93,10 +120,10 @@ def test_basic(self):
assert not is_categorical(1.0)


class TestDatetimeTZDtype(Base, tm.TestCase):
class TestDatetimeTZDtype(Base):

def setUp(self):
self.dtype = DatetimeTZDtype('ns', 'US/Eastern')
def create(self):
return DatetimeTZDtype('ns', 'US/Eastern')

def test_hash_vs_equality(self):
# make sure that we satisfy is semantics
Expand Down Expand Up @@ -209,10 +236,24 @@ def test_empty(self):
str(dt)


class TestPeriodDtype(Base, tm.TestCase):
class TestPeriodDtype(Base):

def setUp(self):
self.dtype = PeriodDtype('D')
def create(self):
return PeriodDtype('D')

def test_hash_vs_equality(self):
# make sure that we satisfy is semantics
dtype = self.dtype
dtype2 = PeriodDtype('D')
dtype3 = PeriodDtype(dtype2)
assert dtype == dtype2
assert dtype2 == dtype
assert dtype3 == dtype
assert dtype is dtype2
assert dtype2 is dtype
assert dtype3 is dtype
assert hash(dtype) == hash(dtype2)
assert hash(dtype) == hash(dtype3)

def test_construction(self):
with pytest.raises(ValueError):
Expand Down Expand Up @@ -338,11 +379,37 @@ def test_not_string(self):
assert not is_string_dtype(PeriodDtype('D'))


class TestIntervalDtype(Base, tm.TestCase):
class TestIntervalDtype(Base):

def create(self):
return IntervalDtype('int64')

def test_hash_vs_equality(self):
# make sure that we satisfy is semantics
dtype = self.dtype
dtype2 = IntervalDtype('int64')
dtype3 = IntervalDtype(dtype2)
assert dtype == dtype2
assert dtype2 == dtype
assert dtype3 == dtype
assert dtype is dtype2
assert dtype2 is dtype
assert dtype3 is dtype
assert hash(dtype) == hash(dtype2)
assert hash(dtype) == hash(dtype3)

# TODO: placeholder
def setUp(self):
self.dtype = IntervalDtype('int64')
dtype1 = IntervalDtype('interval')
dtype2 = IntervalDtype(dtype1)
dtype3 = IntervalDtype('interval')
assert dtype2 == dtype1
assert dtype2 == dtype2
assert dtype2 == dtype3
assert dtype2 is dtype1
assert dtype2 is dtype2
assert dtype2 is dtype3
assert hash(dtype2) == hash(dtype1)
assert hash(dtype2) == hash(dtype2)
assert hash(dtype2) == hash(dtype3)

def test_construction(self):
with pytest.raises(ValueError):
Expand All @@ -356,9 +423,9 @@ def test_construction(self):
def test_construction_generic(self):
# generic
i = IntervalDtype('interval')
assert i.subtype is None
assert i.subtype == ''
assert is_interval_dtype(i)
assert str(i) == 'interval'
assert str(i) == 'interval[]'

i = IntervalDtype()
assert i.subtype is None
Expand Down Expand Up @@ -445,3 +512,15 @@ def test_basic_dtype(self):
assert not is_interval_dtype(np.object_)
assert not is_interval_dtype(np.int64)
assert not is_interval_dtype(np.float64)

def test_caching(self):
IntervalDtype.reset_cache()
dtype = IntervalDtype("int64")
assert len(IntervalDtype._cache) == 1

IntervalDtype("interval")
assert len(IntervalDtype._cache) == 2

IntervalDtype.reset_cache()
tm.round_trip_pickle(dtype)
assert len(IntervalDtype._cache) == 0

0 comments on commit 154a647

Please sign in to comment.