Skip to content

Commit

Permalink
[WIP]: ExtensionArray.take default implementation
Browse files Browse the repository at this point in the history
Implements a take interface that's compatible with NumPy and optionally pandas'
NA semantics.

```python
In [1]: import pandas as pd

In [2]: from pandas.tests.extension.decimal.array import *

In [3]: arr = DecimalArray(['1.1', '1.2', '1.3'])

In [4]: arr.take([0, 1, -1])
Out[4]: DecimalArray(array(['1.1', '1.2', '1.3'], dtype=object))

In [5]: arr.take([0, 1, -1], fill_value=float('nan'))
Out[5]: DecimalArray(array(['1.1', '1.2', Decimal('NaN')], dtype=object))
```

Closes pandas-dev#20640
  • Loading branch information
TomAugspurger committed Apr 24, 2018
1 parent 8bee97a commit fb3c234
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 57 deletions.
2 changes: 2 additions & 0 deletions pandas/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,3 +451,5 @@ def is_platform_mac():

def is_platform_32bit():
return struct.calcsize("P") * 8 < 64

_default_fill_value = object()
78 changes: 77 additions & 1 deletion pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
_ensure_platform_int, _ensure_object,
_ensure_float64, _ensure_uint64,
_ensure_int64)
from pandas.compat import _default_fill_value
from pandas.compat.numpy import _np_version_under1p10
from pandas.core.dtypes.missing import isna, na_value_for_dtype

