Skip to content

Commit

Permalink
REF: Back IntervalArray by array instead of Index (#36310)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel committed Oct 2, 2020
1 parent 456fcb9 commit 089fad9
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 80 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ Performance improvements
- Performance improvement in :meth:`GroupBy.transform` with the ``numba`` engine (:issue:`36240`)
- ``Styler`` uuid method altered to compress data transmission over web whilst maintaining reasonably low table collision probability (:issue:`36345`)
- Performance improvement in :meth:`pd.to_datetime` with non-ns time unit for ``float`` ``dtype`` columns (:issue:`20445`)
- Performance improvement in setting values on a :class:`IntervalArray` (:issue:`36310`)

.. ---------------------------------------------------------------------------
Expand Down
20 changes: 14 additions & 6 deletions pandas/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,8 +977,14 @@ def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray")
"""
_check_isinstance(left, right, IntervalArray)

assert_index_equal(left.left, right.left, exact=exact, obj=f"{obj}.left")
assert_index_equal(left.right, right.right, exact=exact, obj=f"{obj}.left")
kwargs = {}
if left._left.dtype.kind in ["m", "M"]:
# We have a DatetimeArray or TimedeltaArray
kwargs["check_freq"] = False

assert_equal(left._left, right._left, obj=f"{obj}.left", **kwargs)
assert_equal(left._right, right._right, obj=f"{obj}.left", **kwargs)

assert_attr_equal("closed", left, right, obj=obj)


Expand All @@ -989,20 +995,22 @@ def assert_period_array_equal(left, right, obj="PeriodArray"):
assert_attr_equal("freq", left, right, obj=obj)


def assert_datetime_array_equal(left, right, obj="DatetimeArray"):
def assert_datetime_array_equal(left, right, obj="DatetimeArray", check_freq=True):
__tracebackhide__ = True
_check_isinstance(left, right, DatetimeArray)

assert_numpy_array_equal(left._data, right._data, obj=f"{obj}._data")
assert_attr_equal("freq", left, right, obj=obj)
if check_freq:
assert_attr_equal("freq", left, right, obj=obj)
assert_attr_equal("tz", left, right, obj=obj)


def assert_timedelta_array_equal(left, right, obj="TimedeltaArray"):
def assert_timedelta_array_equal(left, right, obj="TimedeltaArray", check_freq=True):
__tracebackhide__ = True
_check_isinstance(left, right, TimedeltaArray)
assert_numpy_array_equal(left._data, right._data, obj=f"{obj}._data")
assert_attr_equal("freq", left, right, obj=obj)
if check_freq:
assert_attr_equal("freq", left, right, obj=obj)


def raise_assert_detail(obj, message, left, right, diff=None, index_values=None):
Expand Down
126 changes: 71 additions & 55 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from pandas.core.dtypes.dtypes import IntervalDtype
from pandas.core.dtypes.generic import (
ABCDatetimeIndex,
ABCIndexClass,
ABCIntervalIndex,
ABCPeriodIndex,
ABCSeries,
Expand All @@ -42,7 +41,7 @@
from pandas.core.arrays.base import ExtensionArray, _extension_array_shared_docs
from pandas.core.arrays.categorical import Categorical
import pandas.core.common as com
from pandas.core.construction import array
from pandas.core.construction import array, extract_array
from pandas.core.indexers import check_array_indexer
from pandas.core.indexes.base import ensure_index

Expand Down Expand Up @@ -161,12 +160,14 @@ def __new__(
verify_integrity: bool = True,
):

if isinstance(data, ABCSeries) and is_interval_dtype(data.dtype):
data = data._values
if isinstance(data, (ABCSeries, ABCIntervalIndex)) and is_interval_dtype(
data.dtype
):
data = data._values # TODO: extract_array?

if isinstance(data, (cls, ABCIntervalIndex)):
left = data.left
right = data.right
if isinstance(data, cls):
left = data._left
right = data._right
closed = closed or data.closed
else:

Expand Down Expand Up @@ -243,6 +244,20 @@ def _simple_new(
)
raise ValueError(msg)

# For dt64/td64 we want DatetimeArray/TimedeltaArray instead of ndarray
from pandas.core.ops.array_ops import maybe_upcast_datetimelike_array

left = maybe_upcast_datetimelike_array(left)
left = extract_array(left, extract_numpy=True)
right = maybe_upcast_datetimelike_array(right)
right = extract_array(right, extract_numpy=True)

lbase = getattr(left, "_ndarray", left).base
rbase = getattr(right, "_ndarray", right).base
if lbase is not None and lbase is rbase:
# If these share data, then setitem could corrupt our IA
right = right.copy()

result._left = left
result._right = right
result._closed = closed
Expand Down Expand Up @@ -476,18 +491,18 @@ def _validate(self):
if self.closed not in VALID_CLOSED:
msg = f"invalid option for 'closed': {self.closed}"
raise ValueError(msg)
if len(self.left) != len(self.right):
if len(self._left) != len(self._right):
msg = "left and right must have the same length"
raise ValueError(msg)
left_mask = notna(self.left)
right_mask = notna(self.right)
left_mask = notna(self._left)
right_mask = notna(self._right)
if not (left_mask == right_mask).all():
msg = (
"missing values must be missing in the same "
"location both left and right sides"
)
raise ValueError(msg)
if not (self.left[left_mask] <= self.right[left_mask]).all():
if not (self._left[left_mask] <= self._right[left_mask]).all():
msg = "left side of interval must be <= right side"
raise ValueError(msg)

Expand Down Expand Up @@ -527,37 +542,29 @@ def __iter__(self):
return iter(np.asarray(self))

def __len__(self) -> int:
return len(self.left)
return len(self._left)

def __getitem__(self, value):
value = check_array_indexer(self, value)
left = self.left[value]
right = self.right[value]
left = self._left[value]
right = self._right[value]

# scalar
if not isinstance(left, ABCIndexClass):
if not isinstance(left, (np.ndarray, ExtensionArray)):
# scalar
if is_scalar(left) and isna(left):
return self._fill_value
if np.ndim(left) > 1:
# GH#30588 multi-dimensional indexer disallowed
raise ValueError("multi-dimensional indexing not allowed")
return Interval(left, right, self.closed)

if np.ndim(left) > 1:
# GH#30588 multi-dimensional indexer disallowed
raise ValueError("multi-dimensional indexing not allowed")
return self._shallow_copy(left, right)

def __setitem__(self, key, value):
value_left, value_right = self._validate_setitem_value(value)
key = check_array_indexer(self, key)

# Need to ensure that left and right are updated atomically, so we're
# forced to copy, update the copy, and swap in the new values.
left = self.left.copy(deep=True)
left._values[key] = value_left
self._left = left

right = self.right.copy(deep=True)
right._values[key] = value_right
self._right = right
self._left[key] = value_left
self._right[key] = value_right

def __eq__(self, other):
# ensure pandas array for list-like and eliminate non-interval scalars
Expand Down Expand Up @@ -588,7 +595,7 @@ def __eq__(self, other):
if is_interval_dtype(other_dtype):
if self.closed != other.closed:
return np.zeros(len(self), dtype=bool)
return (self.left == other.left) & (self.right == other.right)
return (self._left == other.left) & (self._right == other.right)

# non-interval/non-object dtype -> no matches
if not is_object_dtype(other_dtype):
Expand All @@ -601,8 +608,8 @@ def __eq__(self, other):
if (
isinstance(obj, Interval)
and self.closed == obj.closed
and self.left[i] == obj.left
and self.right[i] == obj.right
and self._left[i] == obj.left
and self._right[i] == obj.right
):
result[i] = True

Expand Down Expand Up @@ -665,6 +672,7 @@ def astype(self, dtype, copy=True):
array : ExtensionArray or ndarray
ExtensionArray or NumPy ndarray with 'dtype' for its dtype.
"""
from pandas import Index
from pandas.core.arrays.string_ import StringDtype

if dtype is not None:
Expand All @@ -676,8 +684,10 @@ def astype(self, dtype, copy=True):

# need to cast to different subtype
try:
new_left = self.left.astype(dtype.subtype)
new_right = self.right.astype(dtype.subtype)
# We need to use Index rules for astype to prevent casting
# np.nan entries to int subtypes
new_left = Index(self._left, copy=False).astype(dtype.subtype)
new_right = Index(self._right, copy=False).astype(dtype.subtype)
except TypeError as err:
msg = (
f"Cannot convert {self.dtype} to {dtype}; subtypes are incompatible"
Expand Down Expand Up @@ -726,14 +736,14 @@ def copy(self):
-------
IntervalArray
"""
left = self.left.copy(deep=True)
right = self.right.copy(deep=True)
left = self._left.copy()
right = self._right.copy()
closed = self.closed
# TODO: Could skip verify_integrity here.
return type(self).from_arrays(left, right, closed=closed)

def isna(self):
return isna(self.left)
def isna(self) -> np.ndarray:
return isna(self._left)

def shift(self, periods: int = 1, fill_value: object = None) -> "IntervalArray":
if not len(self) or periods == 0:
Expand All @@ -749,7 +759,9 @@ def shift(self, periods: int = 1, fill_value: object = None) -> "IntervalArray":

empty_len = min(abs(periods), len(self))
if isna(fill_value):
fill_value = self.left._na_value
from pandas import Index

fill_value = Index(self._left, copy=False)._na_value
empty = IntervalArray.from_breaks([fill_value] * (empty_len + 1))
else:
empty = self._from_sequence([fill_value] * empty_len)
Expand Down Expand Up @@ -815,10 +827,10 @@ def take(self, indices, allow_fill=False, fill_value=None, axis=None, **kwargs):
fill_left, fill_right = self._validate_fill_value(fill_value)

left_take = take(
self.left, indices, allow_fill=allow_fill, fill_value=fill_left
self._left, indices, allow_fill=allow_fill, fill_value=fill_left
)
right_take = take(
self.right, indices, allow_fill=allow_fill, fill_value=fill_right
self._right, indices, allow_fill=allow_fill, fill_value=fill_right
)

return self._shallow_copy(left_take, right_take)
Expand Down Expand Up @@ -977,15 +989,19 @@ def left(self):
Return the left endpoints of each Interval in the IntervalArray as
an Index.
"""
return self._left
from pandas import Index

return Index(self._left, copy=False)

@property
def right(self):
"""
Return the right endpoints of each Interval in the IntervalArray as
an Index.
"""
return self._right
from pandas import Index

return Index(self._right, copy=False)

@property
def length(self):
Expand Down Expand Up @@ -1146,7 +1162,7 @@ def set_closed(self, closed):
raise ValueError(msg)

return type(self)._simple_new(
left=self.left, right=self.right, closed=closed, verify_integrity=False
left=self._left, right=self._right, closed=closed, verify_integrity=False
)

_interval_shared_docs[
Expand All @@ -1172,15 +1188,15 @@ def is_non_overlapping_monotonic(self):
# at a point when both sides of intervals are included
if self.closed == "both":
return bool(
(self.right[:-1] < self.left[1:]).all()
or (self.left[:-1] > self.right[1:]).all()
(self._right[:-1] < self._left[1:]).all()
or (self._left[:-1] > self._right[1:]).all()
)

# non-strict inequality when closed != 'both'; at least one side is
# not included in the intervals, so equality does not imply overlapping
return bool(
(self.right[:-1] <= self.left[1:]).all()
or (self.left[:-1] >= self.right[1:]).all()
(self._right[:-1] <= self._left[1:]).all()
or (self._left[:-1] >= self._right[1:]).all()
)

# ---------------------------------------------------------------------
Expand All @@ -1191,8 +1207,8 @@ def __array__(self, dtype=None) -> np.ndarray:
Return the IntervalArray's data as a numpy array of Interval
objects (with dtype='object')
"""
left = self.left
right = self.right
left = self._left
right = self._right
mask = self.isna()
closed = self._closed

Expand Down Expand Up @@ -1222,8 +1238,8 @@ def __arrow_array__(self, type=None):
interval_type = ArrowIntervalType(subtype, self.closed)
storage_array = pyarrow.StructArray.from_arrays(
[
pyarrow.array(self.left, type=subtype, from_pandas=True),
pyarrow.array(self.right, type=subtype, from_pandas=True),
pyarrow.array(self._left, type=subtype, from_pandas=True),
pyarrow.array(self._right, type=subtype, from_pandas=True),
],
names=["left", "right"],
)
Expand Down Expand Up @@ -1277,7 +1293,7 @@ def __arrow_array__(self, type=None):
_interval_shared_docs["to_tuples"] % dict(return_type="ndarray", examples="")
)
def to_tuples(self, na_tuple=True):
tuples = com.asarray_tuplesafe(zip(self.left, self.right))
tuples = com.asarray_tuplesafe(zip(self._left, self._right))
if not na_tuple:
# GH 18756
tuples = np.where(~self.isna(), tuples, np.nan)
Expand Down Expand Up @@ -1343,8 +1359,8 @@ def contains(self, other):
if isinstance(other, Interval):
raise NotImplementedError("contains not implemented for two intervals")

return (self.left < other if self.open_left else self.left <= other) & (
other < self.right if self.open_right else other <= self.right
return (self._left < other if self.open_left else self._left <= other) & (
other < self._right if self.open_right else other <= self._right
)


Expand Down
Loading

0 comments on commit 089fad9

Please sign in to comment.