Skip to content

Commit

Permalink
Mark jax.abstract_arrays as deprecated
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jun 8, 2023
1 parent 0ec9f3c commit 47ae5bd
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 18 deletions.
20 changes: 10 additions & 10 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.12

* Deprecations
* The following APIs have been removed after a 3 month deprecation period, in
accordance with the {ref}`api-compatibility` policy:
* `jax.numpy.alltrue`: use `jax.numpy.all`. This follows the deprecation
of `numpy.alltrue` in NumPy version 1.25.0.
* `jax.numpy.sometrue`: use `jax.numpy.any`. This follows the deprecation
of `numpy.sometrue` in NumPy version 1.25.0.
* `jax.numpy.product`: use `jax.numpy.prod`. This follows the deprecation
of `numpy.product` in NumPy version 1.25.0.
* `jax.numpy.cumproduct`: use `jax.numpy.cumprod`. This follows the deprecation
of `numpy.cumproduct` in NumPy version 1.25.0.
* `jax.abstract_arrays` and its contents are now deprecated. See related
functionality in :mod:`jax.core`.
* `jax.numpy.alltrue`: use `jax.numpy.all`. This follows the deprecation
of `numpy.alltrue` in NumPy version 1.25.0.
* `jax.numpy.sometrue`: use `jax.numpy.any`. This follows the deprecation
of `numpy.sometrue` in NumPy version 1.25.0.
* `jax.numpy.product`: use `jax.numpy.prod`. This follows the deprecation
of `numpy.product` in NumPy version 1.25.0.
* `jax.numpy.cumproduct`: use `jax.numpy.cumprod`. This follows the deprecation
of `numpy.cumproduct` in NumPy version 1.25.0.

## jaxlib 0.4.12

Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/How_JAX_primitives_work.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@
}
],
"source": [
"from jax._src import abstract_arrays\n",
"from jax import core\n",
"@trace(\"multiply_add_abstract_eval\")\n",
"def multiply_add_abstract_eval(xs, ys, zs):\n",
" \"\"\"Abstract evaluation of the primitive.\n",
Expand All @@ -533,7 +533,7 @@
" \"\"\"\n",
" assert xs.shape == ys.shape\n",
" assert xs.shape == zs.shape\n",
" return abstract_arrays.ShapedArray(xs.shape, xs.dtype)\n",
" return core.ShapedArray(xs.shape, xs.dtype)\n",
"\n",
"# Now we register the abstract evaluation with JAX\n",
"multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)"
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/How_JAX_primitives_work.md
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ In the latter case, JAX uses the actual concrete value wrapped as an abstract va
:id: ctQmEeckIbdo
:outputId: e751d0cc-460e-4ffd-df2e-fdabf9cffdc2
from jax._src import abstract_arrays
from jax import core
@trace("multiply_add_abstract_eval")
def multiply_add_abstract_eval(xs, ys, zs):
"""Abstract evaluation of the primitive.
Expand All @@ -322,7 +322,7 @@ def multiply_add_abstract_eval(xs, ys, zs):
"""
assert xs.shape == ys.shape
assert xs.shape == zs.shape
return abstract_arrays.ShapedArray(xs.shape, xs.dtype)
return core.ShapedArray(xs.shape, xs.dtype)
# Now we register the abstract evaluation with JAX
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)
Expand Down
8 changes: 7 additions & 1 deletion jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@

# These submodules are separate because they are in an import cycle with
# jax and rely on the names imported above.
from jax import abstract_arrays as abstract_arrays
from jax import abstract_arrays as _deprecated_abstract_arrays
from jax import custom_derivatives as custom_derivatives
from jax import custom_batching as custom_batching
from jax import custom_transpose as custom_transpose
Expand Down Expand Up @@ -186,6 +186,11 @@
del _ccache

_deprecations = {
# Added 06 June 2023
"abstract_arrays": (
"jax.abstract_arrays is deprecated. Refer to jax.core.",
_deprecated_abstract_arrays
),
# Added 28 March 2023
"ShapedArray": (
"jax.ShapedArray is deprecated. Use jax.core.ShapedArray",
Expand Down Expand Up @@ -219,6 +224,7 @@

import typing as _typing
if _typing.TYPE_CHECKING:
from jax._src import abstract_arrays as abstract_arrays
from jax._src.core import ShapedArray as ShapedArray
from jax.interpreters import ad as ad
from jax.interpreters import partial_eval as partial_eval
Expand Down
33 changes: 30 additions & 3 deletions jax/abstract_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,35 @@

# TODO(phawkins): fix users of these aliases and delete this file.

from jax._src.abstract_arrays import array_types
from jax._src.abstract_arrays import array_types as _deprecated_array_types
from jax._src.core import (
ShapedArray,
raise_to_shaped,
ShapedArray as _deprecated_ShapedArray,
raise_to_shaped as _deprecated_raise_to_shaped,
)

_deprecations = {
# Added 06 June 2023
"array_types": (
"jax.abstract_arrays.array_types is deprecated.",
_deprecated_array_types,
),
"ShapedArray": (
"jax.abstract_arrays.ShapedArray is deprecated. Use jax.core.ShapedArray.",
_deprecated_ShapedArray,
),
"raise_to_shaped": (
"jax.abstract_arrays.raise_to_shaped is deprecated. Use jax.core.raise_to_shaped.",
_deprecated_raise_to_shaped,
),
}

import typing
if typing.TYPE_CHECKING:
from jax._src.abstract_arrays import array_types as array_types
from jax._src.core import ShapedArray as ShapedArray
from jax._src.core import raise_to_shaped as raise_to_shaped
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing

0 comments on commit 47ae5bd

Please sign in to comment.