Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: ExtensionArray support for objects with _can_hold_na=False and relational operators #20801

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,21 @@ class ExtensionArray(object):
* copy
* _concat_same_type

Some additional methods are available to satisfy pandas' internal, private
block API:
An additional method and attribute is available to satisfy pandas'
internal, private block API.

* _can_hold_na
* _formatting_values
* _can_hold_na

Some methods require casting the ExtensionArray to an ndarray of Python
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge snafu?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I reordered _formatting_values and _can_hold_na, since one is a method and the other is an attribute.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant this section specifically. This block is repeated starting on line 57. Or perhaps I'm missing something.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really weird. The copy on my machine is clean and fine, but the copy on GitHub has the repeat. Maybe because I resolved conflicts with master using GitHub interface. UGH.

objects with ``self.astype(object)``, which may be expensive. When
performance is a concern, we highly recommend overriding the following
methods:

* fillna
* unique
* factorize / _values_for_factorize
* argsort / _values_for_argsort

Some methods require casting the ExtensionArray to an ndarray of Python
objects with ``self.astype(object)``, which may be expensive. When
Expand Down Expand Up @@ -393,7 +403,8 @@ def _values_for_factorize(self):
Returns
-------
values : ndarray
An array suitable for factoraization. This should maintain order

An array suitable for factorization. This should maintain order
and be a supported dtype (Float64, Int64, UInt64, String, Object).
By default, the extension array is cast to object dtype.
na_value : object
Expand All @@ -416,7 +427,7 @@ def factorize(self, na_sentinel=-1):
Returns
-------
labels : ndarray
An interger NumPy array that's an indexer into the original
An integer NumPy array that's an indexer into the original
ExtensionArray.
uniques : ExtensionArray
An ExtensionArray containing the unique values of `self`.
Expand Down Expand Up @@ -560,16 +571,13 @@ def _concat_same_type(cls, to_concat):
"""
raise AbstractMethodError(cls)

@property
def _can_hold_na(self):
# type: () -> bool
"""Whether your array can hold missing values. True by default.
_can_hold_na = True
"""Whether your array can hold missing values. True by default.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what our recommended style is here. You can probably just change it to a comment.


Notes
-----
Setting this to false will optimize some operations like fillna.
"""
return True
Notes
-----
Setting this to False will optimize some operations like fillna.
"""

@property
def _ndarray_values(self):
Expand Down
74 changes: 74 additions & 0 deletions pandas/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# necessary to enforce truediv in Python 2.X
from __future__ import division
import operator
import inspect

import numpy as np
import pandas as pd
Expand All @@ -30,6 +31,7 @@
is_bool_dtype,
is_list_like,
is_scalar,
is_extension_array_dtype,
_ensure_object)
from pandas.core.dtypes.cast import (
maybe_upcast_putmask, find_common_type,
Expand Down Expand Up @@ -1208,6 +1210,78 @@ def wrapper(self, other, axis=None):
return self._constructor(res_values, index=self.index,
name=res_name)

elif is_extension_array_dtype(self):
values = self.values

# GH 20659
# The idea here is as follows. First we see if the op is
# defined in the ExtensionArray subclass, and returns a
# result that is not NotImplemented. If so, we use that
# result. If that fails, then we try to see if the underlying
# dtype has implemented the op, and if not, then we try to
# put the values into a numpy array to see if the optimized
# comparisons will work. Either way, then we try an
# element by element comparison

method = getattr(values, op_name, None)
# First see if the extension array object supports the op
res = NotImplemented
if inspect.ismethod(method):
try:
res = method(other)
except TypeError:
pass
except Exception as e:
raise e

isfunc = inspect.isfunction(getattr(self.dtype.type,
op_name, None))

if res is NotImplemented and not isfunc:
# Try the fast implementation, i.e., see if passing a
# numpy array will allow the optimized comparison to
# work
nvalues = self.get_values()
if isinstance(other, list):
other = np.asarray(other)

try:
with np.errstate(all='ignore'):
res = na_op(nvalues, other)
except TypeError:
pass
except Exception as e:
raise e

if res is NotImplemented:
# Try it on each element. Support comparing to another
# ExtensionArray, or something that is list like, or
# a single object. This allows a result of a comparison
# to be an object as opposed to a boolean
if is_extension_array_dtype(other):
ovalues = other.values
elif is_list_like(other):
ovalues = other
else: # Assume its an object
ovalues = [other] * len(self)

# Get the method for each object.
res = [getattr(a, op_name, None)(b)
for (a, b) in zip(values, ovalues)]

# We can't use (NotImplemented in res) because the
# results might be objects that have overridden __eq__
if any([isinstance(r, type(NotImplemented)) for r in res]):
msg = "invalid type comparison between {one} and {two}"
raise TypeError(msg.format(one=type(values),
two=type(other)))

# At this point we have the result
# always return a full value series here
res_values = com._values_from_object(res)
return self._constructor(res_values, index=self.index,
name=res_name)

elif isinstance(other, ABCSeries):
# By this point we have checked that self._indexed_same(other)
res_values = na_op(self.values, other.values)
Expand Down
10 changes: 6 additions & 4 deletions pandas/tests/extension/base/getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ def test_getitem_scalar(self, data):
assert isinstance(result, data.dtype.type)

def test_getitem_scalar_na(self, data_missing, na_cmp, na_value):
result = data_missing[0]
assert na_cmp(result, na_value)
if data_missing._can_hold_na:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All these changes are a bit unfortunate...

Could we instead have the data_missing fixture raise pytest.skip when necessary? I suspect this would have to be done by the EA author in their code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could make that change to avoid any tests where data_missing is used, and the EA author would have to make that change if they wanted _can_hold_na=False.

result = data_missing[0]
assert na_cmp(result, na_value)

def test_getitem_mask(self, data):
# Empty mask, raw array
Expand Down Expand Up @@ -134,8 +135,9 @@ def test_take(self, data, na_value, na_cmp):

def test_take_empty(self, data, na_value, na_cmp):
empty = data[:0]
result = empty.take([-1])
na_cmp(result[0], na_value)
if data._can_hold_na:
result = empty.take([-1])
na_cmp(result[0], na_value)

with tm.assert_raises_regex(IndexError, "cannot do a non-empty take"):
empty.take([0, 1])
Expand Down
21 changes: 17 additions & 4 deletions pandas/tests/extension/base/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ def test_groupby_extension_agg(self, as_index, data_for_grouping):
_, index = pd.factorize(data_for_grouping, sort=True)
# TODO(ExtensionIndex): remove astype
index = pd.Index(index.astype(object), name="B")
expected = pd.Series([3, 1, 4], index=index, name="A")
if data_for_grouping._can_hold_na:
expected = pd.Series([3, 1, 4], index=index, name="A")
else:
expected = pd.Series([2, 3, 1, 4], index=index, name="A")
if as_index:
self.assert_series_equal(result, expected)
else:
Expand All @@ -41,16 +44,26 @@ def test_groupby_extension_no_sort(self, data_for_grouping):
_, index = pd.factorize(data_for_grouping, sort=False)
# TODO(ExtensionIndex): remove astype
index = pd.Index(index.astype(object), name="B")
expected = pd.Series([1, 3, 4], index=index, name="A")
if data_for_grouping._can_hold_na:
expected = pd.Series([1, 3, 4], index=index, name="A")
else:
expected = pd.Series([1, 2, 3, 4], index=index, name="A")

self.assert_series_equal(result, expected)

def test_groupby_extension_transform(self, data_for_grouping):
valid = data_for_grouping[~data_for_grouping.isna()]
df = pd.DataFrame({"A": [1, 1, 3, 3, 1, 4],
if data_for_grouping._can_hold_na:
dfval = [1, 1, 3, 3, 1, 4]
exres = [3, 3, 2, 2, 3, 1]
else:
dfval = [1, 1, 2, 2, 3, 3, 1, 4]
exres = [3, 3, 2, 2, 2, 2, 3, 1]
df = pd.DataFrame({"A": dfval,
"B": valid})

result = df.groupby("B").A.transform(len)
expected = pd.Series([3, 3, 2, 2, 3, 1], name="A")
expected = pd.Series(exres, name="A")

self.assert_series_equal(result, expected)

Expand Down
19 changes: 13 additions & 6 deletions pandas/tests/extension/base/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ def test_value_counts(self, all_data, dropna):
self.assert_series_equal(result, expected)

def test_count(self, data_missing):
df = pd.DataFrame({"A": data_missing})
result = df.count(axis='columns')
expected = pd.Series([0, 1])
self.assert_series_equal(result, expected)
if data_missing._can_hold_na:
df = pd.DataFrame({"A": data_missing})
result = df.count(axis='columns')
expected = pd.Series([0, 1])
self.assert_series_equal(result, expected)

def test_apply_simple_series(self, data):
result = pd.Series(data).apply(id)
Expand All @@ -40,7 +41,10 @@ def test_argsort(self, data_for_sorting):

def test_argsort_missing(self, data_missing_for_sorting):
result = pd.Series(data_missing_for_sorting).argsort()
expected = pd.Series(np.array([1, -1, 0], dtype=np.int64))
if data_missing_for_sorting._can_hold_na:
expected = pd.Series(np.array([1, -1, 0], dtype=np.int64))
else:
expected = pd.Series(np.array([1, 2, 0], dtype=np.int64))
self.assert_series_equal(result, expected)

@pytest.mark.parametrize('ascending', [True, False])
Expand All @@ -58,7 +62,10 @@ def test_sort_values_missing(self, data_missing_for_sorting, ascending):
ser = pd.Series(data_missing_for_sorting)
result = ser.sort_values(ascending=ascending)
if ascending:
expected = ser.iloc[[2, 0, 1]]
if data_missing_for_sorting._can_hold_na:
expected = ser.iloc[[2, 0, 1]]
else:
expected = ser.iloc[[1, 2, 0]]
else:
expected = ser.iloc[[0, 2, 1]]
self.assert_series_equal(result, expected)
Expand Down
52 changes: 41 additions & 11 deletions pandas/tests/extension/base/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,39 @@ def test_isna(self, data_missing):
def test_dropna_series(self, data_missing):
ser = pd.Series(data_missing)
result = ser.dropna()
expected = ser.iloc[[1]]
if data_missing._can_hold_na:
expected = ser.iloc[[1]]
else:
expected = ser
self.assert_series_equal(result, expected)

def test_dropna_frame(self, data_missing):
df = pd.DataFrame({"A": data_missing})

# defaults
result = df.dropna()
expected = df.iloc[[1]]
if data_missing._can_hold_na:
expected = df.iloc[[1]]
else:
expected = df
self.assert_frame_equal(result, expected)

# axis = 1
result = df.dropna(axis='columns')
expected = pd.DataFrame(index=[0, 1])
if data_missing._can_hold_na:
expected = pd.DataFrame(index=[0, 1])
else:
expected = df
self.assert_frame_equal(result, expected)

# multiple
df = pd.DataFrame({"A": data_missing,
"B": [1, np.nan]})
result = df.dropna()
expected = df.iloc[:0]
if data_missing._can_hold_na:
expected = df.iloc[:0]
else:
expected = df.iloc[:1]
self.assert_frame_equal(result, expected)

def test_fillna_scalar(self, data_missing):
Expand All @@ -56,22 +68,32 @@ def test_fillna_scalar(self, data_missing):
def test_fillna_limit_pad(self, data_missing):
arr = data_missing.take([1, 0, 0, 0, 1])
result = pd.Series(arr).fillna(method='ffill', limit=2)
expected = pd.Series(data_missing.take([1, 1, 1, 0, 1]))
if data_missing._can_hold_na:
expected = pd.Series(data_missing.take([1, 1, 1, 0, 1]))
else:
expected = pd.Series(arr)
self.assert_series_equal(result, expected)

def test_fillna_limit_backfill(self, data_missing):
arr = data_missing.take([1, 0, 0, 0, 1])
result = pd.Series(arr).fillna(method='backfill', limit=2)
expected = pd.Series(data_missing.take([1, 0, 1, 1, 1]))
if data_missing._can_hold_na:
expected = pd.Series(data_missing.take([1, 0, 1, 1, 1]))
else:
expected = pd.Series(arr)
self.assert_series_equal(result, expected)

def test_fillna_series(self, data_missing):

fill_value = data_missing[1]
ser = pd.Series(data_missing)

result = ser.fillna(fill_value)
expected = pd.Series(
data_missing._from_sequence([fill_value, fill_value]))
if data_missing._can_hold_na:
expected = pd.Series(
data_missing._from_sequence([fill_value, fill_value]))
else:
expected = ser
self.assert_series_equal(result, expected)

# Fill with a series
Expand All @@ -90,8 +112,11 @@ def test_fillna_series_method(self, data_missing, method):
data_missing = type(data_missing)(data_missing[::-1])

result = pd.Series(data_missing).fillna(method=method)
expected = pd.Series(
data_missing._from_sequence([fill_value, fill_value]))
if data_missing._can_hold_na:
expected = pd.Series(
data_missing._from_sequence([fill_value, fill_value]))
else:
expected = pd.Series(data_missing)

self.assert_series_equal(result, expected)

Expand All @@ -103,8 +128,13 @@ def test_fillna_frame(self, data_missing):
"B": [1, 2]
}).fillna(fill_value)

if data_missing._can_hold_na:
a = data_missing._from_sequence([fill_value, fill_value])
else:
a = data_missing

expected = pd.DataFrame({
"A": data_missing._from_sequence([fill_value, fill_value]),
"A": a,
"B": [1, 2],
})

Expand Down
Loading