Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 634134993
  • Loading branch information
Jake VanderPlas authored and Dinosaur authors committed May 16, 2024
1 parent b862520 commit 23f9264
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 17 deletions.
2 changes: 1 addition & 1 deletion dinosaur/coordinate_systems_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def test_spectral_intrpolate_fn(
interpolate_fn = coordinate_systems.get_spectral_interpolate_fn(
coords, save_coords)
actual = interpolate_fn(input_state)
for x, y in zip(jax.tree_leaves(actual), jax.tree_leaves(expected_state)):
for x, y in zip(jax.tree.leaves(actual), jax.tree.leaves(expected_state)):
np.testing.assert_allclose(
save_coords.horizontal.to_nodal(x),
save_coords.horizontal.to_nodal(y), atol=atol)
Expand Down
8 changes: 4 additions & 4 deletions dinosaur/pytree_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,13 @@ def test_split_and_concat(self, pytree, idx, axis):
"""Tests that split_along_axis, concat_along_axis return expected shapes."""
split_a, split_b = pytree_utils.split_along_axis(pytree, idx, axis, False)
with self.subTest('split_shape'):
self.assertEqual(jax.tree_leaves(split_a)[0].shape[axis], idx)
self.assertEqual(jax.tree.leaves(split_a)[0].shape[axis], idx)

reconstruction = pytree_utils.concat_along_axis([split_a, split_b], axis)
with self.subTest('split_concat_roundtrip'):
chex.assert_trees_all_close(reconstruction, pytree)

same_ndims = len(set(a.ndim for a in jax.tree_leaves(reconstruction))) == 1
same_ndims = len(set(a.ndim for a in jax.tree.leaves(reconstruction))) == 1
if not same_ndims:
with self.subTest('raises_when_wrong_ndims'):
with self.assertRaisesRegex(ValueError, 'arrays in `inputs` expected'):
Expand All @@ -120,8 +120,8 @@ def test_split_and_concat(self, pytree, idx, axis):
with self.subTest('multiple_concat_shape'):
arrays = [split_a, split_a, split_b, split_b]
double_concat = pytree_utils.concat_along_axis(arrays, axis)
actual_shape = jax.tree_leaves(double_concat)[0].shape[axis]
expected_shape = jax.tree_leaves(pytree)[0].shape[axis] * 2
actual_shape = jax.tree.leaves(double_concat)[0].shape[axis]
expected_shape = jax.tree.leaves(pytree)[0].shape[axis] * 2
self.assertEqual(actual_shape, expected_shape)

def test_pytree_cache(self):
Expand Down
4 changes: 2 additions & 2 deletions dinosaur/shallow_water.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,13 @@ def from_si(

def state_to_nodal(state: State, grid: spherical_harmonic.Grid) -> State:
"""Converts a state to the spatial/nodal basis."""
return jax.tree_map(
return jax.tree.map(
lambda x: grid.to_nodal(grid.clip_wavenumbers(x)), state)


def state_to_modal(state: State, grid: spherical_harmonic.Grid) -> State:
"""Converts a state to the spectral/modal basis."""
return jax.tree_map(grid.to_modal, state)
return jax.tree.map(grid.to_modal, state)


def get_density_ratios(density: Array) -> np.ndarray:
Expand Down
2 changes: 1 addition & 1 deletion dinosaur/shallow_water_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def barotropic_instability_tc(
def random_state_fn(rng_key: jnp.ndarray) -> shallow_water.State:
parameters = get_random_parameters(rng_key, default_parameters)
# make sure that all parameters are non-dimensionalized.
parameters = jax.tree_map(physics_specs.nondimensionalize, parameters)
parameters = jax.tree.map(physics_specs.nondimensionalize, parameters)
# The initial condition is computed by findind a steady state solution and
# then adding a small 'bump' to the potential.
zonal_velocity = jnp.stack([get_zonal_velocity(lat, parameters)
Expand Down
4 changes: 2 additions & 2 deletions dinosaur/shallow_water_states_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def testLongitudeSubtraction(self, a, b, expected):
def testZonalVelocity(self, seed):
parameters = shallow_water_states.get_random_parameters(
jax.random.PRNGKey(seed), shallow_water_states.get_default_parameters())
parameters = jax.tree_map(DEFAULT_SCALE.nondimensionalize, parameters)
parameters = jax.tree.map(DEFAULT_SCALE.nondimensionalize, parameters)

latitude = np.linspace(-np.pi / 2, np.pi / 2, 101)
zonal_velocity = shallow_water_states.get_zonal_velocity(
Expand Down Expand Up @@ -89,7 +89,7 @@ def testZonalVelocity(self, seed):
def testHeight(self, seed):
parameters = shallow_water_states.get_random_parameters(
jax.random.PRNGKey(seed), shallow_water_states.get_default_parameters())
parameters = jax.tree_map(DEFAULT_SCALE.nondimensionalize, parameters)
parameters = jax.tree.map(DEFAULT_SCALE.nondimensionalize, parameters)
longitude = np.linspace(0, 2 * np.pi, 101)
latitude = np.linspace(-np.pi / 2, np.pi / 2, 101)
longitude, latitude = np.meshgrid(longitude, latitude, indexing='ij')
Expand Down
14 changes: 7 additions & 7 deletions dinosaur/xarray_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def seed_stream(s):


def shape_structure(inputs):
return jax.tree_map(lambda x: x.shape, inputs)
return jax.tree.map(lambda x: x.shape, inputs)


class XarrayUtilsTest(parameterized.TestCase):
Expand Down Expand Up @@ -105,9 +105,9 @@ def test_primitive_eq_data_to_xarray(
'test_tracer': primitive_equations_states.gaussian_scalar(
coords, physics_specs, amplitude=0.1)
}
trajectory = jax.tree_map(
trajectory = jax.tree.map(
lambda *args: np.stack(args), *([state,] * time_steps))
batch_of_trajectories = jax.tree_map(
batch_of_trajectories = jax.tree.map(
lambda *args: np.stack(args), *([trajectory,] * samples))
expected_grid_attrs = coords.asdict()

Expand Down Expand Up @@ -176,7 +176,7 @@ def test_primitive_eq_data_to_xarray(
}
times = dt * np.arange(time_steps)
sample_ids = np.arange(samples)
nodal_batch_of_trajectories = jax.tree_map(
nodal_batch_of_trajectories = jax.tree.map(
grid.to_nodal, batch_of_trajectories)
ds = xarray_utils.data_to_xarray(
nodal_batch_of_trajectories.asdict(), sample_ids=sample_ids,
Expand Down Expand Up @@ -226,9 +226,9 @@ def test_shallow_water_eq_data_to_xarray(
state = shallow_water_states.multi_layer(
velocity, physics_specs.densities, coords)

trajectory = jax.tree_map(
trajectory = jax.tree.map(
lambda *args: np.stack(args), *([state,] * time_steps))
batch_of_trajectories = jax.tree_map(
batch_of_trajectories = jax.tree.map(
lambda *args: np.stack(args), *([trajectory,] * samples))
expected_grid_attrs = coords.asdict()

Expand Down Expand Up @@ -290,7 +290,7 @@ def test_shallow_water_eq_data_to_xarray(
}
times = dt * np.arange(time_steps)
sample_ids = np.arange(samples)
nodal_batch_of_trajectories = jax.tree_map(
nodal_batch_of_trajectories = jax.tree.map(
grid.to_nodal, batch_of_trajectories)
ds = xarray_utils.data_to_xarray(
nodal_batch_of_trajectories.asdict(), sample_ids=sample_ids,
Expand Down

0 comments on commit 23f9264

Please sign in to comment.