Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Array indexing with dask arrays #2511

Closed
ulijh opened this issue Oct 25, 2018 · 20 comments · Fixed by #5873
Closed

Array indexing with dask arrays #2511

ulijh opened this issue Oct 25, 2018 · 20 comments · Fixed by #5873

Comments

@ulijh
Copy link
Contributor

ulijh commented Oct 25, 2018

Code example

da = xr.DataArray(np.ones((10, 10))).chunk(2)
indc = xr.DataArray(np.random.randint(0, 9, 10)).chunk(2)

# This fails:
da[{'dim_1' : indc}].values

Problem description

Indexing with chunked arrays fails, whereas it's fine with "normal" arrays. In case the indices are the result of a lazy calculation, I would like to continue lazily.

Expected Output

I would expect an output just like in the "un-chunked" case:

da[{'dim_1' : indc.compute()}].values
# Returns: array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

Output of xr.show_versions()

INSTALLED VERSIONS ------------------ commit: None python: 3.7.0.final.0 python-bits: 64 OS: Linux OS-release: 4.18.14-arch1-1-ARCH machine: x86_64 processor: byteorder: little LC_ALL: None LANG: de_DE.utf8 LOCALE: de_DE.UTF-8

xarray: 0.10.9
pandas: 0.23.4
numpy: 1.15.2
scipy: 1.1.0
netCDF4: None
h5netcdf: 0.6.2
h5py: 2.8.0
Nio: None
zarr: None
cftime: None
PseudonetCDF: None
rasterio: None
iris: None
bottleneck: 1.2.1
cyordereddict: None
dask: 0.19.4
distributed: None
matplotlib: 2.2.3
cartopy: 0.16.0
seaborn: None
setuptools: 40.4.3
pip: 18.0
conda: None
pytest: 3.8.2
IPython: 6.5.0
sphinx: 1.8.0

@shoyer
Copy link
Member

shoyer commented Oct 25, 2018

For reference, here's the current stacktrace/error message:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-7-74fe4ba70f9d> in <module>()
----> 1 da[{'dim_1' : indc}]

/usr/local/lib/python3.6/dist-packages/xarray/core/dataarray.py in __getitem__(self, key)
    472         else:
    473             # xarray-style array indexing
--> 474             return self.isel(indexers=self._item_key_to_dict(key))
    475 
    476     def __setitem__(self, key, value):

/usr/local/lib/python3.6/dist-packages/xarray/core/dataarray.py in isel(self, indexers, drop, **indexers_kwargs)
    817         """
    818         indexers = either_dict_or_kwargs(indexers, indexers_kwargs, 'isel')
--> 819         ds = self._to_temp_dataset().isel(drop=drop, indexers=indexers)
    820         return self._from_temp_dataset(ds)
    821 

/usr/local/lib/python3.6/dist-packages/xarray/core/dataset.py in isel(self, indexers, drop, **indexers_kwargs)
   1537         for name, var in iteritems(self._variables):
   1538             var_indexers = {k: v for k, v in indexers_list if k in var.dims}
-> 1539             new_var = var.isel(indexers=var_indexers)
   1540             if not (drop and name in var_indexers):
   1541                 variables[name] = new_var

/usr/local/lib/python3.6/dist-packages/xarray/core/variable.py in isel(self, indexers, drop, **indexers_kwargs)
    905             if dim in indexers:
    906                 key[i] = indexers[dim]
--> 907         return self[tuple(key)]
    908 
    909     def squeeze(self, dim=None):

/usr/local/lib/python3.6/dist-packages/xarray/core/variable.py in __getitem__(self, key)
    614         array `x.values` directly.
    615         """
--> 616         dims, indexer, new_order = self._broadcast_indexes(key)
    617         data = as_indexable(self._data)[indexer]
    618         if new_order:

/usr/local/lib/python3.6/dist-packages/xarray/core/variable.py in _broadcast_indexes(self, key)
    487             return self._broadcast_indexes_outer(key)
    488 
