diff --git a/README.md b/README.md index 7130e99..f9ad013 100644 --- a/README.md +++ b/README.md @@ -19,19 +19,19 @@ da = xr.DataArray( @eqx.filter_jit def some_function(data): neg_data = -1.0 * data - return neg_data * neg_data.coords["y"] # Multiply data by coords. + return neg_data * neg_data.coords["y"] # Multiply data by coords. da = some_function(da) # Construct a xr.DataArray with dummy data (useful for tree manipulation). -da_mask = jax.tree.map(lambda _: True, data) +da_mask = jax.tree.map(lambda _: True, da) -# Use jax.grad. +# Take the gradient of a jitted function. @eqx.filter_jit def fn(data): return (data**2.0).sum().data -grad = jax.grad(fn)(da) +da_grad = jax.grad(fn)(da) # Convert to a custom XjDataArray, implemented as an equinox module. # (Useful for avoiding potentially weird xarray interactions with JAX). @@ -39,7 +39,6 @@ xj_da = xj.from_xarray(da) # Convert back to a xr.DataArray. da = xj.to_xarray(xj_da) - ``` ## Installation ```bash diff --git a/xarray_jax/tests/test_readme_examples.py b/xarray_jax/tests/test_readme_examples.py index 1cb5cb3..54ee764 100644 --- a/xarray_jax/tests/test_readme_examples.py +++ b/xarray_jax/tests/test_readme_examples.py @@ -9,7 +9,12 @@ """ -def test_dataarray_example(): +def test_main_example(): + import jax.numpy as jnp + import xarray as xr + import xarray_jax as xj + + # Construct a DataArray. da = xr.DataArray( xr.Variable(["x", "y"], jnp.ones((2, 3))), coords={"x": [1, 2], "y": [3, 4, 5]}, @@ -17,21 +22,29 @@ def test_dataarray_example(): attrs={"attr1": "value1"}, ) + # Do some operations inside a JIT compiled function. @eqx.filter_jit def some_function(data): - neg_data = -1.0 * data # Multiply data by -1. + neg_data = -1.0 * data return neg_data * neg_data.coords["y"] # Multiply data by coords. - da_new = some_function(da) + da = some_function(da) - assert da_new.equals(-1.0 * da * da.coords["y"]) + # Construct a xr.DataArray with dummy data (useful for tree manipulation). + da_mask = jax.tree.map(lambda _: True, da) - @eqx.filter_jit + # Take the gradient of a jitted function. @eqx.filter_jit def fn(data): return (data**2.0).sum().data - grad = jax.grad(fn)(da) - assert grad.equals(2.0 * da) + da_grad = jax.grad(fn)(da) + + # Convert to a custom XjDataArray, implemented as an equinox module. + # (Useful for avoiding potentially weird xarray interactions with JAX). + xj_da = xj.from_xarray(da) + + # Convert back to a xr.DataArray. + da = xj.to_xarray(xj_da) def test_fail_example(): @@ -43,3 +56,12 @@ def test_fail_example(): with pytest.raises(ValueError): # This will fail. var = var + 1 + + var = xr.Variable(dims=("x", "y"), data=jnp.ones((3, 3))) + + # This will fail. + with pytest.raises(TypeError): + jnp.square(var) + + # This will work. + xr.apply_ufunc(jnp.square, var)