Expand Down Expand Up @@ -1482,7 +1483,7 @@ def take_nd(arr, indexer, axis=0, out=None, fill_value=np.nan, mask_info=None,
# TODO(EA): Remove these if / elifs as datetimeTZ, interval, become EAs
# dispatch to internal type takes
if is_extension_array_dtype(arr):
return arr.take(indexer, fill_value=fill_value, allow_fill=allow_fill)
return arr.take(indexer, fill_value=fill_value)
elif is_datetimetz(arr):
return arr.take(indexer, fill_value=fill_value, allow_fill=allow_fill)
elif is_interval_dtype(arr):
Expand Down Expand Up @@ -1558,6 +1559,81 @@ def take_nd(arr, indexer, axis=0, out=None, fill_value=np.nan, mask_info=None,
take_1d = take_nd


def take_ea(arr, indexer, fill_value=_default_fill_value):
"""Extension-array compatible take.
Parameters
----------
arr : array-like
Must satisify NumPy's take semantics.
indexer : sequence of integers
Indices to be taken. See Notes for how negative indicies
are handled.
fill_value : any, optional
Fill value to use for NA-indicies. This has a few behaviors.
* fill_value is not specified : triggers NumPy's semantics
where negative values in `indexer` mean slices from the end.
* fill_value is NA : Fill positions where `indexer` is ``-1``
with ``self.dtype.na_value``. Anything considered NA by
:func:`pandas.isna` will result in ``self.dtype.na_value``
being used to fill.
* fill_value is not NA : Fill positions where `indexer` is ``-1``
with `fill_value`.
Returns
-------
ExtensionArray
Raises
------
IndexError
When the indexer is out of bounds for the array.
ValueError
When the indexer contains negative values other than ``-1``
and `fill_value` is specified.
Notes
-----
The meaning of negative values in `indexer` depends on the
`fill_value` argument. By default, we follow the behavior
:meth:`numpy.take` of where negative indices indicate slices
from the end.
When `fill_value` is specified, we follow pandas semantics of ``-1``
indicating a missing value. In this case, positions where `indexer`
is ``-1`` will be filled with `fill_value` or the default NA value
for this type.
ExtensionArray.take is called by ``Series.__getitem__``, ``.loc``,
``iloc``, when the indexer is a sequence of values. Additionally,
it's called by :meth:`Series.reindex` with a `fill_value`.
See Also
--------
numpy.take
"""
indexer = np.asarray(indexer)
if fill_value is _default_fill_value:
# NumPy style
result = arr.take(indexer)
else:
mask = indexer == -1
if (indexer < -1).any():
raise ValueError("Invalid value in 'indexer'. All values "
"must be non-negative or -1. When "
"'fill_value' is specified.")

# take on empty array not handled as desired by numpy
# in case of -1 (all missing take)
if not len(arr) and mask.all():
return arr._from_sequence([fill_value] * len(indexer))

result = arr.take(indexer)
result[mask] = fill_value
return result


def take_2d_multi(arr, indexer, out=None, fill_value=np.nan, mask_info=None,
allow_fill=True):
"""
Expand Down
91 changes: 52 additions & 39 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np

from pandas.errors import AbstractMethodError
from pandas.compat import _default_fill_value
from pandas.compat.numpy import function as nv

_not_implemented_message = "{} does not implement {}."
Expand Down Expand Up @@ -53,6 +54,7 @@ class ExtensionArray(object):
* unique
* factorize / _values_for_factorize
* argsort / _values_for_argsort
* take / _values_for_take
This class does not inherit from 'abc.ABCMeta' for performance reasons.
Methods and properties required by the interface raise
Expand Down Expand Up @@ -462,22 +464,38 @@ def factorize(self, na_sentinel=-1):
# ------------------------------------------------------------------------
# Indexing methods
# ------------------------------------------------------------------------
def take(self, indexer, allow_fill=True, fill_value=None):
# type: (Sequence[int], bool, Optional[Any]) -> ExtensionArray
def _values_for_take(self):
"""Values to use for `take`.
Coerces to object dtype by default.
Returns
-------
array-like
Must satisify NumPy's `take` semantics.
"""
return self.astype(object)

def take(self, indexer, fill_value=_default_fill_value):
# type: (Sequence[int], Optional[Any]) -> ExtensionArray
"""Take elements from an array.
Parameters
----------
indexer : sequence of integers
indices to be taken. -1 is used to indicate values
that are missing.
allow_fill : bool, default True
If False, indexer is assumed to contain no -1 values so no filling
will be done. This short-circuits computation of a mask. Result is
undefined if allow_fill == False and -1 is present in indexer.
fill_value : any, default None
Fill value to replace -1 values with. If applicable, this should
use the sentinel missing value for this type.
Indices to be taken. See Notes for how negative indicies
are handled.
fill_value : any, optional
Fill value to use for NA-indicies. This has a few behaviors.
* fill_value is not specified : triggers NumPy's semantics
where negative values in `indexer` mean slices from the end.
* fill_value is NA : Fill positions where `indexer` is ``-1``
with ``self.dtype.na_value``. Anything considered NA by
:func:`pandas.isna` will result in ``self.dtype.na_value``
being used to fill.
* fill_value is not NA : Fill positions where `indexer` is ``-1``
with `fill_value`.
Returns
-------
Expand All @@ -487,44 +505,39 @@ def take(self, indexer, allow_fill=True, fill_value=None):
------
IndexError
When the indexer is out of bounds for the array.
ValueError
When the indexer contains negative values other than ``-1``
and `fill_value` is specified.
Notes
-----
This should follow pandas' semantics where -1 indicates missing values.
Positions where indexer is ``-1`` should be filled with the missing
value for this type.
This gives rise to the special case of a take on an empty
ExtensionArray that does not raises an IndexError straight away
when the `indexer` is all ``-1``.
This is called by ``Series.__getitem__``, ``.loc``, ``iloc``, when the
indexer is a sequence of values.
The meaning of negative values in `indexer` depends on the
`fill_value` argument. By default, we follow the behavior
:meth:`numpy.take` of where negative indices indicate slices
from the end.
Examples
--------
Suppose the extension array is backed by a NumPy array stored as
``self.data``. Then ``take`` may be written as
.. code-block:: python
def take(self, indexer, allow_fill=True, fill_value=None):
indexer = np.asarray(indexer)
mask = indexer == -1
# take on empty array not handled as desired by numpy
# in case of -1 (all missing take)
if not len(self) and mask.all():
return type(self)([np.nan] * len(indexer))
When `fill_value` is specified, we follow pandas semantics of ``-1``
indicating a missing value. In this case, positions where `indexer`
is ``-1`` will be filled with `fill_value` or the default NA value
for this type.
result = self.data.take(indexer)
result[mask] = np.nan # NA for this type
return type(self)(result)
ExtensionArray.take is called by ``Series.__getitem__``, ``.loc``,
``iloc``, when the indexer is a sequence of values. Additionally,
it's called by :meth:`Series.reindex` with a `fill_value`.
See Also
--------
numpy.take
"""
raise AbstractMethodError(self)
from pandas.core.algorithms import take_ea
from pandas.core.missing import isna

if isna(fill_value):
fill_value = self.dtype.na_value

data = self._values_for_take()
result = take_ea(data, indexer, fill_value=fill_value)
return self._from_sequence(result)

def copy(self, deep=False):
# type: (bool) -> ExtensionArray
Expand Down
4 changes: 4 additions & 0 deletions pandas/core/dtypes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ class _DtypeOpsMixin(object):
# classes will inherit from this Mixin. Once everything is compatible, this
# class's methods can be moved to ExtensionDtype and removed.

# na_value is the default NA value to use for this type. This is used in
# e.g. ExtensionArray.take.
na_value = np.nan

def __eq__(self, other):
"""Check whether 'other' is equal to self.
Expand Down
2 changes: 2 additions & 0 deletions pandas/core/dtypes/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,8 @@ def na_value_for_dtype(dtype, compat=True):
"""
dtype = pandas_dtype(dtype)

if is_extension_array_dtype(dtype):
return dtype.na_value
if (is_datetime64_dtype(dtype) or is_datetime64tz_dtype(dtype) or
is_timedelta64_dtype(dtype) or is_period_dtype(dtype)):
return NaT
Expand Down
21 changes: 19 additions & 2 deletions pandas/tests/extension/base/getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,29 @@ def test_take(self, data, na_value, na_cmp):

def test_take_empty(self, data, na_value, na_cmp):
empty = data[:0]
result = empty.take([-1])
na_cmp(result[0], na_value)
# result = empty.take([-1])
# na_cmp(result[0], na_value)

with tm.assert_raises_regex(IndexError, "cannot do a non-empty take"):
empty.take([0, 1])

def test_take_negative(self, data):
# https://github.com/pandas-dev/pandas/issues/20640
n = len(data)
result = data.take([0, -n, n - 1, -1])
expected = data.take([0, 0, n - 1, n - 1])
self.assert_extension_array_equal(result, expected)

def test_take_non_na_fill_value(self, data_missing):
fill_value = data_missing[1] # valid
result = data_missing.take([-1, 1], fill_value=fill_value)
expected = data_missing.take([1, 1])
self.assert_extension_array_equal(result, expected)

def test_take_pandas_style_negative_raises(self, data, na_value):
with pytest.raises(ValueError):
data.take([0, -2], fill_value=na_value)

@pytest.mark.xfail(reason="Series.take with extension array buggy for -1")
def test_take_series(self, data):
s = pd.Series(data)
Expand Down
16 changes: 15 additions & 1 deletion pandas/tests/extension/category/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,18 +81,32 @@ def test_merge(self, data, na_value):


class TestGetitem(base.BaseGetitemTests):
skip_take = pytest.mark.skip(reason="GH-20664.")

@pytest.mark.skip(reason="Backwards compatibility")
def test_getitem_scalar(self):
# CategoricalDtype.type isn't "correct" since it should
# be a parent of the elements (object). But don't want
# to break things by changing.
pass

@pytest.mark.xfail(reason="Categorical.take buggy")
@skip_take
def test_take(self):
# TODO remove this once Categorical.take is fixed
pass

@skip_take
def test_take_negative(self):
pass

@skip_take
def test_take_pandas_style_negative_raises(self):
pass

@skip_take
def test_take_non_na_fill_value(self):
pass

@pytest.mark.xfail(reason="Categorical.take buggy")
def test_take_empty(self):
pass
Expand Down
17 changes: 3 additions & 14 deletions pandas/tests/extension/decimal/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import pandas as pd
from pandas.core.arrays import ExtensionArray
from pandas.core.dtypes.base import ExtensionDtype
from pandas.core.dtypes.common import _ensure_platform_int


class DecimalDtype(ExtensionDtype):
type = decimal.Decimal
name = 'decimal'
na_value = decimal.Decimal('NaN')

@classmethod
def construct_from_string(cls, string):
Expand Down Expand Up @@ -80,19 +80,8 @@ def nbytes(self):
def isna(self):
return np.array([x.is_nan() for x in self._data])

def take(self, indexer, allow_fill=True, fill_value=None):
indexer = np.asarray(indexer)
mask = indexer == -1

# take on empty array not handled as desired by numpy in case of -1
if not len(self) and mask.all():
return type(self)([self._na_value] * len(indexer))

indexer = _ensure_platform_int(indexer)
out = self._data.take(indexer)
out[mask] = self._na_value

return type(self)(out)
def _values_for_take(self):
return self.data

@property
def _na_value(self):
Expand Down

0 comments on commit fb3c234

Please sign in to comment.