--> 489         return self._broadcast_indexes_vectorized(key)
    490 
    491     def _broadcast_indexes_basic(self, key):

/usr/local/lib/python3.6/dist-packages/xarray/core/variable.py in _broadcast_indexes_vectorized(self, key)
    599             new_order = None
    600 
--> 601         return out_dims, VectorizedIndexer(tuple(out_key)), new_order
    602 
    603     def __getitem__(self, key):

/usr/local/lib/python3.6/dist-packages/xarray/core/indexing.py in __init__(self, key)
    423             else:
    424                 raise TypeError('unexpected indexer type for {}: {!r}'
--> 425                                 .format(type(self).__name__, k))
    426             new_key.append(k)
    427 

TypeError: unexpected indexer type for VectorizedIndexer: dask.array<xarray-<this-array>, shape=(10,), dtype=int64, chunksize=(2,)>

It looks like we could support this relatively easily since dask.array supports indexing with dask arrays now. This would be a welcome enhancement!

@ulijh
Copy link
Contributor Author

ulijh commented Oct 26, 2018

It seem's working fine with the following change but it has a lot of dublicated code...

diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py
index d51da471..9fe93581 100644
--- a/xarray/core/indexing.py
+++ b/xarray/core/indexing.py
@@ -7,6 +7,7 @@ from datetime import timedelta
 
 import numpy as np
 import pandas as pd
+import dask.array as da
 
 from . import duck_array_ops, nputils, utils
 from .pycompat import (
@@ -420,6 +421,19 @@ class VectorizedIndexer(ExplicitIndexer):
                                      'have different numbers of dimensions: {}'
                                      .format(ndims))
                 k = np.asarray(k, dtype=np.int64)
+            elif isinstance(k, dask_array_type):
+                if not np.issubdtype(k.dtype, np.integer):
+                    raise TypeError('invalid indexer array, does not have '
+                                    'integer dtype: {!r}'.format(k))
+                if ndim is None:
+                    ndim = k.ndim
+                elif ndim != k.ndim:
+                    ndims = [k.ndim for k in key
+                             if isinstance(k, (np.ndarray) + dask_array_type)]
+                    raise ValueError('invalid indexer key: ndarray arguments '
+                                     'have different numbers of dimensions: {}'
+                                     .format(ndims))
+                k = da.array(k, dtype=np.int64)
             else:
                 raise TypeError('unexpected indexer type for {}: {!r}'
                                 .format(type(self).__name__, k))

@ulijh
Copy link
Contributor Author

ulijh commented Jun 3, 2019

As of version 0.12 indexing with dask arrays works out of the box... I think this can be closed now.

@ulijh ulijh closed this as completed Jun 3, 2019
@ulijh ulijh reopened this Aug 20, 2019
@ulijh
Copy link
Contributor Author

ulijh commented Aug 20, 2019

Even though the example from above does work, sadly, the following does not:

import xarray as xr
import dask.array as da
import numpy as np
da = xr.DataArray(np.random.rand(3*4*5).reshape((3,4,5))).chunk(dict(dim_0=1))
idcs = da.argmax('dim_2')
da[dict(dim_2=idcs)]

results in

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-4-3542cdd6d61c> in <module>
----> 1 da[dict(dim_2=idcs)]

~/src/xarray/xarray/core/dataarray.py in __getitem__(self, key)
    604         else:
    605             # xarray-style array indexing
--> 606             return self.isel(indexers=self._item_key_to_dict(key))
    607 
    608     def __setitem__(self, key: Any, value: Any) -> None:

