Skip to content

Commit

Permalink
Improve hashing situation, add xr.Dataset as pytree node, implement i…
Browse files Browse the repository at this point in the history
…dentity transform tests and grads
  • Loading branch information
allen-adastra committed Sep 21, 2024
1 parent 7a6e080 commit c5cdb33
Show file tree
Hide file tree
Showing 8 changed files with 270 additions and 222 deletions.
29 changes: 8 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,41 +25,28 @@ da = some_function(da)

# Construct a xr.DataArray with dummy data (useful for tree manipulation).
da_mask = jax.tree.map(lambda _: bool, data)
```

`xr.Dataset` is currently not supported, but you can do conversions to/from a custom `xj.Dataset` types inside jit-compiled functions.
``` python
ds = xr.tutorial.load_dataset("air_temperature")

# Use jax.grad.
@eqx.filter_jit
def some_function(xjds: xj.XjDataset):
# Convert to xr.Dataset.
xrds = xj.to_xarray(xjds)
def fn(data):
return (data**2.0).sum().data

# Do some operation.
xrds = -1.0 * xrds

# Convert back to xj.Dataset.
return xj.from_xarray(xrds)

xjds = some_function(xj.from_xarray(ds))
ds_new = xj.to_xarray(xjds)
grad = jax.grad(fn)(da)
```



## Status
- [ ] PyTree node registrations
- [x] PyTree node registrations
- [x] `xr.Variable`
- [x] `xr.IndexVariable`
- [x] `xr.DataArray`
- [ ] `xr.Dataset`
- [x] `xr.Dataset`
- [x] Minimal shadow types implemented as [equinox modules](https://github.com/patrick-kidger/equinox) to handle edge cases (Note: these types are merely data structures that contain the data of these types. They don't have any of the methods of the xarray types).
- [x] `xj.Variable`
- [x] `xj.DataArray`
- [x] `xj.Dataset`
- [x] `xj.from_xarray` and `xj.to_xarray` functions to go between `xj` and `xr` types inside jit-compiled functions.
- [x] `xj.from_xarray` and `xj.to_xarray` functions to go between `xj` and `xr` types.
- [x] Support for `xr` types with dummy data (useful for tree manipulation).
- [ ] Support for transformations that change the dimensionality of the data.

## Sharp Edges

Expand Down
156 changes: 0 additions & 156 deletions tests/test_types.py

This file was deleted.

22 changes: 20 additions & 2 deletions xarray_jax/custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,21 @@
import equinox as eqx
import jax
import xarray
import jax.numpy as jnp


def error_if_inside_jit():
is_jit = isinstance(jnp.array(0), jax.core.Tracer)
if is_jit:
raise ValueError(
"This function should not be called inside a jax.jit'ed function."
)


def maybe_hash_coords(coords):
if isinstance(coords, _HashableCoords):
return coords
return _HashableCoords(coords)


class _HashableCoords(collections.abc.Mapping):
Expand Down Expand Up @@ -77,6 +92,7 @@ def __init__(
self.attrs = attrs

def to_xarray(self) -> xarray.Variable:
error_if_inside_jit()
if self.data is None:
return None
return xarray.Variable(dims=self.dims, data=self.data, attrs=self.attrs)
Expand All @@ -102,6 +118,7 @@ def __init__(
self.name = name

def to_xarray(self) -> xarray.DataArray:
error_if_inside_jit()
var = self.variable.to_xarray()
if var is None:
return None
Expand All @@ -110,7 +127,7 @@ def to_xarray(self) -> xarray.DataArray:
@classmethod
def from_xarray(cls, da: xarray.DataArray) -> "XjDataArray":
return cls(
XjVariable.from_xarray(da.variable), _HashableCoords(da.coords), da.name
XjVariable.from_xarray(da.variable), maybe_hash_coords(da.coords), da.name
)


Expand All @@ -130,6 +147,7 @@ def __init__(
self.attrs = attrs

def to_xarray(self) -> xarray.Dataset:
error_if_inside_jit()
data_vars = {name: var.to_xarray() for name, var in self.variables.items()}

data_vars = {name: var for name, var in data_vars.items() if var is not None}
Expand All @@ -147,7 +165,7 @@ def from_xarray(cls, ds: xarray.Dataset) -> "XjDataset":
name: XjVariable.from_xarray(da.variable)
for name, da in ds.data_vars.items()
},
_HashableCoords(ds.coords),
maybe_hash_coords(ds.coords),
ds.attrs,
)

Expand Down
36 changes: 17 additions & 19 deletions xarray_jax/register_pytrees.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import xarray
import jax
from .custom_types import _HashableCoords
from .custom_types import maybe_hash_coords, _HashableCoords
from typing import Tuple, Hashable, Mapping

import numpy as np


def _flatten_variable(
v: xarray.Variable,
Expand All @@ -12,7 +14,6 @@ def _flatten_variable(
v._dims,
v._attrs,
)
assert isinstance(aux, Hashable)
return children, aux


Expand All @@ -39,8 +40,7 @@ def _flatten_data_array(
da: xarray.DataArray,
):
children = (da._variable,)
aux = (da._name, da._coords, da._indexes)
assert isinstance(aux, Hashable)
aux = (da._name, maybe_hash_coords(da._coords), da._indexes)
return children, aux


Expand All @@ -53,19 +53,22 @@ def _unflatten_data_array(
da = object.__new__(xarray.DataArray)
da._variable = variable
da._name = name
da._coords = coords
da._coords = dict(coords)
da._indexes = indexes
return da


def _flatten_dataset(
ds: xarray.Dataset,
):
data_vars = {name: data_array.variable for name, data_array in ds.data_vars.items()}
coord_names = ds._coord_names
variables = ds._variables

coords = {name: variables[name] for name in coord_names}
data_vars = {name: variables[name] for name in variables if name not in coord_names}

data_var_leaves, data_var_treedef = jax.tree.flatten(data_vars)
children = data_var_leaves
aux = (ds.coords, data_var_treedef, ds._indexes, ds._dims, ds._attrs)
children = (data_vars,)
aux = (maybe_hash_coords(coords), ds._indexes, ds._dims, ds._attrs)
assert isinstance(aux, Hashable)
return children, aux

Expand All @@ -77,15 +80,13 @@ def _unflatten_dataset(
],
) -> xarray.Dataset:
"""Unflattens a Dataset for jax.tree_util."""
data_var_leaves = children
coords, data_var_treedef, indexes, dims, attrs = aux

data_vars = jax.tree.unflatten(data_var_treedef, data_var_leaves)
data_vars = children[0]
coords, indexes, dims, attrs = aux

ds = object.__new__(xarray.Dataset)
ds._dims = dims
ds._variables = data_vars | dict(coords)
ds._coord_names = list(coords.keys())
ds._coord_names = set(coords.keys())
ds._attrs = attrs
ds._indexes = indexes
ds._encoding = None
Expand All @@ -96,11 +97,8 @@ def _unflatten_dataset(
jax.tree_util.register_pytree_node(
xarray.Variable, _flatten_variable, _unflatten_variable
)
jax.tree_util.register_pytree_node(
xarray.IndexVariable, _flatten_variable, _unflatten_variable
)
jax.tree_util.register_static(xarray.IndexVariable)
jax.tree_util.register_pytree_node(
xarray.DataArray, _flatten_data_array, _unflatten_data_array
)

# TODO(allenw): fix xarray.Dataset pytree registration
jax.tree_util.register_pytree_node(xarray.Dataset, _flatten_dataset, _unflatten_dataset)
File renamed without changes.
Loading

0 comments on commit c5cdb33

Please sign in to comment.