From 08f24790daa8bbade46dca86dd2fb9ff0353118c Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Wed, 25 Apr 2018 05:59:17 -0500 Subject: [PATCH] Fixup JSON take --- pandas/tests/extension/json/array.py | 38 ++++++++++++++++++---------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/pandas/tests/extension/json/array.py b/pandas/tests/extension/json/array.py index 95f868e89ac39..be0c904891a2c 100644 --- a/pandas/tests/extension/json/array.py +++ b/pandas/tests/extension/json/array.py @@ -14,6 +14,7 @@ class JSONDtype(ExtensionDtype): type = collections.Mapping name = 'json' + na_value = {} @classmethod def construct_from_string(cls, string): @@ -91,15 +92,30 @@ def nbytes(self): return sys.getsizeof(self.data) def isna(self): - return np.array([x == self._na_value for x in self.data]) - - def take(self, indexer, allow_fill=True, fill_value=None): - try: - output = [self.data[loc] if loc != -1 else self._na_value - for loc in indexer] - except IndexError: - raise IndexError("Index is out of bounds or cannot do a " - "non-empty take from an empty array.") + return np.array([x == self.dtype.na_value for x in self.data]) + + def take(self, indexer, fill_value=None): + # re-implement here, since NumPy has trouble setting + # sized objects like UserDicts into scalar slots of + # an ndarary. + indexer = np.asarray(indexer) + msg = ("Index is out of bounds or cannot do a " + "non-empty take from an empty array.") + + if fill_value is None: + try: + output = [self.data[loc] for loc in indexer] + except IndexError: + raise IndexError(msg) + else: + # bounds check + if (indexer < -1).any(): + raise ValueError + try: + output = [self.data[loc] if loc != -1 else fill_value + for loc in indexer] + except IndexError: + raise msg return self._from_sequence(output) def copy(self, deep=False): @@ -112,10 +128,6 @@ def unique(self): dict(x) for x in list(set(tuple(d.items()) for d in self.data)) ]) - @property - def _na_value(self): - return {} - @classmethod def _concat_same_type(cls, to_concat): data = list(itertools.chain.from_iterable([x.data for x in to_concat]))