~/src/xarray/xarray/core/dataarray.py in isel(self, indexers, drop, **indexers_kwargs)
    986         """
    987         indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")
--> 988         ds = self._to_temp_dataset().isel(drop=drop, indexers=indexers)
    989         return self._from_temp_dataset(ds)
    990 

~/src/xarray/xarray/core/dataset.py in isel(self, indexers, drop, **indexers_kwargs)
   1901                     indexes[name] = new_index
   1902             else:
-> 1903                 new_var = var.isel(indexers=var_indexers)
   1904 
   1905             variables[name] = new_var

~/src/xarray/xarray/core/variable.py in isel(self, indexers, drop, **indexers_kwargs)
    984             if dim in indexers:
    985                 key[i] = indexers[dim]
--> 986         return self[tuple(key)]
    987 
    988     def squeeze(self, dim=None):

~/src/xarray/xarray/core/variable.py in __getitem__(self, key)
    675         array `x.values` directly.
    676         """
--> 677         dims, indexer, new_order = self._broadcast_indexes(key)
    678         data = as_indexable(self._data)[indexer]
    679         if new_order:

~/src/xarray/xarray/core/variable.py in _broadcast_indexes(self, key)
    532             if isinstance(k, Variable):
    533                 if len(k.dims) > 1:
--> 534                     return self._broadcast_indexes_vectorized(key)
    535                 dims.append(k.dims[0])
    536             elif not isinstance(k, integer_types):

~/src/xarray/xarray/core/variable.py in _broadcast_indexes_vectorized(self, key)
    660             new_order = None
    661 
--> 662         return out_dims, VectorizedIndexer(tuple(out_key)), new_order
    663 
    664     def __getitem__(self, key):

~/src/xarray/xarray/core/indexing.py in __init__(self, key)
    460                 raise TypeError(
    461                     "unexpected indexer type for {}: {!r}".format(
--> 462                         type(self).__name__, k
    463                     )
    464                 )

TypeError: unexpected indexer type for VectorizedIndexer: dask.array<arg_agg-aggregate, shape=(3, 4), dtype=int64, chunksize=(1, 4)>

@shoyer
Copy link
Member

shoyer commented Aug 20, 2019

Yes, something seems to be going wrong here...

@ulijh
Copy link
Contributor Author

ulijh commented Aug 28, 2019

I think the problem is somewhere here:

def safe_cast_to_index(array: Any) -> pd.Index:
"""Given an array, safely cast it to a pandas.Index.
If it is already a pandas.Index, return it unchanged.
Unlike pandas.Index, if the array has dtype=object or dtype=timedelta64,
this function will not attempt to do automatic type conversion but will
always return an index with dtype=object.
"""
if isinstance(array, pd.Index):
index = array
elif hasattr(array, "to_index"):
index = array.to_index()
else:
kwargs = {}
if hasattr(array, "dtype") and array.dtype.kind == "O":
kwargs["dtype"] = object
index = pd.Index(np.asarray(array), **kwargs)
return _maybe_cast_to_cftimeindex(index)

I don't think pandas.Index can hold lazy arrays. Could there be a way around exploiting dask.dataframe indexing methods?

@rafa-guedes
Copy link
Contributor

I'm having similar issue, here is an example:

import numpy as np
import dask.array as da
import xarray as xr

darr = xr.DataArray(data=[0.2, 0.4, 0.6], coords={"z": range(3)}, dims=("z",))
good_indexer = xr.DataArray(
    data=np.random.randint(0, 3, 8).reshape(4, 2).astype(int),
    coords={"y": range(4), "x": range(2)},
    dims=("y", "x")
)
bad_indexer = xr.DataArray(
    data=da.random.randint(0, 3, 8).reshape(4, 2).astype(int),
    coords={"y": range(4), "x": range(2)},
    dims=("y", "x")
)

In [5]: darr                                                                                                                                                            
Out[5]: 
<xarray.DataArray (z: 3)>
array([0.2, 0.4, 0.6])
Coordinates:
  * z        (z) int64 0 1 2

In [6]: good_indexer                                                     
Out[6]: 
<xarray.DataArray (y: 4, x: 2)>
array([[0, 1],
       [2, 2],
       [1, 2],
       [1, 0]])
Coordinates:
  * y        (y) int64 0 1 2 3
  * x        (x) int64 0 1

