Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Namespace-aware xarray.ufuncs #9776

Merged
merged 16 commits into from
Nov 18, 2024
Merged

Conversation

slevang
Copy link
Contributor

@slevang slevang commented Nov 13, 2024

Re-implement the old xarray.ufuncs module to allow generic ufunc handling for array types that don't implement __array_ufunc__:

import jax.numpy as jnp
import numpy as np
import xarray as xr
import xarray.ufuncs as xu

x = xr.DataArray(jnp.asarray([1, 2, 3]))
print(type(xu.sin(x).data))
print(type(np.sin(x).data))

# <class 'jaxlib.xla_extension.ArrayImpl'>
# <class 'numpy.ndarray'>

elif isinstance(obj, array_type("dask")):
_walk_array_namespaces(obj._meta, namespaces)
else:
namespace = getattr(obj, "__array_namespace__", None)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we ever prioritize dispatching with np.func via __array_ufunc__ (if it exists) over the library's __array_namespace__().func?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__array_ufunc__ is a more generic protocol, intended to support arbitrary new ufuncs without requiring an array library to be aware of them.

In practice there are very few examples of ufuncs defined outside of NumPy itself, and we wouldn't need to support them here because we are explicitly listing supported ufuncs in this module. I guess the one example that comes to mind would be the rare cases where NumPy adds a new ufunc and an array wrapping library like Dask hasn't written a wrapper yet.

At this point, I think going "all in" on __array_namespace__ is the right call.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess my other comment would be the main reason to consider __array_ufunc__. Some duck arrays don't implement all ufuncs. So either of these approaches would solve the same problem.

xarray/ufuncs.py Outdated
)
func = getattr(np, self._name)

