From a8ea1be906e182bbaed31e78c7d982ea5299d88e Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Mon, 22 Jul 2024 17:32:38 -0400 Subject: [PATCH] Change deprecated jax.tree_util.tree_map to jax.tree.map. Fix argument passed to jax.numpy.finfo call. --- mctx/_src/policies.py | 2 +- mctx/_src/search.py | 6 +++--- mctx/_src/tests/policies_test.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mctx/_src/policies.py b/mctx/_src/policies.py index 06f9670..f5f6538 100644 --- a/mctx/_src/policies.py +++ b/mctx/_src/policies.py @@ -385,7 +385,7 @@ def _mask_invalid_actions(logits, invalid_actions): def _get_logits_from_probs(probs): - tiny = jnp.finfo(probs).tiny + tiny = jnp.finfo(probs.dtype).tiny return jnp.log(jnp.maximum(probs, tiny)) diff --git a/mctx/_src/search.py b/mctx/_src/search.py index 9ab3c76..18303e8 100644 --- a/mctx/_src/search.py +++ b/mctx/_src/search.py @@ -219,7 +219,7 @@ def expand( chex.assert_shape([parent_index, action, next_node_index], (batch_size,)) # Retrieve states for nodes to be evaluated. - embedding = jax.tree_util.tree_map( + embedding = jax.tree.map( lambda x: x[batch_range, parent_index], tree.embeddings) # Evaluate and create a new node. @@ -335,7 +335,7 @@ def update_tree_node( tree.node_values, value, node_index), node_visits=batch_update( tree.node_visits, new_visit, node_index), - embeddings=jax.tree_util.tree_map( + embeddings=jax.tree.map( lambda t, s: batch_update(t, s, node_index), tree.embeddings, embedding)) @@ -375,7 +375,7 @@ def _zeros(x): children_visits=jnp.zeros(batch_node_action, dtype=jnp.int32), children_rewards=jnp.zeros(batch_node_action, dtype=data_dtype), children_discounts=jnp.zeros(batch_node_action, dtype=data_dtype), - embeddings=jax.tree_util.tree_map(_zeros, root.embedding), + embeddings=jax.tree.map(_zeros, root.embedding), root_invalid_actions=root_invalid_actions, extra_data=extra_data) diff --git a/mctx/_src/tests/policies_test.py b/mctx/_src/tests/policies_test.py index c08a093..df24eef 100644 --- a/mctx/_src/tests/policies_test.py +++ b/mctx/_src/tests/policies_test.py @@ -245,7 +245,7 @@ def test_gumbel_muzero_policy(self): # Testing max_depth. leaf, max_found_depth = _get_deepest_leaf( - jax.tree_util.tree_map(lambda x: x[0], policy_output.search_tree), + jax.tree.map(lambda x: x[0], policy_output.search_tree), policy_output.search_tree.ROOT_INDEX) self.assertEqual(max_depth, max_found_depth) self.assertEqual(6, policy_output.search_tree.node_visits[0, leaf])