In [7]: bad_indexer                                                      
Out[7]: 
<xarray.DataArray 'reshape-417766b2035dcb1227ddde8505297039' (y: 4, x: 2)>
dask.array<reshape, shape=(4, 2), dtype=int64, chunksize=(4, 2), chunktype=numpy.ndarray>
Coordinates:
  * y        (y) int64 0 1 2 3
  * x        (x) int64 0 1

In [8]: darr[good_indexer]                                                                                                                                              
Out[8]: 
<xarray.DataArray (y: 4, x: 2)>
array([[0.2, 0.4],
       [0.6, 0.6],
       [0.4, 0.6],
       [0.4, 0.2]])
Coordinates:
    z        (y, x) int64 0 1 2 2 1 2 1 0
  * y        (y) int64 0 1 2 3
  * x        (x) int64 0 1

In [9]: darr[bad_indexer]                                                                                                                                               
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-8-2a57c1a2eade> in <module>
----> 1 darr[bad_indexer]

~/.virtualenvs/py3/local/lib/python3.7/site-packages/xarray/core/dataarray.py in __getitem__(self, key)
    638         else:
    639             # xarray-style array indexing
--> 640             return self.isel(indexers=self._item_key_to_dict(key))
    641 
    642     def __setitem__(self, key: Any, value: Any) -> None:

~/.virtualenvs/py3/local/lib/python3.7/site-packages/xarray/core/dataarray.py in isel(self, indexers, drop, **indexers_kwargs)
   1012         """
   1013         indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")
-> 1014         ds = self._to_temp_dataset().isel(drop=drop, indexers=indexers)
   1015         return self._from_temp_dataset(ds)
   1016 

~/.virtualenvs/py3/local/lib/python3.7/site-packages/xarray/core/dataset.py in isel(self, indexers, drop, **indexers_kwargs)
   1920             if name in self.indexes:
   1921                 new_var, new_index = isel_variable_and_index(
-> 1922                     name, var, self.indexes[name], var_indexers
   1923                 )
   1924                 if new_index is not None:

~/.virtualenvs/py3/local/lib/python3.7/site-packages/xarray/core/indexes.py in isel_variable_and_index(name, variable, index, indexers)
     79         )
     80 
---> 81     new_variable = variable.isel(indexers)
     82 
     83     if new_variable.dims != (name,):

~/.virtualenvs/py3/local/lib/python3.7/site-packages/xarray/core/variable.py in isel(self, indexers, **indexers_kwargs)
   1052 
   1053         key = tuple(indexers.get(dim, slice(None)) for dim in self.dims)
-> 1054         return self[key]
   1055 
   1056     def squeeze(self, dim=None):

~/.virtualenvs/py3/local/lib/python3.7/site-packages/xarray/core/variable.py in __getitem__(self, key)
    700         array `x.values` directly.
    701         """
--> 702         dims, indexer, new_order = self._broadcast_indexes(key)
    703         data = as_indexable(self._data)[indexer]
    704         if new_order:

~/.virtualenvs/py3/local/lib/python3.7/site-packages/xarray/core/variable.py in _broadcast_indexes(self, key)
    557             if isinstance(k, Variable):
    558                 if len(k.dims) > 1:
--> 559                     return self._broadcast_indexes_vectorized(key)
    560                 dims.append(k.dims[0])
    561             elif not isinstance(k, integer_types):

~/.virtualenvs/py3/local/lib/python3.7/site-packages/xarray/core/variable.py in _broadcast_indexes_vectorized(self, key)
    685             new_order = None
    686 
--> 687         return out_dims, VectorizedIndexer(tuple(out_key)), new_order
    688 
    689     def __getitem__(self: VariableType, key) -> VariableType:

~/.virtualenvs/py3/local/lib/python3.7/site-packages/xarray/core/indexing.py in __init__(self, key)
    447             else:
    448                 raise TypeError(
--> 449                     f"unexpected indexer type for {type(self).__name__}: {k!r}"
    450                 )
    451             new_key.append(k)

