Skip to content

Commit

Permalink
Update examples and make sure they are tested
Browse files Browse the repository at this point in the history
  • Loading branch information
allen-adastra committed Oct 3, 2024
1 parent 5465c0c commit d1e2294
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 12 deletions.
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,26 @@ da = xr.DataArray(
@eqx.filter_jit
def some_function(data):
neg_data = -1.0 * data
return neg_data * neg_data.coords["y"] # Multiply data by coords.
return neg_data * neg_data.coords["y"] # Multiply data by coords.

da = some_function(da)

# Construct a xr.DataArray with dummy data (useful for tree manipulation).
da_mask = jax.tree.map(lambda _: True, data)
da_mask = jax.tree.map(lambda _: True, da)

# Use jax.grad.
# Take the gradient of a jitted function.
@eqx.filter_jit
def fn(data):
return (data**2.0).sum().data

grad = jax.grad(fn)(da)
da_grad = jax.grad(fn)(da)

# Convert to a custom XjDataArray, implemented as an equinox module.
# (Useful for avoiding potentially weird xarray interactions with JAX).
xj_da = xj.from_xarray(da)

# Convert back to a xr.DataArray.
da = xj.to_xarray(xj_da)

```
## Installation
```bash
Expand Down
36 changes: 29 additions & 7 deletions xarray_jax/tests/test_readme_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,42 @@
"""


def test_dataarray_example():
def test_main_example():
import jax.numpy as jnp
import xarray as xr
import xarray_jax as xj

# Construct a DataArray.
da = xr.DataArray(
xr.Variable(["x", "y"], jnp.ones((2, 3))),
coords={"x": [1, 2], "y": [3, 4, 5]},
name="foo",
attrs={"attr1": "value1"},
)

# Do some operations inside a JIT compiled function.
@eqx.filter_jit
def some_function(data):
neg_data = -1.0 * data # Multiply data by -1.
neg_data = -1.0 * data
return neg_data * neg_data.coords["y"] # Multiply data by coords.

da_new = some_function(da)
da = some_function(da)

assert da_new.equals(-1.0 * da * da.coords["y"])
# Construct a xr.DataArray with dummy data (useful for tree manipulation).
da_mask = jax.tree.map(lambda _: True, da)

@eqx.filter_jit
# Take the gradient of a jitted function. @eqx.filter_jit
def fn(data):
return (data**2.0).sum().data

grad = jax.grad(fn)(da)
assert grad.equals(2.0 * da)
da_grad = jax.grad(fn)(da)

# Convert to a custom XjDataArray, implemented as an equinox module.
# (Useful for avoiding potentially weird xarray interactions with JAX).
xj_da = xj.from_xarray(da)

# Convert back to a xr.DataArray.
da = xj.to_xarray(xj_da)


def test_fail_example():
Expand All @@ -43,3 +56,12 @@ def test_fail_example():
with pytest.raises(ValueError):
# This will fail.
var = var + 1

var = xr.Variable(dims=("x", "y"), data=jnp.ones((3, 3)))

# This will fail.
with pytest.raises(TypeError):
jnp.square(var)

# This will work.
xr.apply_ufunc(jnp.square, var)

0 comments on commit d1e2294

Please sign in to comment.