Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

argmax() causes dask to compute #3237

Closed
ulijh opened this issue Aug 21, 2019 · 4 comments · Fixed by #3244
Closed

argmax() causes dask to compute #3237

ulijh opened this issue Aug 21, 2019 · 4 comments · Fixed by #3244
Labels

Comments

@ulijh
Copy link
Contributor

ulijh commented Aug 21, 2019

Problem Description

While digging for #2511 I found that da.argmax() causes compute on a dask array in nanargmax(a, axis=None):

if mask.any():

I feel like this shouldn't be the case as da.max() and da.data.argmax() don't compute and it renders the laziness useless.

MCVE Code Sample

In [1]: import numpy as np                                                                                                             
   ...: import dask                                                                                                                    
   ...: import xarray as xr                                                                                                            
                                                                                                                                       
In [2]: class Scheduler:                                                                                                               
   ...:     """ From: https://stackoverflow.com/questions/53289286/ """                                                                
   ...:                                                                                                                                
   ...:     def __init__(self, max_computes=0):                                                                                        
   ...:         self.max_computes = max_computes                                                                                       
   ...:         self.total_computes = 0                                                                                                
   ...:                                                                                                                                
   ...:     def __call__(self, dsk, keys, **kwargs):                                                                                   
   ...:         self.total_computes += 1                                                                                               
   ...:         if self.total_computes > self.max_computes:                                                                            
   ...:             raise RuntimeError(                                                                                                
   ...:                 "Too many dask computations were scheduled: {}".format(                                                        
   ...:                     self.total_computes                                                                                        
   ...:                 )                                                                                                              
   ...:             )                                                                                                                  
   ...:         return dask.get(dsk, keys, **kwargs)                                                                                   
   ...:                                                                                                                                
                                                                                                                                       
In [3]: scheduler = Scheduler()                                                                                                        
                                                                                                                                       
In [4]: with dask.config.set(scheduler=scheduler):                                                                                     
   ...:     da = xr.DataArray(                                                                                                         
   ...:         np.random.rand(2*3*4).reshape((2, 3, 4)),                                                                              
   ...:     ).chunk(dict(dim_0=1))                                                                                                     
   ...:                                                                                                                                
   ...:     dim = da.dims[-1]                                                                                                          
   ...:                                                                                                                                
   ...:     dask_idcs = da.data.argmax(axis=-1)  # Computes 0 times                                                                    
   ...:     print("Total number of computes: %d" % scheduler.total_computes)                                                           
   ...:                                                                                                                                
   ...:     da.max(dim)  # Computes 0 times                                                                                            
   ...:     print("Total number of computes: %d" % scheduler.total_computes)                                                           
   ...:                                                                                                                                
   ...:     da.argmax(dim)  # Does compute                                                                                             
   ...:     print("Total number of computes: %d" % scheduler.total_computes)                                                           
   ...:               
Total number of computes: 0                                                                                                                                                                                                                                                       
Total number of computes: 0                                                                                                                                                                                                                                                       
---------------------------------------------------------------------------                                                              
RuntimeError                              Traceback (most recent call last)                                                              
<ipython-input-4-f95c8753dbe6> in <module>                          
     12     print("Total number of computes: %d" % scheduler.total_computes)                                                             
     13                                                             
---> 14     da.argmax(dim)  # Does compute                          
     15     print("Total number of computes: %d" % scheduler.total_computes)                                                             
     16                                                             

~/src/xarray/xarray/core/common.py in wrapped_func(self, dim, axis, skipna, **kwargs)                                                    
     42             def wrapped_func(self, dim=None, axis=None, skipna=None, **kwargs):                                                  
     43                 return self.reduce(                         
---> 44                     func, dim, axis, skipna=skipna, allow_lazy=True, **kwargs                                                    
     45                 )                                           
     46                                                             

