Skip to content

Commit

Permalink
Remove references to deprecated submodule jax.abstract_arrays
Browse files Browse the repository at this point in the history
Use jax.core instead (see jax-ml/jax#16271)

PiperOrigin-RevId: 538822402
  • Loading branch information
Jake VanderPlas authored and The oryx Authors committed Jun 8, 2023
1 parent 92b0585 commit 3b89034
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 17 deletions.
7 changes: 3 additions & 4 deletions oryx/core/interpreters/harvest.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def f(x):
import functools
from typing import Any, Callable, Dict, FrozenSet, Hashable, Iterable, List, Optional, Tuple, Union

from jax import abstract_arrays
from jax import api_util
from jax import lax
from jax import linear_util as lu
Expand Down Expand Up @@ -419,7 +418,7 @@ def __init__(self, trace: 'HarvestTrace', val: Value):

@property
def aval(self):
return abstract_arrays.raise_to_shaped(jax_core.get_aval(self.val))
return jax_core.raise_to_shaped(jax_core.get_aval(self.val))

def full_lower(self):
return self
Expand Down Expand Up @@ -512,7 +511,7 @@ def handle_sow(self, *values, name, tag, tree, mode):
raise ValueError(f'Variable has already been reaped: {name}')
avals = tree_util.tree_unflatten(
tree,
[abstract_arrays.raise_to_shaped(jax_core.get_aval(v)) for v in values])
[jax_core.raise_to_shaped(jax_core.get_aval(v)) for v in values])
self.reaps[name] = Reap(
tree_util.tree_unflatten(tree, values), dict(mode=mode, aval=avals))
return values
Expand Down Expand Up @@ -781,7 +780,7 @@ def _get_harvest_metadata(closed_jaxpr, settings, *args):
flat_args, in_tree = tree_util.tree_flatten(args)
flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
in_avals = jax_util.safe_map(
lambda a: abstract_arrays.raise_to_shaped(jax_core.get_aval(a)),
lambda a: jax_core.raise_to_shaped(jax_core.get_aval(a)),
flat_args)
pe.trace_to_jaxpr_final(flat_fun, in_avals)
metadata = aux()
Expand Down
7 changes: 3 additions & 4 deletions oryx/core/interpreters/inverse/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import Iterable

import jax
from jax import abstract_arrays
from jax import tree_util
from jax import util as jax_util
from jax._src import core as jax_core
Expand Down Expand Up @@ -142,7 +141,7 @@ def unknown(cls, aval):
def new(cls, val):
val = np.array(val)
aval = jax_core.get_aval(val)
aval = abstract_arrays.raise_to_shaped(aval)
aval = jax_core.raise_to_shaped(aval)
ndslice = NDSlice.new(val, np.zeros_like(val))
return InverseAndILDJ(aval, frozenset([ndslice]))

Expand Down Expand Up @@ -319,8 +318,8 @@ def map_ildj(prim, incells, outcells, **params):
f, incells = incells[0], incells[1:]

def slice_aval(aval):
return abstract_arrays.ShapedArray(aval.shape[1:], aval.dtype,
aval.weak_type)
return jax_core.ShapedArray(aval.shape[1:], aval.dtype,
aval.weak_type)

def add_slice(cell, old_cell):
new_slices = [
Expand Down
3 changes: 1 addition & 2 deletions oryx/core/ppl/effect_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Tests for oryx.core.ppl.effect_handler."""
from absl.testing import absltest
import jax
from jax import abstract_arrays
from jax import random
import jax.numpy as np

Expand All @@ -40,7 +39,7 @@ def _random_normal_impl(key, loc, scale):
@random_normal_p.def_abstract_eval
def _random_normal_abstract(key, loc, scale):
del key, loc, scale
return [abstract_arrays.ShapedArray((), np.float32)]
return [jax.core.ShapedArray((), np.float32)]


class EffectHandlerTest(test_util.TestCase):
Expand Down
3 changes: 1 addition & 2 deletions oryx/core/ppl/transformations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from absl.testing import absltest

import jax
from jax import abstract_arrays
from jax import random
from jax._src import core as jax_core
from jax.interpreters import batching
Expand Down Expand Up @@ -61,7 +60,7 @@ def random_normal_impl(rng, *, batch_ndims):

def random_normal_abstract(key, **_):
del key
return abstract_arrays.ShapedArray((), jnp.float32)
return jax_core.ShapedArray((), jnp.float32)


def random_normal_log_prob_rule(incells, outcells, *, batch_ndims, **_):
Expand Down
3 changes: 1 addition & 2 deletions oryx/core/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import itertools as it
from typing import Callable

from jax import abstract_arrays
from jax import api_util
from jax import linear_util as lu
from jax import tree_util
Expand Down Expand Up @@ -237,7 +236,7 @@ def subcall(self, name):
tie_all_p.multiple_results = True
tie_all_p.def_impl(lambda *args: args)
tie_all_p.def_abstract_eval(lambda *args: safe_map( # pylint: disable=g-long-lambda
abstract_arrays.raise_to_shaped, args))
jax_core.raise_to_shaped, args))

mlir.register_lowering(tie_all_p, lambda c, *args: args)

Expand Down
5 changes: 2 additions & 3 deletions oryx/core/trace_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import threading
from typing import Any, Dict, Generator, List

from jax import abstract_arrays
from jax import api_util
from jax import linear_util as lu
from jax import tree_util
Expand All @@ -41,9 +40,9 @@
def get_shaped_aval(x):
"""Converts a JAX value type into a shaped abstract value."""
if hasattr(x, 'dtype') and hasattr(x, 'shape'):
return abstract_arrays.ShapedArray(
return jax_core.ShapedArray(
x.shape, dtypes.canonicalize_dtype(x.dtype, allow_opaque_dtype=True))
return abstract_arrays.raise_to_shaped(jax_core.get_aval(x))
return jax_core.raise_to_shaped(jax_core.get_aval(x))


def pv_like(x, abstract=True):
Expand Down

0 comments on commit 3b89034

Please sign in to comment.