Skip to content

Commit

Permalink
Add var_change_on_unflatten examples
Browse files Browse the repository at this point in the history
  • Loading branch information
allen-adastra committed Oct 3, 2024
1 parent d1e2294 commit 1d9d97b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
26 changes: 10 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ da = some_function(da)
# Construct a xr.DataArray with dummy data (useful for tree manipulation).
da_mask = jax.tree.map(lambda _: True, da)

# Take the gradient of a jitted function.
@eqx.filter_jit
# Take the gradient of a jitted function. @eqx.filter_jit
def fn(data):
return (data**2.0).sum().data

Expand All @@ -39,6 +38,14 @@ xj_da = xj.from_xarray(da)

# Convert back to a xr.DataArray.
da = xj.to_xarray(xj_da)

# Use xj.var_change_on_unflatten to allow us to expand the dimensions of the DataArray.
def add_dim_to_var(var):
var._dims = ("new_dim", *var._dims)
return var

with xj.var_change_on_unflatten(add_dim_to_var):
da = jax.tree.map(lambda x: jnp.expand_dims(x, axis=0), da)
```
## Installation
```bash
Expand All @@ -56,26 +63,13 @@ pip install xarray_jax
- [x] `XjDataset`
- [x] `xj.from_xarray` and `xj.to_xarray` functions to go between `xj` and `xr` types.
- [x] Support for `xr` types with dummy data (useful for tree manipulation).
- [ ] Support for transformations that change the dimensionality of the data.
- [x] Support for transformations that change the dimensionality of the data using the `var_change_on_unflatten` context manager.

## Sharp Edges

### Prefer `eqx.filter_jit` over `jax.jit`
There are some edge cases with metadata that `eqx.filter_jit` handles but `jax.jit` does not.

### Operations that Increase the Dimensionality of the Data
Operations that increase the dimensionality of the data (e.g. `jnp.expand_dims`) will cause problems downstream.

``` python
var = xr.Variable(dims=("x", "y"), data=jnp.ones((3, 3)))

# This will not error.
var = jax.tree.map(lambda x: jnp.expand_dims(x, axis=0), var)

# The error from expanding the dimensionality will be triggered here.
var = var + 1
```

### Dispatching to jnp is not supported yet
Pending resolution of https://github.com/pydata/xarray/issues/7848.
``` python
Expand Down
8 changes: 8 additions & 0 deletions xarray_jax/tests/test_readme_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def fn(data):
# Convert back to a xr.DataArray.
da = xj.to_xarray(xj_da)

# Use xj.var_change_on_unflatten to allow us to expand the dimensions of the DataArray.
def add_dim_to_var(var):
var._dims = ("new_dim", *var._dims)
return var

with xj.var_change_on_unflatten(add_dim_to_var):
da = jax.tree.map(lambda x: jnp.expand_dims(x, axis=0), da)


def test_fail_example():
var = xr.Variable(dims=("x", "y"), data=jnp.ones((3, 3)))
Expand Down

0 comments on commit 1d9d97b

Please sign in to comment.