~/src/xarray/xarray/core/dataarray.py in reduce(self, func, dim, axis, keep_attrs, keepdims, **kwargs)                                   
   2120         """                                                 
   2121                                                             
-> 2122         var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims, **kwargs)                                              
   2123         return self._replace_maybe_drop_dims(var)           
   2124                                                             

~/src/xarray/xarray/core/variable.py in reduce(self, func, dim, axis, keep_attrs, keepdims, allow_lazy, **kwargs)                        
   1456         input_data = self.data if allow_lazy else self.values                                                                    
   1457         if axis is not None:                                
-> 1458             data = func(input_data, axis=axis, **kwargs)    
   1459         else:                                               
   1460             data = func(input_data, **kwargs)               

~/src/xarray/xarray/core/duck_array_ops.py in f(values, axis, skipna, **kwargs)                                                          
    279                                                             
    280         try:                                                
--> 281             return func(values, axis=axis, **kwargs)        
    282         except AttributeError:                              
    283             if isinstance(values, dask_array_type):         

~/src/xarray/xarray/core/nanops.py in nanargmax(a, axis)            
    118     if mask is not None:                                    
    119         mask = mask.all(axis=axis)                          
--> 120         if mask.any():                                      
    121             raise ValueError("All-NaN slice encountered")   
    122     return res                                              

/usr/lib/python3.7/site-packages/dask/array/core.py in __bool__(self)                                                                    
   1370             )                                               
   1371         else:                                               
-> 1372             return bool(self.compute())                     
   1373                                                             
   1374     __nonzero__ = __bool__  # python 2                      

/usr/lib/python3.7/site-packages/dask/base.py in compute(self, **kwargs)                                                                 
    173         dask.base.compute                                   
    174         """                                                 
--> 175         (result,) = compute(self, traverse=False, **kwargs) 
    176         return result                                       
    177                                                             

/usr/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs)                                                                
    444     keys = [x.__dask_keys__() for x in collections]         
    445     postcomputes = [x.__dask_postcompute__() for x in collections]                                                               
--> 446     results = schedule(dsk, keys, **kwargs)                 
    447     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])                                                        
    448        

Expected Output

None of the methods should actually compute:

Total number of computes: 0                                                                                                                                                                                                                                                    
Total number of computes: 0                                                                                                                                                                                                   
Total number of computes: 0   

Output of xr.show_versions()

INSTALLED VERSIONS ------------------ commit: None python: 3.7.4 (default, Jul 16 2019, 07:12:58) [GCC 9.1.0] python-bits: 64 OS: Linux OS-release: 5.2.9-arch1-1-ARCH machine: x86_64 processor: byteorder: little LC_ALL: None LANG: de_DE.utf8 LOCALE: de_DE.UTF-8 libhdf5: 1.10.5 libnetcdf: 4.7.0

xarray: 0.12.3+63.g131f6022
pandas: 0.25.0
numpy: 1.17.0
scipy: 1.3.1
netCDF4: 1.5.1.2
pydap: None
h5netcdf: 0.7.4
h5py: 2.9.0
Nio: None
zarr: None
cftime: 1.0.3.4
nc_time_axis: None
PseudoNetCDF: None
rasterio: 1.0.25
cfgrib: None
iris: None
bottleneck: 1.2.1
dask: 2.1.0
distributed: 1.27.1
matplotlib: 3.1.1
cartopy: 0.17.0
seaborn: 0.9.0
numbagg: None
setuptools: 41.0.1
pip: 19.0.3
conda: None
pytest: 5.0.1
IPython: 7.6.1
sphinx: 2.2.0

@shoyer shoyer added the bug label Aug 21, 2019
@shoyer
Copy link
Member

shoyer commented Aug 21, 2019

Yes, this is definitely a bug -- thanks for clear example to reproduce it!

These helper functions were originally added back in #1883 to handle object dtype arrays properly.

So it would be nice to fix this for object arrays in dask, but for the much more common non-object dtype arrays we should really just be using dask.array.nnargmax.

@ulijh
Copy link
Contributor Author

ulijh commented Aug 22, 2019

Those little changes do solve the MCVE, but break at least one test. I don't have enough of an understanding of the (nan)ops logic in xarray to get around the issue. But may be this helps:

The change

diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py
index 9ba4eae2..784a1d01 100644
--- a/xarray/core/nanops.py
+++ b/xarray/core/nanops.py
@@ -91,17 +91,9 @@ def nanargmin(a, axis=None):
     fill_value = dtypes.get_pos_infinity(a.dtype)
     if a.dtype.kind == "O":
         return _nan_argminmax_object("argmin", fill_value, a, axis=axis)
-    a, mask = _replace_nan(a, fill_value)
-    if isinstance(a, dask_array_type):
-        res = dask_array.argmin(a, axis=axis)
-    else:
-        res = np.argmin(a, axis=axis)
 
