diff --git a/README.md b/README.md index 1c35206..296e5f0 100644 --- a/README.md +++ b/README.md @@ -25,41 +25,28 @@ da = some_function(da) # Construct a xr.DataArray with dummy data (useful for tree manipulation). da_mask = jax.tree.map(lambda _: bool, data) -``` - -`xr.Dataset` is currently not supported, but you can do conversions to/from a custom `xj.Dataset` types inside jit-compiled functions. -``` python -ds = xr.tutorial.load_dataset("air_temperature") +# Use jax.grad. @eqx.filter_jit -def some_function(xjds: xj.XjDataset): - # Convert to xr.Dataset. - xrds = xj.to_xarray(xjds) +def fn(data): + return (data**2.0).sum().data - # Do some operation. - xrds = -1.0 * xrds - - # Convert back to xj.Dataset. - return xj.from_xarray(xrds) - -xjds = some_function(xj.from_xarray(ds)) -ds_new = xj.to_xarray(xjds) +grad = jax.grad(fn)(da) ``` - ## Status -- [ ] PyTree node registrations +- [x] PyTree node registrations - [x] `xr.Variable` - - [x] `xr.IndexVariable` - [x] `xr.DataArray` - - [ ] `xr.Dataset` + - [x] `xr.Dataset` - [x] Minimal shadow types implemented as [equinox modules](https://github.com/patrick-kidger/equinox) to handle edge cases (Note: these types are merely data structures that contain the data of these types. They don't have any of the methods of the xarray types). - [x] `xj.Variable` - [x] `xj.DataArray` - [x] `xj.Dataset` -- [x] `xj.from_xarray` and `xj.to_xarray` functions to go between `xj` and `xr` types inside jit-compiled functions. +- [x] `xj.from_xarray` and `xj.to_xarray` functions to go between `xj` and `xr` types. - [x] Support for `xr` types with dummy data (useful for tree manipulation). +- [ ] Support for transformations that change the dimensionality of the data. ## Sharp Edges diff --git a/tests/test_types.py b/tests/test_types.py deleted file mode 100644 index 5e3c9d1..0000000 --- a/tests/test_types.py +++ /dev/null @@ -1,156 +0,0 @@ -import xarray as xr -from hypothesis import given, settings -import hypothesis.strategies as st -import jax.numpy as jnp -import jax -from hypothesis.extra.array_api import make_strategies_namespace -import xarray.testing.strategies as xrst -from hypothesis.strategies import sampled_from - -import xarray_jax -import pytest -import equinox as eqx -from jaxtyping import PyTree -import numpy as np - -jax.config.update("jax_enable_x64", True) - -jnps = make_strategies_namespace(jnp) - -xp_variables = xrst.variables( - array_strategy_fn=jnps.arrays, - dtype=jnps.scalar_dtypes(), -) - -index_variables = xrst.variables( - array_strategy_fn=jnps.arrays, - dtype=jnps.scalar_dtypes(), - dims=xrst.dimension_names(min_dims=1, max_dims=1), -) - -# TODO(allenw): for now, just construct simple DataArrays from Variables. -# Pending xarray developers adding strategies for DataArrays. -# https://github.com/pydata/xarray/pull/6908 -data_arrays = xp_variables.map( - lambda var: xr.DataArray( - var, - coords={ - "dummy_coord": (var.dims, np.ones(var.data.shape)), - "dummy_coord2": (var.dims, np.asarray(var.data)), - }, - ) -) - -# Creating a strategy for ufuncs to test. -ufuncs = [jnp.abs, jnp.sin, jnp.cos, jnp.exp, jnp.log] -ufunc_strat = sampled_from(ufuncs) - - -@given(var=xp_variables, ufunc=ufunc_strat) -@settings(deadline=None) -def test_variables(var: xr.Variable, ufunc): - # Test that we can wrap a jax array in a variable. - assert isinstance(var.data, jax.Array) - - # - # Test that we can flatten and unflatten the variable and get the same result. - # - leaves, treedef = jax.tree.flatten(var) - var_unflattened = jax.tree.unflatten(treedef, leaves) - assert var.equals(var_unflattened) - - # - # Test that we can apply a jitted ufunc to the variable and get the correct result as if we applied it directly to the data. - # - @eqx.filter_jit - def fn(v): - return xr.apply_ufunc(ufunc, v) - - result_var = fn(var) - assert isinstance(result_var.data, jax.Array) - assert result_var.equals(xr.Variable(var.dims, ufunc(var.data), var.attrs)) - - # - # Test that we can create a boolean mask tree. - # - var_mask = jax.tree.map(lambda _: True, var) - assert var_mask._data is True - assert var_mask._dims == var._dims - assert var_mask._attrs == var._attrs - - -@given(var=index_variables, ufunc=ufunc_strat) -@settings(deadline=None) -def test_index_variables(var: xr.Variable, ufunc): - var = var.to_index_variable() - - # Index variables should not be wrapped in a jax array. - assert not isinstance(var.data, jax.Array) - - -@given(da=data_arrays, ufunc=ufunc_strat) -@settings(deadline=None) -def test_dataarrays(da: xr.DataArray, ufunc): - # - # Test that the data is a jax array. - # - assert isinstance(da.variable.data, jax.Array) - - # - # Test that we can flatten and unflatten the da and get the same result. - # - leaves, treedef = jax.tree.flatten(da) - da_unflattened = jax.tree.unflatten(treedef, leaves) - assert da.equals(da_unflattened) - - # - # Test that we can apply a jitted ufunc to the da and get the correct result as if we applied it directly to the data. - # - @eqx.filter_jit - def fn(da_): - return xr.apply_ufunc(ufunc, da_) - - result_da = fn(da) - - expected_da = xr.DataArray( - xr.Variable(da.variable.dims, ufunc(da.variable.data), da.variable.attrs), - coords=da.coords, - ) - assert result_da.equals(expected_da) - - -@given(xr_data=st.one_of(xp_variables, index_variables, data_arrays)) -@settings(deadline=None) -def test_roundtrip(xr_data): - # - # Test that we can convert an xarray object to a xarray_jax object and back. - # - xj_data = xarray_jax.from_xarray(xr_data) - assert isinstance(xj_data, PyTree) - xr_data_roundtrip = xarray_jax.to_xarray(xj_data) - assert xr_data.equals(xr_data_roundtrip) - - # - # Test that we can go back and forth inside a jit-compiled function. - # - @eqx.filter_jit - def fn(xj_data_): - xr_data_ = xarray_jax.to_xarray(xj_data_) - xr_data_ = -1.0 * xr_data_ # Some operation. - return xarray_jax.from_xarray(xr_data_) - - xj_data_roundtrip_neg = fn(xj_data) - xr_data_roundtrip_neg = xarray_jax.to_xarray(xj_data_roundtrip_neg) - assert xr_data_roundtrip_neg.equals(-1.0 * xr_data) - - -@pytest.mark.skip(reason="Dataset is not yet supported.") -def test_ds(): - ds = xr.tutorial.load_dataset("air_temperature") - - ds_mask = jax.tree.map(lambda x: True, ds) - - for k, v in ds_mask.data_vars.items(): - assert v._data is True # TODO(allenw): RecursionError! - assert v._dims == ds[k]._dims - assert v._attrs == ds[k]._attrs diff --git a/xarray_jax/custom_types.py b/xarray_jax/custom_types.py index 1a1a637..fe43ac6 100644 --- a/xarray_jax/custom_types.py +++ b/xarray_jax/custom_types.py @@ -8,6 +8,21 @@ import equinox as eqx import jax import xarray +import jax.numpy as jnp + + +def error_if_inside_jit(): + is_jit = isinstance(jnp.array(0), jax.core.Tracer) + if is_jit: + raise ValueError( + "This function should not be called inside a jax.jit'ed function." + ) + + +def maybe_hash_coords(coords): + if isinstance(coords, _HashableCoords): + return coords + return _HashableCoords(coords) class _HashableCoords(collections.abc.Mapping): @@ -77,6 +92,7 @@ def __init__( self.attrs = attrs def to_xarray(self) -> xarray.Variable: + error_if_inside_jit() if self.data is None: return None return xarray.Variable(dims=self.dims, data=self.data, attrs=self.attrs) @@ -102,6 +118,7 @@ def __init__( self.name = name def to_xarray(self) -> xarray.DataArray: + error_if_inside_jit() var = self.variable.to_xarray() if var is None: return None @@ -110,7 +127,7 @@ def to_xarray(self) -> xarray.DataArray: @classmethod def from_xarray(cls, da: xarray.DataArray) -> "XjDataArray": return cls( - XjVariable.from_xarray(da.variable), _HashableCoords(da.coords), da.name + XjVariable.from_xarray(da.variable), maybe_hash_coords(da.coords), da.name ) @@ -130,6 +147,7 @@ def __init__( self.attrs = attrs def to_xarray(self) -> xarray.Dataset: + error_if_inside_jit() data_vars = {name: var.to_xarray() for name, var in self.variables.items()} data_vars = {name: var for name, var in data_vars.items() if var is not None} @@ -147,7 +165,7 @@ def from_xarray(cls, ds: xarray.Dataset) -> "XjDataset": name: XjVariable.from_xarray(da.variable) for name, da in ds.data_vars.items() }, - _HashableCoords(ds.coords), + maybe_hash_coords(ds.coords), ds.attrs, ) diff --git a/xarray_jax/register_pytrees.py b/xarray_jax/register_pytrees.py index c67d9a8..a754f25 100644 --- a/xarray_jax/register_pytrees.py +++ b/xarray_jax/register_pytrees.py @@ -1,8 +1,10 @@ import xarray import jax -from .custom_types import _HashableCoords +from .custom_types import maybe_hash_coords, _HashableCoords from typing import Tuple, Hashable, Mapping +import numpy as np + def _flatten_variable( v: xarray.Variable, @@ -12,7 +14,6 @@ def _flatten_variable( v._dims, v._attrs, ) - assert isinstance(aux, Hashable) return children, aux @@ -39,8 +40,7 @@ def _flatten_data_array( da: xarray.DataArray, ): children = (da._variable,) - aux = (da._name, da._coords, da._indexes) - assert isinstance(aux, Hashable) + aux = (da._name, maybe_hash_coords(da._coords), da._indexes) return children, aux @@ -53,7 +53,7 @@ def _unflatten_data_array( da = object.__new__(xarray.DataArray) da._variable = variable da._name = name - da._coords = coords + da._coords = dict(coords) da._indexes = indexes return da @@ -61,11 +61,14 @@ def _unflatten_data_array( def _flatten_dataset( ds: xarray.Dataset, ): - data_vars = {name: data_array.variable for name, data_array in ds.data_vars.items()} + coord_names = ds._coord_names + variables = ds._variables + + coords = {name: variables[name] for name in coord_names} + data_vars = {name: variables[name] for name in variables if name not in coord_names} - data_var_leaves, data_var_treedef = jax.tree.flatten(data_vars) - children = data_var_leaves - aux = (ds.coords, data_var_treedef, ds._indexes, ds._dims, ds._attrs) + children = (data_vars,) + aux = (maybe_hash_coords(coords), ds._indexes, ds._dims, ds._attrs) assert isinstance(aux, Hashable) return children, aux @@ -77,15 +80,13 @@ def _unflatten_dataset( ], ) -> xarray.Dataset: """Unflattens a Dataset for jax.tree_util.""" - data_var_leaves = children - coords, data_var_treedef, indexes, dims, attrs = aux - - data_vars = jax.tree.unflatten(data_var_treedef, data_var_leaves) + data_vars = children[0] + coords, indexes, dims, attrs = aux ds = object.__new__(xarray.Dataset) ds._dims = dims ds._variables = data_vars | dict(coords) - ds._coord_names = list(coords.keys()) + ds._coord_names = set(coords.keys()) ds._attrs = attrs ds._indexes = indexes ds._encoding = None @@ -96,11 +97,8 @@ def _unflatten_dataset( jax.tree_util.register_pytree_node( xarray.Variable, _flatten_variable, _unflatten_variable ) -jax.tree_util.register_pytree_node( - xarray.IndexVariable, _flatten_variable, _unflatten_variable -) +jax.tree_util.register_static(xarray.IndexVariable) jax.tree_util.register_pytree_node( xarray.DataArray, _flatten_data_array, _unflatten_data_array ) - -# TODO(allenw): fix xarray.Dataset pytree registration \ No newline at end of file +jax.tree_util.register_pytree_node(xarray.Dataset, _flatten_dataset, _unflatten_dataset) diff --git a/tests/__init__.py b/xarray_jax/tests/__init__.py similarity index 100% rename from tests/__init__.py rename to xarray_jax/tests/__init__.py diff --git a/xarray_jax/tests/strategies.py b/xarray_jax/tests/strategies.py new file mode 100644 index 0000000..4ac7a5d --- /dev/null +++ b/xarray_jax/tests/strategies.py @@ -0,0 +1,112 @@ +import xarray as xr +import hypothesis.strategies as st +import jax.numpy as jnp +from hypothesis.extra.array_api import make_strategies_namespace +import xarray.testing.strategies as xrst +from hypothesis.strategies import sampled_from +import numpy as np +import xarray_jax +import equinox as eqx +import jax + +jnps = make_strategies_namespace(jnp) + +xp_variables = xrst.variables( + array_strategy_fn=jnps.arrays, + dtype=jnps.scalar_dtypes(), +) + +xp_variables_float = xrst.variables( + array_strategy_fn=jnps.arrays, dtype=jnps.floating_dtypes() +) + +# TODO(allenw): for now, just construct simple DataArrays from Variables. +# Pending xarray developers adding strategies for DataArrays. +# https://github.com/pydata/xarray/pull/6908 +data_arrays = xp_variables.map( + lambda var: xr.DataArray( + var, + coords={ + "dummy_coord": (var.dims, np.ones(var.data.shape)), + "dummy_coord2": (var.dims, np.asarray(var.data)), + }, + ) +) + +data_arrays_float = xp_variables_float.map( + lambda var: xr.DataArray( + var, + coords={ + "dummy_coord": (var.dims, np.ones(var.data.shape)), + "dummy_coord2": (var.dims, np.asarray(var.data)), + }, + ) +) + +generic_xr_strat = st.one_of( + xp_variables, + data_arrays, + st.just(xr.tutorial.load_dataset("air_temperature")), +) + +float_vars_and_das = st.one_of( + xp_variables_float, + data_arrays_float, +) + +# Creating a strategy for ufuncs to test. +ufuncs = [jnp.abs, jnp.sin, jnp.cos, jnp.exp, jnp.log] +ufunc_strat = sampled_from(ufuncs) + + +""" +Strategy for sampling from identity transformations. +""" + + +def xj_roundtrip(xr_data): + xj_data = xarray_jax.from_xarray(xr_data) + xj_data_roundtrip = xarray_jax.to_xarray(xj_data) + return xj_data_roundtrip + + +def flatten_unflatten(x): + leaves, treedef = jax.tree.flatten(x) + return jax.tree.unflatten(treedef, leaves) + + +def jit_identity(x): + @eqx.filter_jit + def fn(x_): + return x_ + + return fn(x) + + +def vmap_identity(x): + @eqx.filter_vmap + def fn(x_): + return x_ + + return fn(x) + + +def partition(x): + out, _ = eqx.partition(x, lambda _: True) + return out + + +def filt(x): + return eqx.filter(x, lambda _: True) + + +identity_transforms = sampled_from( + [ + xj_roundtrip, + flatten_unflatten, + jit_identity, + vmap_identity, + partition, + filt, + ] +) diff --git a/tests/test_readme_examples.py b/xarray_jax/tests/test_readme_examples.py similarity index 55% rename from tests/test_readme_examples.py rename to xarray_jax/tests/test_readme_examples.py index 0d62d4b..c42da38 100644 --- a/tests/test_readme_examples.py +++ b/xarray_jax/tests/test_readme_examples.py @@ -9,8 +9,8 @@ Tests for the examples in the README. """ -def test_dataarray_example(): +def test_dataarray_example(): da = xr.DataArray( xr.Variable(["x", "y"], jnp.ones((2, 3))), coords={"x": [1, 2], "y": [3, 4, 5]}, @@ -20,13 +20,21 @@ def test_dataarray_example(): @eqx.filter_jit def some_function(data): - neg_data = -1.0 * data # Multiply data by -1. - return neg_data * neg_data.coords["y"] # Multiply data by coords. + neg_data = -1.0 * data # Multiply data by -1. + return neg_data * neg_data.coords["y"] # Multiply data by coords. da_new = some_function(da) assert da_new.equals(-1.0 * da * da.coords["y"]) + @eqx.filter_jit + def fn(data): + return (data**2.0).sum().data + + grad = jax.grad(fn)(da) + assert grad.equals(2.0 * da) + + def test_fail_example(): var = xr.Variable(dims=("x", "y"), data=jnp.ones((3, 3))) @@ -35,24 +43,4 @@ def test_fail_example(): with pytest.raises(ValueError): # This will fail. - var = var + 1 - -def test_dataset_example(): - ds = xr.tutorial.load_dataset("air_temperature") - - @eqx.filter_jit - def some_function(xjds: xj.XjDataset): - # Convert to xr.Dataset. - xrds = xj.to_xarray(xjds) - - # Do some operation. - xrds = -1.0 * xrds - - # Convert back to xj.Dataset. - return xj.from_xarray(xrds) - - xjds = some_function(xj.from_xarray(ds)) - ds_new = xj.to_xarray(xjds) - - assert isinstance(ds, xr.Dataset) - assert ds_new.equals(ds * -1.0) \ No newline at end of file + var = var + 1 diff --git a/xarray_jax/tests/test_types.py b/xarray_jax/tests/test_types.py new file mode 100644 index 0000000..610092a --- /dev/null +++ b/xarray_jax/tests/test_types.py @@ -0,0 +1,101 @@ +import xarray as xr +from hypothesis import given, settings +import jax + +from xarray_jax.tests.strategies import ( + generic_xr_strat, + ufunc_strat, + xp_variables, + identity_transforms, + xj_roundtrip, + float_vars_and_das, +) +import equinox as eqx + + +jax.config.update("jax_enable_x64", True) + + +@given(var=xp_variables, ufunc=ufunc_strat) +@settings(deadline=None) +def test_variables(var: xr.Variable, ufunc): + # + # Test that we can create a boolean mask tree. + # + var_mask = jax.tree.map(lambda _: True, var) + assert var_mask._data is True + assert var_mask._dims == var._dims + assert var_mask._attrs == var._attrs + + +@given( + xr_data=generic_xr_strat, + ufunc=ufunc_strat, +) +@settings(deadline=None) +def test_ufunc(xr_data, ufunc): + """ + Test that we can apply a jitted ufunc to the da and get the correct result as if we applied it directly to the data. + """ + + @eqx.filter_jit + def fn(data): + return xr.apply_ufunc(ufunc, data) + + result = fn(xr_data) + expected = xr.apply_ufunc(ufunc, xr_data) + assert result.equals(expected) + + +@given( + xr_data=generic_xr_strat, +) +@settings(deadline=None) +def test_boolean_mask(xr_data): + """ + Simple test that we can construct a boolean mask tree with all trues and apply eqx.filter. + """ + + @eqx.filter_jit + def fn(data): + mask = jax.tree.map(lambda _: True, data) + return eqx.filter(data, mask) + + result = fn(xr_data) + assert xr_data.equals(result) + + +@given( + xr_data=generic_xr_strat, + transform=identity_transforms, +) +@settings(deadline=None) +def test_identity_transforms(xr_data, transform): + """ + Test for identity JAX transformations. + """ + # Run the transform without JIT. + out = transform(xr_data) + assert xr_data.equals(out) + # Do another round trip to test that we can call the xr constructor on the output. + reconstructed = xj_roundtrip(out) + assert xr_data.equals(reconstructed) + + +@given( + xr_data=float_vars_and_das, +) +@settings(deadline=None) +def test_grads(xr_data): + # Test the gradient of sum(x**2), which is 2*x. + @eqx.filter_jit + def fn(data): + return (data**2.0).sum().data # Requires .data to get the value. + + grad = jax.grad(fn)(xr_data) + expected = 2 * xr_data + assert grad.equals(expected) + + val, grad = eqx.filter_value_and_grad(fn)(xr_data) + assert val == (xr_data**2.0).sum().data + assert grad.equals(expected)