Skip to content

Commit

Permalink
Add var_change_on_unflatten to public API
Browse files Browse the repository at this point in the history
  • Loading branch information
allen-adastra committed Sep 26, 2024
1 parent 6f16c55 commit 21d44ea
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "xarray-jax"
version = "0.0.2"
version = "0.0.3"
description = ""
authors = ["Allen Wang <allenw@mit.edu>"]
readme = "README.md"
Expand Down
3 changes: 2 additions & 1 deletion xarray_jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import xarray
from xarray_jax.custom_types import XjDataArray, XjDataset, XjVariable, from_xarray, to_xarray
from xarray_jax.register_pytrees import var_change_on_unflatten

xarray.set_options(keep_attrs=True) # Necessary for preserving PyTree structure.

__all__ = ["XjDataArray", "XjDataset", "XjVariable", "from_xarray", "to_xarray"]
__all__ = ["XjDataArray", "XjDataset", "XjVariable", "from_xarray", "to_xarray", "var_change_on_unflatten"]

0 comments on commit 21d44ea

Please sign in to comment.