From d37d83493ff06084c614f67631e0b59372a92493 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 15 May 2024 15:28:02 -0700 Subject: [PATCH] Replace deprecated `jax.tree_*` functions with `jax.tree.*` The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25. PiperOrigin-RevId: 634095385 --- CHANGELOG.md | 6 +++--- docs/flip/2396-rnn.md | 2 +- docs/flip/2434-general-metadata.md | 10 +++++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e0a570266..2bbcd73620 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,8 +44,8 @@ vNext - Add SimpleCell. by @carlosgmartin in https://github.com/google/flax/pull/3697 - fix Module.module_paths docstring by @cgarciae in https://github.com/google/flax/pull/3709 - Guarantee the latest JAX version on CI by @cgarciae in https://github.com/google/flax/pull/3705 -- Replace deprecated API `jax.tree_map` by @copybara-service in https://github.com/google/flax/pull/3715 -- Use `jax.tree_util.tree_map` instead of deprecated `jax.tree_map`. by @copybara-service in https://github.com/google/flax/pull/3714 +- Replace deprecated API `jax.tree.map` by @copybara-service in https://github.com/google/flax/pull/3715 +- Use `jax.tree_util.tree_map` instead of deprecated `jax.tree.map`. by @copybara-service in https://github.com/google/flax/pull/3714 - [nnx] simplify readme by @cgarciae in https://github.com/google/flax/pull/3707 - [nnx] add demo.ipynb by @cgarciae in https://github.com/google/flax/pull/3680 - Fix Tabulate's compute_flops by @cgarciae in https://github.com/google/flax/pull/3721 @@ -340,7 +340,7 @@ Breaking changes: New features: - Add lifted conditional `nn.cond`. - Improved error messages: parameters not found, loading checkpoints. -- Replace `jax.tree_multimap` (deprecated) with `jax.tree_map`. +- Replace `jax.tree_multimap` (deprecated) with `jax.tree.map`. - Add the "Module Lifecycle" design note. - Add support for JAX dynamic stack-based named_call diff --git a/docs/flip/2396-rnn.md b/docs/flip/2396-rnn.md index 94a7e78a16..06fa5ab18a 100644 --- a/docs/flip/2396-rnn.md +++ b/docs/flip/2396-rnn.md @@ -162,7 +162,7 @@ def __call__(self, inputs, seq_lengths): keep_order=True, # but return the sequence in the original order ) # Merge both sequences. - outputs = jax.tree_map(self.merge_fn, outputs_forward, outputs_backward) + outputs = jax.tree.map(self.merge_fn, outputs_forward, outputs_backward) return (carry_forward, carry_backward), outputs ``` diff --git a/docs/flip/2434-general-metadata.md b/docs/flip/2434-general-metadata.md index 7758fda0ee..8f21e378b9 100644 --- a/docs/flip/2434-general-metadata.md +++ b/docs/flip/2434-general-metadata.md @@ -123,12 +123,12 @@ This should make the API future proof and modular. The ``add_axis`` and ``remove_axis`` method return an instance of their own type instead of mutating in-place. Typically, an implementation would be a ``flax.struct.PyTreeNode`` because the box should still be a valid JAX value and must therefore be handled by the PyTree API. -Calling ``jax.tree_map`` on a boxed value will simply map over the value in the box. -The lifted transforms that need to handle metadata will call ``jax.tree_map(..., is_leaf=lambda x: isinstance(x, AxisMetadata))`` to find the AxisMetadata instances within a PyTree. +Calling ``jax.tree.map`` on a boxed value will simply map over the value in the box. +The lifted transforms that need to handle metadata will call ``jax.tree.map(..., is_leaf=lambda x: isinstance(x, AxisMetadata))`` to find the AxisMetadata instances within a PyTree. Advantages of the boxing approach: 1. Boxing can be used outside of Flax and metadata is automatically "inherited". For example, the optimizer state will - have the same partitioning spec as the parameters, because the state is initialized using a ``jax.tree_map`` over the boxed parameters. + have the same partitioning spec as the parameters, because the state is initialized using a ``jax.tree.map`` over the boxed parameters. 2. Boxes are composable. 3. Boxing avoids string manipulation and generally avoids having to handle additional auxiliary collections like "param_axes" in the current partitioning API. @@ -184,7 +184,7 @@ Initializing a model that creates partitioned weights would result in the follow ```python variables = partitioned_dense.init(rng, jnp.ones((4,))) -jax.tree_map(np.shape, variables) # => {"params": {"kernel": Partitioned(value=(4, 8), names=(None, "data")), bias: (8,)}} +jax.tree.map(np.shape, variables) # => {"params": {"kernel": Partitioned(value=(4, 8), names=(None, "data")), bias: (8,)}} ``` The variable tree with metadata can be used to integrate with other libraries and APIs. @@ -199,7 +199,7 @@ def to_sharding_spec(x): return PartitionSpec() # Result: {"params": {"kernel": PartitionSpec(None, "data"), bias: PartitionSpec()}} -variables_pspec = jax.tree_map(to_sharding_spec, variables, is_leaf=lambda x: isinstance(x, Partitioned)) +variables_pspec = jax.tree.map(to_sharding_spec, variables, is_leaf=lambda x: isinstance(x, Partitioned)) ``` ### Unbox syntax