-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for __array_function__ implementers (sparse arrays) [WIP] #3117
Conversation
Codecov Report
@@ Coverage Diff @@
## master #3117 +/- ##
==========================================
+ Coverage 95.7% 95.73% +0.03%
==========================================
Files 63 63
Lines 12861 12870 +9
==========================================
+ Hits 12308 12321 +13
+ Misses 553 549 -4 |
@nvictus, thank you for your work! I just tried this on CuPy arrays, and it seems to be working during array creation: In [1]: import cupy as cp
In [2]: import xarray as xr
In [3]: x = cp.arange(6).reshape(2, 3).astype('f')
In [4]: y = cp.ones((2, 3), dtype='int')
In [5]: x
Out[5]:
array([[0., 1., 2.],
[3., 4., 5.]], dtype=float32)
In [6]: y
Out[6]:
array([[1, 1, 1],
[1, 1, 1]])
In [7]: y.device
Out[7]: <CUDA Device 0>
In [8]: x.device
Out[8]: <CUDA Device 0>
In [9]: ds = xr.Dataset()
In [10]: ds['x'] = xr.DataArray(x, dims=['lat', 'lon'])
In [11]: ds['y'] = xr.DataArray(y, dims=['lat', 'lon'])
In [12]: ds
Out[12]:
<xarray.Dataset>
Dimensions: (lat: 2, lon: 3)
Dimensions without coordinates: lat, lon
Data variables:
x (lat, lon) float32 ...
y (lat, lon) int64 ...
In [13]: ds.x.data.device
Out[13]: <CUDA Device 0>
In [14]: ds.y.data.device
Out[14]: <CUDA Device 0> Even though it failed when I tried applying an operation on the dataset, this is still awesome! I am very excited and looking forward to seeing this feature in xarray: In [15]: m = ds.mean(dim='lat')
-------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-15-8e4d5e7d5ee3> in <module>
----> 1 m = ds.mean(dim='lat')
/glade/work/abanihi/devel/pangeo/xarray/xarray/core/common.py in wrapped_func(self, dim, skipna, **kwargs)
65 return self.reduce(func, dim, skipna=skipna,
66 numeric_only=numeric_only, allow_lazy=True,
---> 67 **kwargs)
68 else:
69 def wrapped_func(self, dim=None, **kwargs): # type: ignore
/glade/work/abanihi/devel/pangeo/xarray/xarray/core/dataset.py in reduce(self, func, dim, keep_attrs, keepdims, numeric_only, allow_lazy, **kwargs)
3532 keepdims=keepdims,
3533 allow_lazy=allow_lazy,
-> 3534 **kwargs)
3535
3536 coord_names = set(k for k in self.coords if k in variables)
/glade/work/abanihi/devel/pangeo/xarray/xarray/core/variable.py in reduce(self, func, dim, axis, keep_attrs, keepdims, allow_lazy, **kwargs)
1392 input_data = self.data if allow_lazy else self.values
1393 if axis is not None:
-> 1394 data = func(input_data, axis=axis, **kwargs)
1395 else:
1396 data = func(input_data, **kwargs)
/glade/work/abanihi/devel/pangeo/xarray/xarray/core/duck_array_ops.py in mean(array, axis, skipna, **kwargs)
370 return _to_pytimedelta(mean_timedeltas, unit='us') + offset
371 else:
--> 372 return _mean(array, axis=axis, skipna=skipna, **kwargs)
373
374
/glade/work/abanihi/devel/pangeo/xarray/xarray/core/duck_array_ops.py in f(values, axis, skipna, **kwargs)
257
258 try:
--> 259 return func(values, axis=axis, **kwargs)
260 except AttributeError:
261 if isinstance(values, dask_array_type):
/glade/work/abanihi/devel/pangeo/xarray/xarray/core/nanops.py in nanmean(a, axis, dtype, out)
158 return dask_array.nanmean(a, axis=axis, dtype=dtype)
159
--> 160 return np.nanmean(a, axis=axis, dtype=dtype)
161
162
/glade/work/abanihi/softwares/miniconda3/envs/xarray-tests/lib/python3.7/site-packages/numpy/core/overrides.py in public_api(*args, **kwargs)
163 relevant_args = dispatcher(*args, **kwargs)
164 return implement_array_function(
--> 165 implementation, public_api, relevant_args, args, kwargs)
166
167 if module is not None:
TypeError: no implementation found for 'numpy.nanmean' on types that implement __array_function__: [<class 'cupy.core.core.ndarray'>] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am really excited about this!
Yes, it really is! For this specific failure, we should think about adding an option for the default If someone is using xarray to wrap a computation oriented library like CuPy, they probably almost always want to set |
xarray/tests/test_nep18.py
Outdated
|
||
sparse = pytest.importorskip('sparse') | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Xarray does automatic alignment, so it would also be really good to add a test case with multiple duck arrays with different coordinate labels, to verify that alignment works. xarray/tests/test_dask.py
also has a bunch of good examples of things to test.
This version should be much more compatible out of the box with duck typing.
After writing more tests, turns out With a serendipitous shape and density of a sparse array, there were the right number of
A simple fix is to special-case Would it make sense to just assume that all non-DataArray NEP-18 compliant arrays do not contain an xarray-compliant |
Hmm, looks like the |
xarray/tests/test_nep18.py
Outdated
assert isinstance(func(A).data, sparse.SparseArray) | ||
|
||
|
||
class TestSparseVariable: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An aside (don't let this slow you down): these are tests that we could run across any eligible array, including the current numpy & dask arrays
You could pull tests from the current dask tests if that would be a shortcut to building up this test suite from scratch
We could do that by inheriting this class, or having var
as a fixture
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @max-sixty. I did copy them over from the dask tests. I'm still figuring out which ones are transferable to this use case since I'm not a regular xarray user myself (yet) :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! LMK if I can be helpful at all. Excited about the PR!
Yes, let's switch:
to
|
Thanks for bumping this @mrocklin! I've put in some extra work on my free time, which hasn't been pushed yet. I'll try to write up a summary of my findings today. Briefly though, it seems like the two limiting factors for NEP18 duck array support are:
I think NEP18-backed xarray structures can be supported in principle, but it won't prevent some operations from simply failing in some contexts. So maybe xarray will need to define a minimum required implementation subset of the array API for duck arrays. |
This is totally fine for now, as long as there are clear errors when attempting to do an unsupported operation. We can write unit tests with expected failures, which should provide a clear roadmap for things to fix upstream in sparse. We could attempt to define a minimum required implementation, but in practice I suspect this will be hard to nail down definitively. The ultimate determinant of what works will be xarray's implementation. |
So, tests are passing now and I've documented the expected failures on sparse arrays. :) As mentioned before, most fall into the categories of (1) implicit coercion to dense and (2) missing operations on sparse arrays. Turns out a lot of the implicit coercion is due to the use of routines from I also modified If |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great to me! I think we can merge this in after syncing again with master.
I probably wouldn't call this ready for widespread use until we get a few more features working (e.g., concat
), but this is a great start, and it will be easier for others to contribute once this is merged.
xarray/tests/test_sparse.py
Outdated
(do('astype', int), True), | ||
(do('broadcast_equals', make_xrarray({'x': 10, 'y': 5})), False), | ||
(do('clip', min=0, max=1), True), | ||
# (do('close'), False), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm guessing the commented out cases don't work yet? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, forgot to finish those! :)
I've added tests for all except those I couldn't figure out how to use like swap_dims
, sel_points
and isel_points
.
I think the right behavior is probably for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
xarray/tests/test_sparse.py
Outdated
def test_binary_op(self): | ||
assert_sparse_eq((2 * self.var).data, 2 * self.data) | ||
assert_sparse_eq((self.var + self.var).data, self.data + self.data) | ||
# assert_eq((self.var[0] + self.var).data, self.data[0] + self.data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove?
@nvictus we are good to go ahead and merge, and do follow-ups in other PRs? |
Sounds good! |
Thanks a lot, @nvictus |
Woot! Thanks @nvictus ! |
* master: pyupgrade one-off run (pydata#3190) mfdataset, concat now support the 'join' kwarg. (pydata#3102) reduce the size of example dataset in dask docs (pydata#3187) add climpred to related-projects (pydata#3188) bump rasterio to 1.0.24 in doc building environment (pydata#3186) More annotations (pydata#3177) Support for __array_function__ implementers (sparse arrays) [WIP] (pydata#3117) Internal clean-up of isnull() to avoid relying on pandas (pydata#3132) Call darray.compute() in plot() (pydata#3183) BUG: fix + test open_mfdataset fails on variable attributes with list… (pydata#3181)
@@ -21,6 +21,7 @@ dependencies: | |||
- pip | |||
- scipy | |||
- seaborn | |||
- sparse | |||
- toolz |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you exclude py36 and Windows?
Thanks @nvictus ! |
* master: pyupgrade one-off run (pydata#3190) mfdataset, concat now support the 'join' kwarg. (pydata#3102) reduce the size of example dataset in dask docs (pydata#3187) add climpred to related-projects (pydata#3188) bump rasterio to 1.0.24 in doc building environment (pydata#3186) More annotations (pydata#3177) Support for __array_function__ implementers (sparse arrays) [WIP] (pydata#3117) Internal clean-up of isnull() to avoid relying on pandas (pydata#3132) Call darray.compute() in plot() (pydata#3183) BUG: fix + test open_mfdataset fails on variable attributes with list… (pydata#3181)
Doing a SciPy sprint. Working towards #1375
Together with pydata/sparse#261, it seems to work.
whats-new.rst
for all changes andapi.rst
for new API