Skip to content
forked from pydata/xarray

Commit

Permalink
Add optimization to DaskIndexingAdapter
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Dec 11, 2019
1 parent 28074b9 commit 1972fde
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 16 deletions.
20 changes: 19 additions & 1 deletion xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import defaultdict
from contextlib import suppress
from datetime import timedelta
from typing import Any, Callable, Sequence, Tuple, Union
from typing import Any, Callable, Iterable, Sequence, Tuple, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -1291,6 +1291,24 @@ def __init__(self, array):
self.array = array

def __getitem__(self, key):

if not isinstance(key, VectorizedIndexer):
# if possible, short-circuit when keys are effectively slice(None)
# This preserves dask name and passes lazy array equivalence checks
# (see duck_array_ops.lazy_array_equiv)
rewritten_indexer = False
new_indexer = []
for idim, k in enumerate(key.tuple):
if isinstance(k, Iterable) and duck_array_ops.array_equiv(
k, np.arange(self.array.shape[idim])
):
new_indexer.append(slice(None))
rewritten_indexer = True
else:
new_indexer.append(k)
if rewritten_indexer:
key = type(key)(tuple(new_indexer))

if isinstance(key, BasicIndexer):
return self.array[key.tuple]
elif isinstance(key, VectorizedIndexer):
Expand Down
16 changes: 1 addition & 15 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import defaultdict
from datetime import timedelta
from distutils.version import LooseVersion
from typing import Any, Dict, Hashable, Iterable, Mapping, TypeVar, Union
from typing import Any, Dict, Hashable, Mapping, TypeVar, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -540,20 +540,6 @@ def _broadcast_indexes(self, key):
k.item() if isinstance(k, np.ndarray) and k.ndim == 0 else k for k in key
)

key_dict = dict(zip(self.dims, key))
for dim, k in key_dict.items():
if isinstance(k, Iterable):
# let da.sel(x=da.x) pass but skip if Variable has different dimensions
# e.g. da.sel(x=Variable(("points",), [0, 1, 2]))
if isinstance(k, Variable) and k.dims != (dim,):
continue
if duck_array_ops.array_equiv(k, np.arange(self.sizes[dim])):
# short-circuit when keys are effectively slice(None)
# This preserves dask name and passes lazy array equivalence checks
# (see duck_array_ops.lazy_array_equiv)
key_dict[dim] = slice(None)
key = tuple(key_dict.values())

if all(isinstance(k, BASIC_INDEXING_TYPES) for k in key):
return self._broadcast_indexes_basic(key)

Expand Down

0 comments on commit 1972fde

Please sign in to comment.