TypeError: unexpected indexer type for VectorizedIndexer: dask.array<reshape, shape=(4, 2), dtype=int64, chunksize=(4, 2), chunktype=numpy.ndarray>

In [10]: xr.__version__                                                   
Out[10]: '0.14.1'

In [11]: import dask; dask.__version__                                    
Out[11]: '2.9.0'

@roxyboy
Copy link

roxyboy commented Dec 20, 2019

I'm just curious if there's been any progress on this issue. I'm also getting the same error: TypeError: unexpected indexer type for VectorizedIndexer and I would greatly benefit from lazy vectorized indexing.

@dcherian
Copy link
Contributor

I don't think any one is working on it. We would appreciate it if you could try to fix it.

@bzah
Copy link
Contributor

bzah commented Sep 20, 2021

I wrote a very naive fix, it works but seems to perform really slowly, I would appreciate some feedback (I'm a beginner with Dask).
Basically, I added k = dask.array.asarray(k, dtype=np.int64) to do the exact same thing as with numpy.
I can create a PR if it's better to review this

The patch:

class VectorizedIndexer(ExplicitIndexer):
    """Tuple for vectorized indexing.

    All elements should be slice or N-dimensional np.ndarray objects with an
    integer dtype and the same number of dimensions. Indexing follows proposed
    rules for np.ndarray.vindex, which matches NumPy's advanced indexing rules
    (including broadcasting) except sliced axes are always moved to the end:
    https://github.com/numpy/numpy/pull/6256
    """

    __slots__ = ()

    def __init__(self, key):
        if not isinstance(key, tuple):
            raise TypeError(f"key must be a tuple: {key!r}")

        new_key = []
        ndim = None
        for k in key:
            if isinstance(k, slice):
                k = as_integer_slice(k)
            elif isinstance(k, np.ndarray) or isinstance(k, dask.array.Array):
                if not np.issubdtype(k.dtype, np.integer):
                    raise TypeError(
                        f"invalid indexer array, does not have integer dtype: {k!r}"
                    )
                if ndim is None:
                    ndim = k.ndim
                elif ndim != k.ndim:
                    ndims = [k.ndim for k in key if isinstance(k, np.ndarray)]
                    raise ValueError(
                        "invalid indexer key: ndarray arguments "
                        f"have different numbers of dimensions: {ndims}"
                    )
                if isinstance(k, dask.array.Array):
                    k = dask.array.asarray(k, dtype=np.int64)
                else:
                    k = np.asarray(k, dtype=np.int64)
            else:
                raise TypeError(
                    f"unexpected indexer type for {type(self).__name__}: {k!r}"
                )
            new_key.append(k)

        super().__init__(new_key)

@pl-marasco
Copy link

@bzah I've been testing your solution and doesn't seems to slow as you are mentioning. Do you have a specific test to be conducted so that we can make a more robust comparison?

@bzah

This comment has been minimized.

@pl-marasco
Copy link

@pl-marasco Ok that's strange. I should have saved my use case :/ I will try to reproduce it and will provide a gist of it soon.

What I noticed, on my use case, is that it provoke a computation. Is that the reason for what you consider slow? Could be possible that is related to #3237 ?

@bzah

This comment has been minimized.

@pl-marasco
Copy link

@bzah I tested your patch with the following code:

import xarray as xr
from distributed import Client
client = Client()

da = xr.DataArray(np.random.rand(20*3500*3500).reshape((20,3500,3500)), dims=('time', 'x', 'y')).chunk(dict(time=-1, x=100, y=100))

idx = da.argmax('time').compute()
da.isel(time=idx)

In my case seems that with or without it takes the same time but I would like to know if is the same for you.

L.

@bzah
Copy link
Contributor

bzah commented Oct 1, 2021

@pl-marasco Thanks for the example !
With it I have the same result as you, it takes the same time with patch or with compute.

However, I could construct an example giving very different results. It is quite close to my original code:

    time_start = time.perf_counter()
    COORDS = dict(
        time=pd.date_range("2042-01-01", periods=200,
                           freq=pd.DateOffset(days=1)),
    )
    da = xr.DataArray(
        np.random.rand(200 * 3500 * 350).reshape((200, 3500, 350)),
        dims=('time', 'x', 'y'),
        coords=COORDS
    ).chunk(dict(time=-1, x=100, y=100))

    resampled = da.resample(time="MS")

    for label, sample in resampled:
        # sample = sample.compute()
        idx = sample.argmax('time')
        sample.isel(time=idx)

    time_elapsed = time.perf_counter() - time_start
    print(time_elapsed, " secs")

(Basically I want for each month the first event occurring in it).

Without the patch and uncommenting sample = sample.compute(), it takes 5.7 secs.
With the patch it takes 53.9 seconds.

@cerodell
Copy link

cerodell commented Oct 1, 2021

Hello! First off thank you for all the hard work on xarray! Use it every day and love it :)