-    if mask is not None:
-        mask = mask.all(axis=axis)
-        if mask.any():
-            raise ValueError("All-NaN slice encountered")
-    return res
+    module = dask_array if isinstance(a, dask_array_type) else nputils
+    return module.nanargmin(a, axis=axis)
 
 
 def nanargmax(a, axis=None):
@@ -109,17 +101,8 @@ def nanargmax(a, axis=None):
     if a.dtype.kind == "O":
         return _nan_argminmax_object("argmax", fill_value, a, axis=axis)
 
-    a, mask = _replace_nan(a, fill_value)
-    if isinstance(a, dask_array_type):
-        res = dask_array.argmax(a, axis=axis)
-    else:
-        res = np.argmax(a, axis=axis)
-
-    if mask is not None:
-        mask = mask.all(axis=axis)
-        if mask.any():
-            raise ValueError("All-NaN slice encountered")
-    return res
+    module = dask_array if isinstance(a, dask_array_type) else nputils
+    return module.nanargmax(a, axis=axis)
 
 
 def nansum(a, axis=None, dtype=None, out=None, min_count=None):

The failing test

...                                                                                                
___________ TestVariable.test_reduce ________________                                                                                                                                                     
...

    def f(values, axis=None, skipna=None, **kwargs):
        if kwargs.pop("out", None) is not None:
            raise TypeError("`out` is not valid for {}".format(name))
    
        values = asarray(values)
    
        if coerce_strings and values.dtype.kind in "SU":
            values = values.astype(object)
    
        func = None
        if skipna or (skipna is None and values.dtype.kind in "cfO"):
            nanname = "nan" + name
            func = getattr(nanops, nanname)
        else:
            func = _dask_or_eager_func(name)
    
        try:
            return func(values, axis=axis, **kwargs)
        except AttributeError:
            if isinstance(values, dask_array_type):
                try:  # dask/dask#3133 dask sometimes needs dtype argument
                    # if func does not accept dtype, then raises TypeError
                    return func(values, axis=axis, dtype=values.dtype, **kwargs)
                except (AttributeError, TypeError):
                    msg = "%s is not yet implemented on dask arrays" % name
            else:
                msg = (
                    "%s is not available with skipna=False with the "
                    "installed version of numpy; upgrade to numpy 1.12 "
                    "or newer to use skipna=True or skipna=None" % name
                )
>           raise NotImplementedError(msg)
E           NotImplementedError: argmax is not available with skipna=False with the installed version of numpy; upgrade to numpy 1.12 or newer to use skipna=True or skipna=None

...

Note: I habe numpy 1.17 instaleed so the error msg here seems missleading.

@shoyer
Copy link
Member

shoyer commented Aug 22, 2019

Thanks for sharing the patch! I dropped into a debugger by adding --pdb to the pytest command, which revealed what is going on here:

>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> PDB post_mortem (IO-capturing turned off) >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
> /Users/shoyer/dev/xarray/xarray/core/duck_array_ops.py(295)f()
-> raise NotImplementedError(msg)
(Pdb) values
array(['2000-01-01T00:00:00.000000000', '2000-01-02T00:00:00.000000000',
       '2000-01-03T00:00:00.000000000'], dtype='datetime64[ns]')
(Pdb) func
<function nanargmax at 0x1156ff730>
(Pdb) func(values, axis=axis, **kwargs)
*** AttributeError: module 'xarray.core.nputils' has no attribute 'nanargmax'

So it looks like nputils doesn't have nanargmax defined. Instead we need to use nanargmax from NumPy.

ulijh added a commit to ulijh/xarray that referenced this issue Aug 22, 2019
@ulijh
Copy link
Contributor Author

ulijh commented Aug 22, 2019

Thanks @shoyer. Cool, then this was easier than I expected. I added the patch and nanargmax/min to the nputils in #3244. What do you think?

dcherian pushed a commit that referenced this issue Sep 6, 2019
* Make argmin/max work lazy with dask (#3237).

* dask: Testing number of computes on reduce methods.

* what's new updated

* Fix typo

Co-Authored-By: Stephan Hoyer <shoyer@gmail.com>

* Be more explicit.

Co-Authored-By: Stephan Hoyer <shoyer@gmail.com>

* More explicit raise_if_dask_computes

* nanargmin/max: only set fill_value when needed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants