Skip to content

Commit

Permalink
BUG: Array.__setitem__ failing with nullable boolean mask (#31484) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesdong1991 authored Feb 2, 2020
1 parent 2181824 commit 4082c46
Show file tree
Hide file tree
Showing 11 changed files with 45 additions and 0 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.0.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ Indexing

-
-
- Bug where assigning to a :class:`Series` using a IntegerArray / BooleanArray as a mask would raise ``TypeError`` (:issue:`31446`)

Missing
^^^^^^^
Expand Down
2 changes: 2 additions & 0 deletions pandas/core/arrays/boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pandas.core.dtypes.missing import isna, notna

from pandas.core import nanops, ops
from pandas.core.indexers import check_array_indexer

from .masked import BaseMaskedArray

Expand Down Expand Up @@ -369,6 +370,7 @@ def __setitem__(self, key, value):
value = value[0]
mask = mask[0]

key = check_array_indexer(self, key)
self._data[key] = value
self._mask[key] = mask

Expand Down
2 changes: 2 additions & 0 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2073,6 +2073,8 @@ def __setitem__(self, key, value):

lindexer = self.categories.get_indexer(rvalue)
lindexer = self._maybe_coerce_indexer(lindexer)

key = check_array_indexer(self, key)
self._codes[key] = lindexer

def _reverse_indexer(self) -> Dict[Hashable, np.ndarray]:
Expand Down
2 changes: 2 additions & 0 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,8 @@ def __setitem__(
f"or array of those. Got '{type(value).__name__}' instead."
)
raise TypeError(msg)

key = check_array_indexer(self, key)
self._data[key] = value
self._maybe_clear_freq()

Expand Down
2 changes: 2 additions & 0 deletions pandas/core/arrays/integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pandas.core.dtypes.missing import isna

from pandas.core import nanops, ops
from pandas.core.indexers import check_array_indexer
from pandas.core.ops import invalid_comparison
from pandas.core.ops.common import unpack_zerodim_and_defer
from pandas.core.tools.numeric import to_numeric
Expand Down Expand Up @@ -414,6 +415,7 @@ def __setitem__(self, key, value):
value = value[0]
mask = mask[0]

key = check_array_indexer(self, key)
self._data[key] = value
self._mask[key] = mask

Expand Down
1 change: 1 addition & 0 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,7 @@ def __setitem__(self, key, value):
msg = f"'value' should be an interval type, got {type(value)} instead."
raise TypeError(msg)

key = check_array_indexer(self, key)
# Need to ensure that left and right are updated atomically, so we're
# forced to copy, update the copy, and swap in the new values.
left = self.left.copy(deep=True)
Expand Down
1 change: 1 addition & 0 deletions pandas/core/arrays/numpy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def __getitem__(self, item):
def __setitem__(self, key, value):
value = extract_array(value, extract_numpy=True)

key = check_array_indexer(self, key)
scalar_key = lib.is_scalar(key)
scalar_value = lib.is_scalar(value)

Expand Down
2 changes: 2 additions & 0 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pandas.core import ops
from pandas.core.arrays import PandasArray
from pandas.core.construction import extract_array
from pandas.core.indexers import check_array_indexer
from pandas.core.missing import isna


Expand Down Expand Up @@ -224,6 +225,7 @@ def __setitem__(self, key, value):
# extract_array doesn't extract PandasArray subclasses
value = value._ndarray

key = check_array_indexer(self, key)
scalar_key = lib.is_scalar(key)
scalar_value = lib.is_scalar(value)
if scalar_key and not scalar_value:
Expand Down
17 changes: 17 additions & 0 deletions pandas/tests/arrays/test_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,6 +1072,23 @@ def test_cut(bins, right, include_lowest):
tm.assert_categorical_equal(result, expected)


def test_array_setitem_nullable_boolean_mask():
# GH 31446
ser = pd.Series([1, 2], dtype="Int64")
result = ser.where(ser > 1)
expected = pd.Series([pd.NA, 2], dtype="Int64")
tm.assert_series_equal(result, expected)


def test_array_setitem():
# GH 31446
arr = pd.Series([1, 2], dtype="Int64").array
arr[arr > 1] = 1

expected = pd.array([1, 1], dtype="Int64")
tm.assert_extension_array_equal(arr, expected)


# TODO(jreback) - these need testing / are broken

# shift
Expand Down
12 changes: 12 additions & 0 deletions pandas/tests/extension/base/setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

import pandas as pd
from pandas.core.arrays.numpy_ import PandasDtype

from .base import BaseExtensionTests

Expand Down Expand Up @@ -195,3 +196,14 @@ def test_setitem_preserves_views(self, data):
data[0] = data[1]
assert view1[0] == data[1]
assert view2[0] == data[1]

def test_setitem_nullable_mask(self, data):
# GH 31446
# TODO: there is some issue with PandasArray, therefore,
# TODO: skip the setitem test for now, and fix it later
if data.dtype != PandasDtype("object"):
arr = data[:5]
expected = data.take([0, 0, 0, 3, 4])
mask = pd.array([True, True, True, False, False])
arr[mask] = data[0]
self.assert_extension_array_equal(expected, arr)
3 changes: 3 additions & 0 deletions pandas/tests/extension/decimal/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pandas as pd
from pandas.api.extensions import no_default, register_extension_dtype
from pandas.core.arrays import ExtensionArray, ExtensionScalarOpsMixin
from pandas.core.indexers import check_array_indexer


@register_extension_dtype
Expand Down Expand Up @@ -144,6 +145,8 @@ def __setitem__(self, key, value):
value = [decimal.Decimal(v) for v in value]
else:
value = decimal.Decimal(value)

key = check_array_indexer(self, key)
self._data[key] = value

def __len__(self) -> int:
Expand Down

0 comments on commit 4082c46

Please sign in to comment.