I am also having issues indexing with dask arrays and get the following error.

 Traceback (most recent call last):
  File "~/phd-comps/scripts/sfire-pbl.py", line 64, in <module>
    PBLH = height.isel(gradT2.argmax(dim=['interp_level']))
  File "~/miniconda3/envs/cr/lib/python3.7/site-packages/xarray/core/dataarray.py", line 1184, in isel
    indexers, drop=drop, missing_dims=missing_dims
  File "~/miniconda3/envs/cr/lib/python3.7/site-packages/xarray/core/dataset.py", line 2389, in _isel_fancy
    new_var = var.isel(indexers=var_indexers)
  File "~/miniconda3/envs/cr/lib/python3.7/site-packages/xarray/core/variable.py", line 1156, in isel
    return self[key]
  File "~/miniconda3/envs/cr/lib/python3.7/site-packages/xarray/core/variable.py", line 776, in __getitem__
    dims, indexer, new_order = self._broadcast_indexes(key)
  File "~/miniconda3/envs/cr/lib/python3.7/site-packages/xarray/core/variable.py", line 632, in _broadcast_indexes
    return self._broadcast_indexes_vectorized(key)
  File "~/miniconda3/envs/cr/lib/python3.7/site-packages/xarray/core/variable.py", line 761, in _broadcast_indexes_vectorized
    return out_dims, VectorizedIndexer(tuple(out_key)), new_order
  File "~/miniconda3/envs/cr/lib/python3.7/site-packages/xarray/core/indexing.py", line 323, in __init__
    f"unexpected indexer type for {type(self).__name__}: {k!r}"
TypeError: unexpected indexer type for VectorizedIndexer: dask.array<getitem, shape=(240, 399, 159), dtype=int64, chunksize=(60, 133, 53), chunktype=numpy.ndarray>

dask                      2021.9.1           pyhd8ed1ab_0    conda-forge
xarray                    0.19.0             pyhd8ed1ab_0    conda-forge

In order to get it to work, I first need to manually call compute to load to NumPy array before using argmax with isel. Not sure what info I can provide to help solve the issue please let me know and ill send whatever I can.

@pl-marasco
Copy link

@bzah I've been testing your code and I can confirm the increment of timing once the .compute() isn't in use.
I've noticed that using your modification, seems that dask array is computed more than one time per sample.
I've made some tests using a modified version from #3237 and here are my observations:

Assuming that we have only one sample object after the resample the expected result should be 1 compute and that's what we obtain if we call the computation before the .argmax()
If .compute() is removed then I got 3 total computations.
Just as a confirmation if you increase the sample you will get a multiple of 3 as a result of computes.

I still don't know the reason and if is correct or not but sounds weird to me; though it could explain the time increase.

@dcherian @shyer do you know if all this make any sense? should the .isel() automatically trig the computation or should give back a lazy array?

Here is the code I've been using (works only adding the modification proposed by @bzah)

import numpy as np
import dask
import xarray as xr