return xr.apply_ufunc(func, *args, dask="parallelized", **kwargs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there ever a reason to use dask's ufuncs with dask="allowed" instead of the appropriate _meta array's namespace and dask="parallelized"? With jax for example, which doesn't have __array_ufunc__, this ends up converting to numpy. So it would have to be special cased.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In user code using xr.apply_ufunc there is - dask='allowed' can be used to rechunk along a core dimension e.g. by applying a dask reduction ufunc along that dimension. Not sure if that's relevant here though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are all elementwise so no core dimensions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there ever a reason to use dask's ufuncs with dask="allowed" instead of the appropriate _meta array's namespace and dask="parallelized"?

Yes, this feels like a cleaner solution to me.

With jax for example, which doesn't have array_ufunc, this ends up converting to numpy. So it would have to be special cased

Is the concern here Dask wrapping JAX?

Generally, I think it's best for xarray to avoid introspecting into wrapped array types in Xarray, and leave nested wrapping up to other libraries, which can better understand their own implementation details. So Dask wrapping JAX should be fixed in Dask.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So Dask wrapping JAX should be fixed in Dask.

Totally fair. It looks like basically the same effort here would be required in dask then, because dask's ufuncs are all simple wrappers around the numpy version so they aren't aware of the namespace.

xarray/ufuncs.py Outdated


# Auto generate from the public numpy ufuncs
np_ufuncs = {name for name in dir(np) if isinstance(getattr(np, name), np.ufunc)}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can hard code these if preferred?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I would suggest hard coding these if possible, ideally as something like:

sin = _unary_ufunc('sin')
add = _binary_ufunc('add')
...

The reason to prefer this is that static typing does not evaluate loops or dynamically defined functions. So otherwise xarray.ufuncs.sin() will not be recognized as valid by tools like mypy.

Ideally (could be done in a follow-up PR), these functions could be annotated to follow the appropriate type casting rules, so type checkers would know that xarray.ufuncs.sin(dataset) returns another Dataset.

In some cases, we use a script to generate all these special methods. I don't think that should be necessary here, but it still may be worth a look to understand the type casting rules:
https://github.com/pydata/xarray/blob/d5f84dd1ef4c023cf2ea0a38866c9d9cd50487e7/xarray/util/generate_ops.py
https://github.com/pydata/xarray/blob/d5f84dd1ef4c023cf2ea0a38866c9d9cd50487e7/xarray/core/_typed_ops.py

xarray/ufuncs.py Outdated

# Auto generate from the public numpy ufuncs
np_ufuncs = {name for name in dir(np) if isinstance(getattr(np, name), np.ufunc)}
excluded_ufuncs = {"divmod", "frexp", "isnat", "matmul", "modf", "vecdot"}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the ones that didn't immediately work. There are also other ufunc like things that aren't technically np.ufunc subclasses that we could add. I saw angle and iscomplex were special cased before.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth noting that the reason why matmul and vecdot doesn't work is that they are "generalized ufuncs" that use core dimensions.

divmod, frexp and modf doesn't work because they return multiple arrays.

I'm not sure why isnat didn't work for you. Did you test it with datetime dtypes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you test it with datetime dtypes?

No, this was just a real quick initial pass. Will add this with a special test case.

Do you have an opinion about adding any of the ones with multiple return values? Seems low priority to me.

Same question for the odd balls like angle, iscomplex, isreal, etc?

@slevang slevang marked this pull request as ready for review November 13, 2024 14:17
@TomNicholas TomNicholas added topic-arrays related to flexible array support array API standard Support for the Python array API standard labels Nov 14, 2024
@dcherian dcherian requested a review from keewis November 15, 2024 16:23
xarray/ufuncs.py Outdated
Comment on lines 55 to 60
if func is None:
warnings.warn(
f"Function {self._name} not found in {xp.__name__}, falling back to numpy",
stacklevel=2,
)
func = getattr(np, self._name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would lean towards skipping this fall-back, unless there are particularly motivating cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Motivation here would be duck arrays that implement __array_ufunc__ and don't implement the full suite of numpy ufuncs. I ran into this with sparse. Not sure the full delta list, but I see they don't have sin/cos for example. In this case, np.cos(x_sparse) works but xp.cos(x_sparse) fails, which is a little weird. Not the most elegant solution though, I agree.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's intentional than xp.cos(x_sparse) fails, because cos(0) != 0, so the result is no longer sparse.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I was wrong, there is a sparse.cos and a bunch of others, although they don't appear in the API docs. It seems sparse's general approach to these is compute elementwise on the valid data, and then modify the fill_value as required for the empty regions.

There are still 40-some functions that fail without this fallback, although generally more niche. With the fallback, all work and output a sparse array (no auto densification):

absolute, arccos, arccosh, arcsin, arcsinh, arctan, arctan2, arctanh, bitwise_count, cbrt, conj, conjugate, copysign, deg2rad, degrees, exp2, expm1x, fabs, float_power, fmax, fmin, fmod, gcd, heaviside, hypot, invert, isreal, lcm, ldexp, left_shift, logaddexp2, maximum, minimum, mod, nextafter, power, rad2deg, radians, reciprocal, right_shift, rint, signbit, spacing, true_divide

xarray/ufuncs.py Outdated
Comment on lines 24 to 25
elif isinstance(obj, array_type("dask")):
_walk_array_namespaces(obj._meta, namespaces)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not sure Dask's __array_namespace__ instead? That feels a little cleaner than special case logic for dask.array.

xarray/ufuncs.py Outdated
)
func = getattr(np, self._name)

return xr.apply_ufunc(func, *args, dask="parallelized", **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there ever a reason to use dask's ufuncs with dask="allowed" instead of the appropriate _meta array's namespace and dask="parallelized"?

Yes, this feels like a cleaner solution to me.

With jax for example, which doesn't have array_ufunc, this ends up converting to numpy. So it would have to be special cased

Is the concern here Dask wrapping JAX?

Generally, I think it's best for xarray to avoid introspecting into wrapped array types in Xarray, and leave nested wrapping up to other libraries, which can better understand their own implementation details. So Dask wrapping JAX should be fixed in Dask.

xarray/ufuncs.py Outdated


# Auto generate from the public numpy ufuncs
np_ufuncs = {name for name in dir(np) if isinstance(getattr(np, name), np.ufunc)}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I would suggest hard coding these if possible, ideally as something like:

sin = _unary_ufunc('sin')
add = _binary_ufunc('add')
...

The reason to prefer this is that static typing does not evaluate loops or dynamically defined functions. So otherwise xarray.ufuncs.sin() will not be recognized as valid by tools like mypy.

Ideally (could be done in a follow-up PR), these functions could be annotated to follow the appropriate type casting rules, so type checkers would know that xarray.ufuncs.sin(dataset) returns another Dataset.

In some cases, we use a script to generate all these special methods. I don't think that should be necessary here, but it still may be worth a look to understand the type casting rules:
https://github.com/pydata/xarray/blob/d5f84dd1ef4c023cf2ea0a38866c9d9cd50487e7/xarray/util/generate_ops.py
https://github.com/pydata/xarray/blob/d5f84dd1ef4c023cf2ea0a38866c9d9cd50487e7/xarray/core/_typed_ops.py

xarray/ufuncs.py Outdated Show resolved Hide resolved
elif isinstance(obj, array_type("dask")):
_walk_array_namespaces(obj._meta, namespaces)
else:
namespace = getattr(obj, "__array_namespace__", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__array_ufunc__ is a more generic protocol, intended to support arbitrary new ufuncs without requiring an array library to be aware of them.

In practice there are very few examples of ufuncs defined outside of NumPy itself, and we wouldn't need to support them here because we are explicitly listing supported ufuncs in this module. I guess the one example that comes to mind would be the rare cases where NumPy adds a new ufunc and an array wrapping library like Dask hasn't written a wrapper yet.

At this point, I think going "all in" on __array_namespace__ is the right call.

xarray/ufuncs.py Outdated

# Auto generate from the public numpy ufuncs
np_ufuncs = {name for name in dir(np) if isinstance(getattr(np, name), np.ufunc)}
excluded_ufuncs = {"divmod", "frexp", "isnat", "matmul", "modf", "vecdot"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth noting that the reason why matmul and vecdot doesn't work is that they are "generalized ufuncs" that use core dimensions.

divmod, frexp and modf doesn't work because they return multiple arrays.

I'm not sure why isnat didn't work for you. Did you test it with datetime dtypes?

xarray/ufuncs.py Outdated
def __init__(self, name):
self._name = name

def __call__(self, x, **kwargs):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shoyer lmk if this is what you had in mind for a separating unary from binary funcs. I assume this is for typing purposes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Close! I added a more detailed note below.

xarray/ufuncs.py Outdated
Comment on lines 116 to 140
def _create_op(name):
if not hasattr(np, name):
# handle older numpy versions with missing array api standard aliases
if np.lib.NumpyVersion(np.__version__) < "2.0.0":
return _UnavailableUfunc(name)
raise ValueError(f"'{name}' is not a valid numpy function")

np_func = getattr(np, name)
if hasattr(np_func, "nin") and np_func.nin == 2:
func = _BinaryUfunc(name)
else:
func = _UnaryUfunc(name)

func.__name__ = name
doc = getattr(np, name).__doc__

doc = _remove_unused_reference_labels(_skip_signature(_dedent(doc), name))

func.__doc__ = (
f"xarray specific variant of numpy.{name}. Handles "
"xarray objects by dispatching to the appropriate "
"function for the underlying array type.\n\n"
f"Documentation from numpy:\n\n{doc}"
)
return func
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type checkers can't evaluate code at runtime, so you really want to write something that prescribes the type signature statically, like:

abs = _UnaryUfunc('abs')

or

def _unary_ufunc(name: str) -> _UnaryUfunc:
    func = _UnaryUfunc(name)
    func.__doc__ = ...
    return func

abs = _unary_ufunc('abs')

When you write abs = _create_op('abs'), type checkers think the type of abs could be any of _UnavailableUfunc or _BinaryUfunc or _UnaryUfunc. (You can check this with reveal_type(abs) if you're curious.)

xarray/ufuncs.py Outdated Show resolved Hide resolved
xarray/ufuncs.py Outdated Show resolved Hide resolved
xarray/ufuncs.py Outdated
def __init__(self, name):
self._name = name

def __call__(self, x, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Close! I added a more detailed note below.

xarray/ufuncs.py Outdated Show resolved Hide resolved
@shoyer
Copy link
Member

shoyer commented Nov 17, 2024

This looks great! Could you please also add a brief note to whats-new.rst?

@slevang
Copy link
Contributor Author

slevang commented Nov 17, 2024

This looks great! Could you please also add a brief note to whats-new.rst?

Yes will do. Shall we add these back to the main api docs as well, same as before minus the deprecation notice?

@shoyer
Copy link
Member

shoyer commented Nov 18, 2024 via email

@dcherian dcherian mentioned this pull request Nov 18, 2024
6 tasks
xarray/ufuncs.py Outdated Show resolved Hide resolved
@dcherian dcherian merged commit 5d70f4d into pydata:main Nov 18, 2024
29 checks passed
@slevang slevang mentioned this pull request Nov 18, 2024
4 tasks
dcherian added a commit that referenced this pull request Nov 19, 2024
* main: (24 commits)
  Bump minimum versions (#9796)
  Namespace-aware `xarray.ufuncs` (#9776)
  Add prettier and pygrep hooks to pre-commit hooks (#9644)
  `rolling.construct`: Add `sliding_window_kwargs` to pipe arguments down to `sliding_window_view` (#9720)
  Bump codecov/codecov-action from 4.6.0 to 5.0.2 in the actions group (#9793)
  Buffer types (#9787)
  Add download stats badges (#9786)
  Fix open_mfdataset for list of fsspec files (#9785)
  add 'User-Agent'-header to pooch.retrieve (#9782)
  Optimize `ffill`, `bfill` with dask when `limit` is specified (#9771)
  fix cf decoding of grid_mapping (#9765)
  Allow wrapping `np.ndarray` subclasses (#9760)
  Optimize polyfit (#9766)
  Use `map_overlap` for rolling reductions with Dask (#9770)
  fix html repr indexes section (#9768)
  Bump pypa/gh-action-pypi-publish from 1.11.0 to 1.12.2 in the actions group (#9763)
  unpin array-api-strict, as issues are resolved upstream (#9762)
  rewrite the `min_deps_check` script (#9754)
  CI runs ruff instead of pep8speaks (#9759)
  Specify copyright holders in main license file (#9756)
  ...
dcherian added a commit to dcherian/xarray that referenced this pull request Nov 19, 2024
* main:
  Bump minimum versions (pydata#9796)
  Namespace-aware `xarray.ufuncs` (pydata#9776)
  Add prettier and pygrep hooks to pre-commit hooks (pydata#9644)
  `rolling.construct`: Add `sliding_window_kwargs` to pipe arguments down to `sliding_window_view` (pydata#9720)
  Bump codecov/codecov-action from 4.6.0 to 5.0.2 in the actions group (pydata#9793)
  Buffer types (pydata#9787)
  Add download stats badges (pydata#9786)
  Fix open_mfdataset for list of fsspec files (pydata#9785)
  add 'User-Agent'-header to pooch.retrieve (pydata#9782)
  Optimize `ffill`, `bfill` with dask when `limit` is specified (pydata#9771)
  fix cf decoding of grid_mapping (pydata#9765)
  Allow wrapping `np.ndarray` subclasses (pydata#9760)
  Optimize polyfit (pydata#9766)
  Use `map_overlap` for rolling reductions with Dask (pydata#9770)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array API standard Support for the Python array API standard topic-arrays related to flexible array support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Compatibility with the Array API standard
4 participants