-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax
and xarray
integration for automatic differentiation?
#17107
Comments
Hi - I think I recall this kind of thing coming up before – I don't know of any effort to do the full integration of JAX and xarray that you have in mind. The problem is that xarray is fundamentally built on the assumption that its arrays are numpy arrays, and so for example To move forward, either one of two things would have to happen:
Short of a team of people undertaking one of those very big projects, I don't think there's any good way to do what you have in mind. |
I think would be quite exciting! I think the Python Array API standard would probably be the way to go. Xarray's support for the API standard is pretty close to complete, and most missing features would not be hard to add. Xarray in fact already supports wrapping many types of non-NumPy arrays so this supporting JAX arrays as well would not be a big lift. To get Xarray objects working with JAX transforms like Deepmind's GraphCast project contains a bundled Xarray-JAX wrapper, which I think already does some verison of both of these (maybe in a non-ideal way): |
(Side note: for the Array API approach, we'd also have to land some version of #16099 to make JAX compliant) |
CC @mjwillson who wrote the Xarray-JAX wrapper in GraphCast. |
Thanks @shoyer ! I'll have to study that graph cast code, I tried something similar but never could get it working properly. |
I played around a bit with this GraphCast wrapper. It worked for the intended use case of applying Unfortunately for It's pretty likely I'm doing something wrong here so if @mjwilson / @shoyer spots something wrong here let me know! import numpy as np
import jax
import jax.numpy as jnp
import xarray as xr
shape = (3, 4, 5)
values = np.random.random(shape)
coords = {dim: np.arange(length).tolist() for dim, length in zip('xyz', shape)}
xarr_jax = DataArray(values, dims=('x', 'y', 'z'), coords=coords) # note: GraphCast wrapper class
def f(x):
val = x * xarr_jax
val = val.interp(x=1, y=1, z=1)
val = jnp.array(val.data)
return jnp.sum(val)
f(1.0) # works
jax.grad(f)(1.0) # TracerArrayConversionError
|
The error is happening because the gradient computation results in calling There's no way to fix this without changing how xarray is implemented. |
This isn't true -- xarray supports a number of duck arrays. As soon as JAX implements If you use the GraphCast Xarray-JAX wrapper, you need to use its special constructors for DataArray/Dataset. |
Oh, good to know! Progress on |
Could you explain a bit more?
|
@jakevdp Indeed, Xarray doesn't rely on the mutation APIs (unless a user tries to mutate an array) @tylerflex I see, it looks like you were already using the GraphCast wrapper. I don't know exactly what's going on, then. |
Hiya, Firstly just to note that xarray_jax isn't something we're officially supporting outside the GraphCast project for now, as it does have some rough edges and is in part a bit of a stop-gap measure until JAX supports the new array protocol which will allow it to integrate better with xarray. That said, about your example, you'll find the following very similar code works: import numpy as np
import jax
import jax.numpy as jnp
import xarray as xr
from graphcast import xarray_jax
shape = (3, 4, 5)
values = jnp.asarray(np.random.random(shape))
coords = {dim: np.arange(length).tolist() for dim, length in zip('xyz', shape)}
xarr_jax = xarray_jax.DataArray(values, dims=('x', 'y', 'z'), coords=coords)
def f(x):
val = x * xarr_jax
val = xarray_jax.unwrap_data(val)
return jnp.sum(val)
f(1.0)
jax.jit(f)(1.0)
jax.grad(f)(1.0) Some issues in your code were:
|
I've been wondering if there has been any recent progress in integrating
jax
andxarray
, specifically for automatic differentiation. For context, we have a simulation project that relies onxarray
for our simulation output data but recently addedjax
support so users can automatically differentiate through these simulations. To make this work, we added code to emulatexr.DataArray
functionality but withjax
internals. However, this approach has been a headache to maintain and extend. It would be amazing ifxarray
had native support for gradient tracking in jax.As an example, the code snippet below multiplies a Jax-traced value by an
xarray.DataArray
, does an interpolation, and then a jax-traved operation. It would be great if we could differentiate through this. The forward pass works, but the backwards pass gives aTracerArrayConversionError
.I've tried many other workarounds based on issues, such as this and some other discussions eg but without any luck. Are any updates on the status of this, whether it would be possible eventually, or suggestions for possible workarounds? Any discussion or pointers towards a good approach to this are really appreciated.
@shoyer
The text was updated successfully, but these errors were encountered: