diff --git a/partd/compatibility.py b/partd/compatibility.py index 20a217a..c74cd91 100644 --- a/partd/compatibility.py +++ b/partd/compatibility.py @@ -12,3 +12,4 @@ unicode = unicode import cPickle as pickle from Queue import Queue, Empty + diff --git a/partd/pandas.py b/partd/pandas.py index db3a208..6a6da92 100644 --- a/partd/pandas.py +++ b/partd/pandas.py @@ -12,7 +12,6 @@ from .encode import Encode from .utils import extend, framesplit, frame - try: # pandas >= 0.24.0 from pandas.api.types import is_extension_array_dtype @@ -20,6 +19,18 @@ def is_extension_array_dtype(dtype): return False +try: + # Some `ExtensionArray`s can have a `.dtype` which is not a `ExtensionDtype` + # (e.g. they can be backed by a NumPy dtype). For these cases we check + # whether the instance is a `ExtensionArray`. + # https://github.com/dask/partd/issues/48 + from pandas.api.extensions import ExtensionArray + def is_extension_array(x): + return isinstance(x, ExtensionArray) +except ImportError: + def is_extension_array(x): + return False + dumps = partial(pickle.dumps, protocol=pickle.HIGHEST_PROTOCOL) @@ -129,7 +140,7 @@ def block_to_header_bytes(block): elif is_datetime64tz_dtype(block): extension = ('datetime64_tz_type', (block.values.tzinfo,)) values = values.view('i8') - elif is_extension_array_dtype(block.dtype): + elif is_extension_array_dtype(block.dtype) or is_extension_array(values): extension = ("other", ()) else: extension = ('numpy_type', ())