From f9de409cb4d410560da1e7ea97ab04f213684dbb Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 8 Jun 2023 02:41:21 -0700 Subject: [PATCH] Remove references to deprecated submodule jax.abstract_arrays Use jax.core instead (see https://github.com/google/jax/pull/16271) PiperOrigin-RevId: 538729607 --- oryx/core/interpreters/harvest.py | 7 +++---- oryx/core/interpreters/inverse/core.py | 7 +++---- oryx/core/ppl/effect_handler_test.py | 3 +-- oryx/core/ppl/transformations_test.py | 3 +-- oryx/core/primitive.py | 3 +-- oryx/core/trace_util.py | 5 ++--- 6 files changed, 11 insertions(+), 17 deletions(-) diff --git a/oryx/core/interpreters/harvest.py b/oryx/core/interpreters/harvest.py index 186b88e..66c47b4 100644 --- a/oryx/core/interpreters/harvest.py +++ b/oryx/core/interpreters/harvest.py @@ -138,7 +138,6 @@ def f(x): import functools from typing import Any, Callable, Dict, FrozenSet, Hashable, Iterable, List, Optional, Tuple, Union -from jax import abstract_arrays from jax import api_util from jax import lax from jax import linear_util as lu @@ -419,7 +418,7 @@ def __init__(self, trace: 'HarvestTrace', val: Value): @property def aval(self): - return abstract_arrays.raise_to_shaped(jax_core.get_aval(self.val)) + return jax_core.raise_to_shaped(jax_core.get_aval(self.val)) def full_lower(self): return self @@ -512,7 +511,7 @@ def handle_sow(self, *values, name, tag, tree, mode): raise ValueError(f'Variable has already been reaped: {name}') avals = tree_util.tree_unflatten( tree, - [abstract_arrays.raise_to_shaped(jax_core.get_aval(v)) for v in values]) + [jax_core.raise_to_shaped(jax_core.get_aval(v)) for v in values]) self.reaps[name] = Reap( tree_util.tree_unflatten(tree, values), dict(mode=mode, aval=avals)) return values @@ -781,7 +780,7 @@ def _get_harvest_metadata(closed_jaxpr, settings, *args): flat_args, in_tree = tree_util.tree_flatten(args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) in_avals = jax_util.safe_map( - lambda a: abstract_arrays.raise_to_shaped(jax_core.get_aval(a)), + lambda a: jax_core.raise_to_shaped(jax_core.get_aval(a)), flat_args) pe.trace_to_jaxpr_final(flat_fun, in_avals) metadata = aux() diff --git a/oryx/core/interpreters/inverse/core.py b/oryx/core/interpreters/inverse/core.py index 447a3ca..0407c74 100644 --- a/oryx/core/interpreters/inverse/core.py +++ b/oryx/core/interpreters/inverse/core.py @@ -17,7 +17,6 @@ from typing import Iterable import jax -from jax import abstract_arrays from jax import tree_util from jax import util as jax_util from jax._src import core as jax_core @@ -142,7 +141,7 @@ def unknown(cls, aval): def new(cls, val): val = np.array(val) aval = jax_core.get_aval(val) - aval = abstract_arrays.raise_to_shaped(aval) + aval = jax_core.raise_to_shaped(aval) ndslice = NDSlice.new(val, np.zeros_like(val)) return InverseAndILDJ(aval, frozenset([ndslice])) @@ -319,8 +318,8 @@ def map_ildj(prim, incells, outcells, **params): f, incells = incells[0], incells[1:] def slice_aval(aval): - return abstract_arrays.ShapedArray(aval.shape[1:], aval.dtype, - aval.weak_type) + return jax_core.ShapedArray(aval.shape[1:], aval.dtype, + aval.weak_type) def add_slice(cell, old_cell): new_slices = [ diff --git a/oryx/core/ppl/effect_handler_test.py b/oryx/core/ppl/effect_handler_test.py index 56eeb8b..fb00502 100644 --- a/oryx/core/ppl/effect_handler_test.py +++ b/oryx/core/ppl/effect_handler_test.py @@ -15,7 +15,6 @@ """Tests for oryx.core.ppl.effect_handler.""" from absl.testing import absltest import jax -from jax import abstract_arrays from jax import random import jax.numpy as np @@ -40,7 +39,7 @@ def _random_normal_impl(key, loc, scale): @random_normal_p.def_abstract_eval def _random_normal_abstract(key, loc, scale): del key, loc, scale - return [abstract_arrays.ShapedArray((), np.float32)] + return [jax.core.ShapedArray((), np.float32)] class EffectHandlerTest(test_util.TestCase): diff --git a/oryx/core/ppl/transformations_test.py b/oryx/core/ppl/transformations_test.py index 829e299..7bfc8f9 100644 --- a/oryx/core/ppl/transformations_test.py +++ b/oryx/core/ppl/transformations_test.py @@ -16,7 +16,6 @@ from absl.testing import absltest import jax -from jax import abstract_arrays from jax import random from jax._src import core as jax_core from jax.interpreters import batching @@ -61,7 +60,7 @@ def random_normal_impl(rng, *, batch_ndims): def random_normal_abstract(key, **_): del key - return abstract_arrays.ShapedArray((), jnp.float32) + return jax_core.ShapedArray((), jnp.float32) def random_normal_log_prob_rule(incells, outcells, *, batch_ndims, **_): diff --git a/oryx/core/primitive.py b/oryx/core/primitive.py index 3efeded..7617be2 100644 --- a/oryx/core/primitive.py +++ b/oryx/core/primitive.py @@ -16,7 +16,6 @@ import itertools as it from typing import Callable -from jax import abstract_arrays from jax import api_util from jax import linear_util as lu from jax import tree_util @@ -237,7 +236,7 @@ def subcall(self, name): tie_all_p.multiple_results = True tie_all_p.def_impl(lambda *args: args) tie_all_p.def_abstract_eval(lambda *args: safe_map( # pylint: disable=g-long-lambda - abstract_arrays.raise_to_shaped, args)) + jax_core.raise_to_shaped, args)) mlir.register_lowering(tie_all_p, lambda c, *args: args) diff --git a/oryx/core/trace_util.py b/oryx/core/trace_util.py index 994a0c3..5a98ebe 100644 --- a/oryx/core/trace_util.py +++ b/oryx/core/trace_util.py @@ -17,7 +17,6 @@ import threading from typing import Any, Dict, Generator, List -from jax import abstract_arrays from jax import api_util from jax import linear_util as lu from jax import tree_util @@ -41,9 +40,9 @@ def get_shaped_aval(x): """Converts a JAX value type into a shaped abstract value.""" if hasattr(x, 'dtype') and hasattr(x, 'shape'): - return abstract_arrays.ShapedArray( + return jax_core.ShapedArray( x.shape, dtypes.canonicalize_dtype(x.dtype, allow_opaque_dtype=True)) - return abstract_arrays.raise_to_shaped(jax_core.get_aval(x)) + return jax_core.raise_to_shaped(jax_core.get_aval(x)) def pv_like(x, abstract=True):