Skip to content

Commit

Permalink
Fix indexing with datetime64[ns] with pandas=1.1
Browse files Browse the repository at this point in the history
Fixes pydata#4283

The underlying issue is that calling `.item()` on a NumPy array with
`dtype=datetime64[ns]` returns an _integer_, rather than an `np.datetime64
scalar. This is somewhat baffling but works this way because `.item()`
returns native Python types, but `datetime.datetime` doesn't support
nanosecond precision.

`pandas.Index.get_loc` used to support these integers, but now is more strict.
Hence we get errors.

We can fix this by using `array[()]` to convert 0d arrays into NumPy scalars
instead of calling `array.item()`.

I've added a crude regression test. There may well be a better way to test this
but I haven't figured it out yet.
  • Loading branch information
shoyer committed Jul 31, 2020
1 parent 1be777f commit 3c9da37
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
9 changes: 4 additions & 5 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,9 @@ def convert_label_indexer(index, label, index_name="", method=None, tolerance=No
else _asarray_tuplesafe(label)
)
if label.ndim == 0:
label_value = label[()]
if isinstance(index, pd.MultiIndex):
indexer, new_index = index.get_loc_level(label.item(), level=0)
indexer, new_index = index.get_loc_level(label_value, level=0)
elif isinstance(index, pd.CategoricalIndex):
if method is not None:
raise ValueError(
Expand All @@ -184,11 +185,9 @@ def convert_label_indexer(index, label, index_name="", method=None, tolerance=No
raise ValueError(
"'tolerance' is not a valid kwarg when indexing using a CategoricalIndex."
)
indexer = index.get_loc(label.item())
indexer = index.get_loc(label_value)
else:
indexer = index.get_loc(
label.item(), method=method, tolerance=tolerance
)
indexer = index.get_loc(label_value, method=method, tolerance=tolerance)
elif label.dtype.kind == "b":
indexer = label
else:
Expand Down
8 changes: 7 additions & 1 deletion xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,7 @@ def test_sel_invalid_slice(self):
with raises_regex(ValueError, "cannot use non-scalar arrays"):
array.sel(x=slice(array.x))

def test_sel_dataarray_datetime(self):
def test_sel_dataarray_datetime_slice(self):
# regression test for GH1240
times = pd.date_range("2000-01-01", freq="D", periods=365)
array = DataArray(np.arange(365), [("time", times)])
Expand Down Expand Up @@ -1078,6 +1078,12 @@ def test_loc(self):
assert_identical(da[:3, :4], da.loc[["a", "b", "c"], np.arange(4)])
assert_identical(da[:, :4], da.loc[:, self.ds["y"] < 4])

def test_loc_datetime64_value(self):
# regression test for https://github.com/pydata/xarray/issues/4283
t = np.array(["2017-09-05T12", "2017-09-05T15"], dtype="datetime64[ns]")
array = DataArray(np.ones(t.shape), dims=("time",), coords=(t,))
assert_identical(array.loc[{"time": t[0]}], array[0])

def test_loc_assign(self):
self.ds["x"] = ("x", np.array(list("abcdefghij")))
da = self.ds["foo"]
Expand Down
8 changes: 7 additions & 1 deletion xarray/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
from xarray import DataArray, Dataset, Variable
from xarray.core import indexing, nputils

from . import IndexerMaker, ReturnItem, assert_array_equal, raises_regex
from . import (
IndexerMaker,
ReturnItem,
assert_array_equal,
assert_identical,
raises_regex,
)

B = IndexerMaker(indexing.BasicIndexer)

Expand Down

0 comments on commit 3c9da37

Please sign in to comment.