class Scheduler:
    """ From: https://stackoverflow.com/questions/53289286/ """

    def __init__(self, max_computes=20):
        self.max_computes = max_computes
        self.total_computes = 0

    def __call__(self, dsk, keys, **kwargs):
        self.total_computes += 1
        if self.total_computes > self.max_computes:
            raise RuntimeError(
                "Too many dask computations were scheduled: {}".format(
                    self.total_computes
                )
            )
        return dask.get(dsk, keys, **kwargs)

scheduler = Scheduler()

with dask.config.set(scheduler=scheduler):

    COORDS = dict(dim_0=pd.date_range("2042-01-01", periods=31, freq='D'),
                  dim_1= range(0,500),
                  dim_2= range(0,500))

    da = xr.DataArray(np.random.rand(31 * 500 * 500).reshape((31, 500, 500)),
                      coords=COORDS).chunk(dict(dim_0=-1, dim_1=100, dim_2=100))

    print(da)

    resampled = da.resample(dim_0="MS")

    for label, sample in resampled:

        #sample = sample.compute()
        idx = sample.argmax('dim_0')
        sampled = sample.isel(dim_0=idx)

    print("Total number of computes: %d" % scheduler.total_computes)

@bzah
Copy link
Contributor

bzah commented Oct 15, 2021

I'll drop a PR, it might be easier to try and play with this than a piece of code lost in an issue.

@dcherian
Copy link
Contributor

IIUC this cannot work lazily in most cases if you have dimension coordinate variables. When xarray constructs the output after indexing, it will try to index those coordinate variables so that it can associate the right timestamp (for e.g) with the output.

The example from @ulijh should work though (it has no dimension coordinate or indexed variables)

import xarray as xr
import dask.array as da
import numpy as np
da = xr.DataArray(np.random.rand(3*4*5).reshape((3,4,5))).chunk(dict(dim_0=1))
idcs = da.argmax('dim_2')
da[dict(dim_2=idcs)]

The example by @rafa-guedes (thanks for that one!) could be made to work I think.

import numpy as np
import dask.array as da
import xarray as xr

darr = xr.DataArray(data=[0.2, 0.4, 0.6], coords={"z": range(3)}, dims=("z",))
good_indexer = xr.DataArray(
    data=np.random.randint(0, 3, 8).reshape(4, 2).astype(int),
    coords={"y": range(4), "x": range(2)},
    dims=("y", "x")
)
bad_indexer = xr.DataArray(
    data=da.random.randint(0, 3, 8).reshape(4, 2).astype(int),
    coords={"y": range(4), "x": range(2)},
    dims=("y", "x")
)

In [5]: darr                                                                                                                                                            
Out[5]: 
<xarray.DataArray (z: 3)>
array([0.2, 0.4, 0.6])
Coordinates:
  * z        (z) int64 0 1 2

In [6]: good_indexer                                                     
Out[6]: 
<xarray.DataArray (y: 4, x: 2)>
array([[0, 1],
       [2, 2],
       [1, 2],
       [1, 0]])
Coordinates:
  * y        (y) int64 0 1 2 3
  * x        (x) int64 0 1

In [7]: bad_indexer                                                      
Out[7]: 
<xarray.DataArray 'reshape-417766b2035dcb1227ddde8505297039' (y: 4, x: 2)>
dask.array<reshape, shape=(4, 2), dtype=int64, chunksize=(4, 2), chunktype=numpy.ndarray>
Coordinates:
  * y        (y) int64 0 1 2 3
  * x        (x) int64 0 1

In [8]: darr[good_indexer]                                                                                                                                              
Out[8]: 
<xarray.DataArray (y: 4, x: 2)>
array([[0.2, 0.4],
       [0.6, 0.6],
       [0.4, 0.6],
       [0.4, 0.2]])
Coordinates:
    z        (y, x) int64 0 1 2 2 1 2 1 0
  * y        (y) int64 0 1 2 3
  * x        (x) int64 0 1

We can copy the dimension coordinates of the output (x,y) directly from the indexer. And the dimension coordinate on the input (z) should be a dask array in the output (since z is not a dimension coordinate in the output, this should be fine)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
8 participants