Skip to content

Commit

Permalink
PERF: improve efficiency of BaseMaskedArray.__setitem__
Browse files Browse the repository at this point in the history
This somewhat deals with #44172, though that won't be fully resolved until 2D `ExtensionArray`s are supported (per the comments there).
  • Loading branch information
alexreg committed Oct 28, 2021
1 parent 7c00e0c commit d468b36
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 15 deletions.
6 changes: 3 additions & 3 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def frame_apply(
args=None,
kwargs=None,
) -> FrameApply:
"""construct and return a row or column based frame apply object"""
"""Construct and return a row- or column-based frame apply object."""
axis = obj._get_axis_number(axis)
klass: type[FrameApply]
if axis == 0:
Expand Down Expand Up @@ -693,7 +693,7 @@ def dtypes(self) -> Series:
return self.obj.dtypes

def apply(self) -> DataFrame | Series:
"""compute the results"""
"""Compute the results."""
# dispatch to agg
if is_list_like(self.f):
return self.apply_multiple()
Expand Down Expand Up @@ -1011,7 +1011,7 @@ def result_columns(self) -> Index:
def wrap_results_for_axis(
self, results: ResType, res_index: Index
) -> DataFrame | Series:
"""return the results for the columns"""
"""Return the results for the columns."""
result: DataFrame | Series

# we have requested to expand
Expand Down
4 changes: 4 additions & 0 deletions pandas/core/arrays/boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
BaseMaskedArray,
BaseMaskedDtype,
)
from pandas.core.indexers import check_array_indexer

if TYPE_CHECKING:
import pyarrow
Expand Down Expand Up @@ -367,6 +368,9 @@ def map_string(s):
def _coerce_to_array(self, value) -> tuple[np.ndarray, np.ndarray]:
return coerce_to_array(value)

def _validate_setitem_value(self, value):
return lib.is_bool(value)

@overload
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
...
Expand Down
4 changes: 4 additions & 0 deletions pandas/core/arrays/floating.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
NumericArray,
NumericDtype,
)
from pandas.core.indexers import check_array_indexer
from pandas.core.ops import invalid_comparison
from pandas.core.tools.numeric import to_numeric

Expand Down Expand Up @@ -278,6 +279,9 @@ def _from_sequence_of_strings(
def _coerce_to_array(self, value) -> tuple[np.ndarray, np.ndarray]:
return coerce_to_array(value, dtype=self.dtype)

def _validate_setitem_value(self, value):
return lib.is_float(value)

@overload
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
...
Expand Down
4 changes: 4 additions & 0 deletions pandas/core/arrays/integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
NumericArray,
NumericDtype,
)
from pandas.core.indexers import check_array_indexer
from pandas.core.ops import invalid_comparison
from pandas.core.tools.numeric import to_numeric

Expand Down Expand Up @@ -345,6 +346,9 @@ def _from_sequence_of_strings(
def _coerce_to_array(self, value) -> tuple[np.ndarray, np.ndarray]:
return coerce_to_array(value, dtype=self.dtype)

def _validate_setitem_value(self, value):
return lib.is_integer(value)

@overload
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
...
Expand Down
28 changes: 16 additions & 12 deletions pandas/core/arrays/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@

class BaseMaskedDtype(ExtensionDtype):
"""
Base class for dtypes for BasedMaskedArray subclasses.
Base class for dtypes for BaseMaskedArray subclasses.
"""

name: str
Expand Down Expand Up @@ -208,19 +208,23 @@ def fillna(
def _coerce_to_array(self, values) -> tuple[np.ndarray, np.ndarray]:
raise AbstractMethodError(self)

def __setitem__(self, key, value) -> None:
_is_scalar = is_scalar(value)
if _is_scalar:
value = [value]
value, mask = self._coerce_to_array(value)

if _is_scalar:
value = value[0]
mask = mask[0]
def _validate_setitem_value(self, value) -> bool:
raise AbstractMethodError(self)

def __setitem__(self, key, value) -> None:
key = check_array_indexer(self, key)
self._data[key] = value
self._mask[key] = mask
if is_scalar(value):
if self._validate_setitem_value(value):
self._data[key] = value
self._mask[key] = False
elif value is libmissing.NA:
self._mask[key] = True
else:
raise TypeError(f"Invalid value '{value}'")
else:
value, mask = self._coerce_to_array(value)
self._data[key] = value
self._mask[key] = mask

def __iter__(self):
if self.ndim == 1:
Expand Down

0 comments on commit d468b36

Please sign in to comment.