diff --git a/doc/source/whatsnew/v1.2.0.rst b/doc/source/whatsnew/v1.2.0.rst index 6d1196b783f74..1236c672a1fa1 100644 --- a/doc/source/whatsnew/v1.2.0.rst +++ b/doc/source/whatsnew/v1.2.0.rst @@ -282,6 +282,7 @@ Performance improvements - Performance improvement in :meth:`GroupBy.transform` with the ``numba`` engine (:issue:`36240`) - ``Styler`` uuid method altered to compress data transmission over web whilst maintaining reasonably low table collision probability (:issue:`36345`) - Performance improvement in :meth:`pd.to_datetime` with non-ns time unit for ``float`` ``dtype`` columns (:issue:`20445`) +- Performance improvement in setting values on a :class:`IntervalArray` (:issue:`36310`) .. --------------------------------------------------------------------------- diff --git a/pandas/_testing.py b/pandas/_testing.py index 78b6b3c4f9072..cf6272edc4c05 100644 --- a/pandas/_testing.py +++ b/pandas/_testing.py @@ -977,8 +977,14 @@ def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray") """ _check_isinstance(left, right, IntervalArray) - assert_index_equal(left.left, right.left, exact=exact, obj=f"{obj}.left") - assert_index_equal(left.right, right.right, exact=exact, obj=f"{obj}.left") + kwargs = {} + if left._left.dtype.kind in ["m", "M"]: + # We have a DatetimeArray or TimedeltaArray + kwargs["check_freq"] = False + + assert_equal(left._left, right._left, obj=f"{obj}.left", **kwargs) + assert_equal(left._right, right._right, obj=f"{obj}.left", **kwargs) + assert_attr_equal("closed", left, right, obj=obj) @@ -989,20 +995,22 @@ def assert_period_array_equal(left, right, obj="PeriodArray"): assert_attr_equal("freq", left, right, obj=obj) -def assert_datetime_array_equal(left, right, obj="DatetimeArray"): +def assert_datetime_array_equal(left, right, obj="DatetimeArray", check_freq=True): __tracebackhide__ = True _check_isinstance(left, right, DatetimeArray) assert_numpy_array_equal(left._data, right._data, obj=f"{obj}._data") - assert_attr_equal("freq", left, right, obj=obj) + if check_freq: + assert_attr_equal("freq", left, right, obj=obj) assert_attr_equal("tz", left, right, obj=obj) -def assert_timedelta_array_equal(left, right, obj="TimedeltaArray"): +def assert_timedelta_array_equal(left, right, obj="TimedeltaArray", check_freq=True): __tracebackhide__ = True _check_isinstance(left, right, TimedeltaArray) assert_numpy_array_equal(left._data, right._data, obj=f"{obj}._data") - assert_attr_equal("freq", left, right, obj=obj) + if check_freq: + assert_attr_equal("freq", left, right, obj=obj) def raise_assert_detail(obj, message, left, right, diff=None, index_values=None): diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index 5105b5b9cc57b..413430942575d 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -31,7 +31,6 @@ from pandas.core.dtypes.dtypes import IntervalDtype from pandas.core.dtypes.generic import ( ABCDatetimeIndex, - ABCIndexClass, ABCIntervalIndex, ABCPeriodIndex, ABCSeries, @@ -42,7 +41,7 @@ from pandas.core.arrays.base import ExtensionArray, _extension_array_shared_docs from pandas.core.arrays.categorical import Categorical import pandas.core.common as com -from pandas.core.construction import array +from pandas.core.construction import array, extract_array from pandas.core.indexers import check_array_indexer from pandas.core.indexes.base import ensure_index @@ -161,12 +160,14 @@ def __new__( verify_integrity: bool = True, ): - if isinstance(data, ABCSeries) and is_interval_dtype(data.dtype): - data = data._values + if isinstance(data, (ABCSeries, ABCIntervalIndex)) and is_interval_dtype( + data.dtype + ): + data = data._values # TODO: extract_array? - if isinstance(data, (cls, ABCIntervalIndex)): - left = data.left - right = data.right + if isinstance(data, cls): + left = data._left + right = data._right closed = closed or data.closed else: @@ -243,6 +244,20 @@ def _simple_new( ) raise ValueError(msg) + # For dt64/td64 we want DatetimeArray/TimedeltaArray instead of ndarray + from pandas.core.ops.array_ops import maybe_upcast_datetimelike_array + + left = maybe_upcast_datetimelike_array(left) + left = extract_array(left, extract_numpy=True) + right = maybe_upcast_datetimelike_array(right) + right = extract_array(right, extract_numpy=True) + + lbase = getattr(left, "_ndarray", left).base + rbase = getattr(right, "_ndarray", right).base + if lbase is not None and lbase is rbase: + # If these share data, then setitem could corrupt our IA + right = right.copy() + result._left = left result._right = right result._closed = closed @@ -476,18 +491,18 @@ def _validate(self): if self.closed not in VALID_CLOSED: msg = f"invalid option for 'closed': {self.closed}" raise ValueError(msg) - if len(self.left) != len(self.right): + if len(self._left) != len(self._right): msg = "left and right must have the same length" raise ValueError(msg) - left_mask = notna(self.left) - right_mask = notna(self.right) + left_mask = notna(self._left) + right_mask = notna(self._right) if not (left_mask == right_mask).all(): msg = ( "missing values must be missing in the same " "location both left and right sides" ) raise ValueError(msg) - if not (self.left[left_mask] <= self.right[left_mask]).all(): + if not (self._left[left_mask] <= self._right[left_mask]).all(): msg = "left side of interval must be <= right side" raise ValueError(msg) @@ -527,37 +542,29 @@ def __iter__(self): return iter(np.asarray(self)) def __len__(self) -> int: - return len(self.left) + return len(self._left) def __getitem__(self, value): value = check_array_indexer(self, value) - left = self.left[value] - right = self.right[value] + left = self._left[value] + right = self._right[value] - # scalar - if not isinstance(left, ABCIndexClass): + if not isinstance(left, (np.ndarray, ExtensionArray)): + # scalar if is_scalar(left) and isna(left): return self._fill_value - if np.ndim(left) > 1: - # GH#30588 multi-dimensional indexer disallowed - raise ValueError("multi-dimensional indexing not allowed") return Interval(left, right, self.closed) - + if np.ndim(left) > 1: + # GH#30588 multi-dimensional indexer disallowed + raise ValueError("multi-dimensional indexing not allowed") return self._shallow_copy(left, right) def __setitem__(self, key, value): value_left, value_right = self._validate_setitem_value(value) key = check_array_indexer(self, key) - # Need to ensure that left and right are updated atomically, so we're - # forced to copy, update the copy, and swap in the new values. - left = self.left.copy(deep=True) - left._values[key] = value_left - self._left = left - - right = self.right.copy(deep=True) - right._values[key] = value_right - self._right = right + self._left[key] = value_left + self._right[key] = value_right def __eq__(self, other): # ensure pandas array for list-like and eliminate non-interval scalars @@ -588,7 +595,7 @@ def __eq__(self, other): if is_interval_dtype(other_dtype): if self.closed != other.closed: return np.zeros(len(self), dtype=bool) - return (self.left == other.left) & (self.right == other.right) + return (self._left == other.left) & (self._right == other.right) # non-interval/non-object dtype -> no matches if not is_object_dtype(other_dtype): @@ -601,8 +608,8 @@ def __eq__(self, other): if ( isinstance(obj, Interval) and self.closed == obj.closed - and self.left[i] == obj.left - and self.right[i] == obj.right + and self._left[i] == obj.left + and self._right[i] == obj.right ): result[i] = True @@ -665,6 +672,7 @@ def astype(self, dtype, copy=True): array : ExtensionArray or ndarray ExtensionArray or NumPy ndarray with 'dtype' for its dtype. """ + from pandas import Index from pandas.core.arrays.string_ import StringDtype if dtype is not None: @@ -676,8 +684,10 @@ def astype(self, dtype, copy=True): # need to cast to different subtype try: - new_left = self.left.astype(dtype.subtype) - new_right = self.right.astype(dtype.subtype) + # We need to use Index rules for astype to prevent casting + # np.nan entries to int subtypes + new_left = Index(self._left, copy=False).astype(dtype.subtype) + new_right = Index(self._right, copy=False).astype(dtype.subtype) except TypeError as err: msg = ( f"Cannot convert {self.dtype} to {dtype}; subtypes are incompatible" @@ -726,14 +736,14 @@ def copy(self): ------- IntervalArray """ - left = self.left.copy(deep=True) - right = self.right.copy(deep=True) + left = self._left.copy() + right = self._right.copy() closed = self.closed # TODO: Could skip verify_integrity here. return type(self).from_arrays(left, right, closed=closed) - def isna(self): - return isna(self.left) + def isna(self) -> np.ndarray: + return isna(self._left) def shift(self, periods: int = 1, fill_value: object = None) -> "IntervalArray": if not len(self) or periods == 0: @@ -749,7 +759,9 @@ def shift(self, periods: int = 1, fill_value: object = None) -> "IntervalArray": empty_len = min(abs(periods), len(self)) if isna(fill_value): - fill_value = self.left._na_value + from pandas import Index + + fill_value = Index(self._left, copy=False)._na_value empty = IntervalArray.from_breaks([fill_value] * (empty_len + 1)) else: empty = self._from_sequence([fill_value] * empty_len) @@ -815,10 +827,10 @@ def take(self, indices, allow_fill=False, fill_value=None, axis=None, **kwargs): fill_left, fill_right = self._validate_fill_value(fill_value) left_take = take( - self.left, indices, allow_fill=allow_fill, fill_value=fill_left + self._left, indices, allow_fill=allow_fill, fill_value=fill_left ) right_take = take( - self.right, indices, allow_fill=allow_fill, fill_value=fill_right + self._right, indices, allow_fill=allow_fill, fill_value=fill_right ) return self._shallow_copy(left_take, right_take) @@ -977,7 +989,9 @@ def left(self): Return the left endpoints of each Interval in the IntervalArray as an Index. """ - return self._left + from pandas import Index + + return Index(self._left, copy=False) @property def right(self): @@ -985,7 +999,9 @@ def right(self): Return the right endpoints of each Interval in the IntervalArray as an Index. """ - return self._right + from pandas import Index + + return Index(self._right, copy=False) @property def length(self): @@ -1146,7 +1162,7 @@ def set_closed(self, closed): raise ValueError(msg) return type(self)._simple_new( - left=self.left, right=self.right, closed=closed, verify_integrity=False + left=self._left, right=self._right, closed=closed, verify_integrity=False ) _interval_shared_docs[ @@ -1172,15 +1188,15 @@ def is_non_overlapping_monotonic(self): # at a point when both sides of intervals are included if self.closed == "both": return bool( - (self.right[:-1] < self.left[1:]).all() - or (self.left[:-1] > self.right[1:]).all() + (self._right[:-1] < self._left[1:]).all() + or (self._left[:-1] > self._right[1:]).all() ) # non-strict inequality when closed != 'both'; at least one side is # not included in the intervals, so equality does not imply overlapping return bool( - (self.right[:-1] <= self.left[1:]).all() - or (self.left[:-1] >= self.right[1:]).all() + (self._right[:-1] <= self._left[1:]).all() + or (self._left[:-1] >= self._right[1:]).all() ) # --------------------------------------------------------------------- @@ -1191,8 +1207,8 @@ def __array__(self, dtype=None) -> np.ndarray: Return the IntervalArray's data as a numpy array of Interval objects (with dtype='object') """ - left = self.left - right = self.right + left = self._left + right = self._right mask = self.isna() closed = self._closed @@ -1222,8 +1238,8 @@ def __arrow_array__(self, type=None): interval_type = ArrowIntervalType(subtype, self.closed) storage_array = pyarrow.StructArray.from_arrays( [ - pyarrow.array(self.left, type=subtype, from_pandas=True), - pyarrow.array(self.right, type=subtype, from_pandas=True), + pyarrow.array(self._left, type=subtype, from_pandas=True), + pyarrow.array(self._right, type=subtype, from_pandas=True), ], names=["left", "right"], ) @@ -1277,7 +1293,7 @@ def __arrow_array__(self, type=None): _interval_shared_docs["to_tuples"] % dict(return_type="ndarray", examples="") ) def to_tuples(self, na_tuple=True): - tuples = com.asarray_tuplesafe(zip(self.left, self.right)) + tuples = com.asarray_tuplesafe(zip(self._left, self._right)) if not na_tuple: # GH 18756 tuples = np.where(~self.isna(), tuples, np.nan) @@ -1343,8 +1359,8 @@ def contains(self, other): if isinstance(other, Interval): raise NotImplementedError("contains not implemented for two intervals") - return (self.left < other if self.open_left else self.left <= other) & ( - other < self.right if self.open_right else other <= self.right + return (self._left < other if self.open_left else self._left <= other) & ( + other < self._right if self.open_right else other <= self._right ) diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index 8855d987af745..a56f6a5bb0340 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -183,12 +183,8 @@ def func(intvidx_self, other, sort=False): ) ) @inherit_names(["set_closed", "to_tuples"], IntervalArray, wrap=True) -@inherit_names( - ["__array__", "overlaps", "contains", "left", "right", "length"], IntervalArray -) -@inherit_names( - ["is_non_overlapping_monotonic", "mid", "closed"], IntervalArray, cache=True -) +@inherit_names(["__array__", "overlaps", "contains"], IntervalArray) +@inherit_names(["is_non_overlapping_monotonic", "closed"], IntervalArray, cache=True) class IntervalIndex(IntervalMixin, ExtensionIndex): _typ = "intervalindex" _comparables = ["name"] @@ -201,6 +197,8 @@ class IntervalIndex(IntervalMixin, ExtensionIndex): _mask = None _data: IntervalArray + _values: IntervalArray + # -------------------------------------------------------------------- # Constructors @@ -409,7 +407,7 @@ def __reduce__(self): return _new_IntervalIndex, (type(self), d), None @Appender(Index.astype.__doc__) - def astype(self, dtype, copy=True): + def astype(self, dtype, copy: bool = True): with rewrite_exception("IntervalArray", type(self).__name__): new_values = self._values.astype(dtype, copy=copy) if is_interval_dtype(new_values.dtype): @@ -438,7 +436,7 @@ def is_monotonic_decreasing(self) -> bool: return self[::-1].is_monotonic_increasing @cache_readonly - def is_unique(self): + def is_unique(self) -> bool: """ Return True if the IntervalIndex contains unique elements, else False. """ @@ -865,6 +863,22 @@ def _convert_list_indexer(self, keyarr): # -------------------------------------------------------------------- + @cache_readonly + def left(self) -> Index: + return Index(self._data.left, copy=False) + + @cache_readonly + def right(self) -> Index: + return Index(self._data.right, copy=False) + + @cache_readonly + def mid(self): + return Index(self._data.mid, copy=False) + + @property + def length(self): + return Index(self._data.length, copy=False) + @Appender(Index.where.__doc__) def where(self, cond, other=None): if other is None: diff --git a/pandas/tests/extension/test_interval.py b/pandas/tests/extension/test_interval.py index 2411f6cfbd936..4fdcf930d224f 100644 --- a/pandas/tests/extension/test_interval.py +++ b/pandas/tests/extension/test_interval.py @@ -147,9 +147,7 @@ class TestReshaping(BaseInterval, base.BaseReshapingTests): class TestSetitem(BaseInterval, base.BaseSetitemTests): - @pytest.mark.xfail(reason="GH#27147 setitem changes underlying index") - def test_setitem_preserves_views(self, data): - super().test_setitem_preserves_views(data) + pass class TestPrinting(BaseInterval, base.BasePrintingTests): diff --git a/pandas/tests/indexes/interval/test_constructors.py b/pandas/tests/indexes/interval/test_constructors.py index fa881df8139c6..aec7de549744f 100644 --- a/pandas/tests/indexes/interval/test_constructors.py +++ b/pandas/tests/indexes/interval/test_constructors.py @@ -262,6 +262,12 @@ def test_length_one(self): expected = IntervalIndex.from_breaks([]) tm.assert_index_equal(result, expected) + def test_left_right_dont_share_data(self): + # GH#36310 + breaks = np.arange(5) + result = IntervalIndex.from_breaks(breaks)._data + assert result._left.base is None or result._left.base is not result._right.base + class TestFromTuples(Base): """Tests specific to IntervalIndex.from_tuples""" diff --git a/pandas/tests/series/indexing/test_getitem.py b/pandas/tests/series/indexing/test_getitem.py index 6b7cda89a4714..5b585e8802752 100644 --- a/pandas/tests/series/indexing/test_getitem.py +++ b/pandas/tests/series/indexing/test_getitem.py @@ -101,7 +101,7 @@ def test_getitem_intlist_intindex_periodvalues(self): @pytest.mark.parametrize("box", [list, np.array, pd.Index]) def test_getitem_intlist_intervalindex_non_int(self, box): # GH#33404 fall back to positional since ints are unambiguous - dti = date_range("2000-01-03", periods=3) + dti = date_range("2000-01-03", periods=3)._with_freq(None) ii = pd.IntervalIndex.from_breaks(dti) ser = Series(range(len(ii)), index=ii) diff --git a/pandas/tests/util/test_assert_interval_array_equal.py b/pandas/tests/util/test_assert_interval_array_equal.py index 96f2973a1528c..2e8699536c72a 100644 --- a/pandas/tests/util/test_assert_interval_array_equal.py +++ b/pandas/tests/util/test_assert_interval_array_equal.py @@ -41,9 +41,9 @@ def test_interval_array_equal_periods_mismatch(): msg = """\ IntervalArray.left are different -IntervalArray.left length are different -\\[left\\]: 5, Int64Index\\(\\[0, 1, 2, 3, 4\\], dtype='int64'\\) -\\[right\\]: 6, Int64Index\\(\\[0, 1, 2, 3, 4, 5\\], dtype='int64'\\)""" +IntervalArray.left shapes are different +\\[left\\]: \\(5,\\) +\\[right\\]: \\(6,\\)""" with pytest.raises(AssertionError, match=msg): tm.assert_interval_array_equal(arr1, arr2) @@ -58,8 +58,8 @@ def test_interval_array_equal_end_mismatch(): IntervalArray.left are different IntervalArray.left values are different \\(80.0 %\\) -\\[left\\]: Int64Index\\(\\[0, 2, 4, 6, 8\\], dtype='int64'\\) -\\[right\\]: Int64Index\\(\\[0, 4, 8, 12, 16\\], dtype='int64'\\)""" +\\[left\\]: \\[0, 2, 4, 6, 8\\] +\\[right\\]: \\[0, 4, 8, 12, 16\\]""" with pytest.raises(AssertionError, match=msg): tm.assert_interval_array_equal(arr1, arr2) @@ -74,8 +74,8 @@ def test_interval_array_equal_start_mismatch(): IntervalArray.left are different IntervalArray.left values are different \\(100.0 %\\) -\\[left\\]: Int64Index\\(\\[0, 1, 2, 3\\], dtype='int64'\\) -\\[right\\]: Int64Index\\(\\[1, 2, 3, 4\\], dtype='int64'\\)""" +\\[left\\]: \\[0, 1, 2, 3\\] +\\[right\\]: \\[1, 2, 3, 4\\]""" with pytest.raises(AssertionError, match=msg): tm.assert_interval_array_equal(arr1, arr2)