From 21d44ea998f3cb02fd343754c82d6809d5dcec1b Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Wed, 25 Sep 2024 22:41:43 -0400 Subject: [PATCH] Add var_change_on_unflatten to public API --- pyproject.toml | 2 +- xarray_jax/__init__.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 594ac00..9e6e64d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "xarray-jax" -version = "0.0.2" +version = "0.0.3" description = "" authors = ["Allen Wang "] readme = "README.md" diff --git a/xarray_jax/__init__.py b/xarray_jax/__init__.py index d753f06..d70131f 100644 --- a/xarray_jax/__init__.py +++ b/xarray_jax/__init__.py @@ -1,6 +1,7 @@ import xarray from xarray_jax.custom_types import XjDataArray, XjDataset, XjVariable, from_xarray, to_xarray +from xarray_jax.register_pytrees import var_change_on_unflatten xarray.set_options(keep_attrs=True) # Necessary for preserving PyTree structure. -__all__ = ["XjDataArray", "XjDataset", "XjVariable", "from_xarray", "to_xarray"] \ No newline at end of file +__all__ = ["XjDataArray", "XjDataset", "XjVariable", "from_xarray", "to_xarray", "var_change_on_unflatten"] \ No newline at end of file