Skip to content

Commit

Permalink
Replace uses of deprecated JAX sharding APIs with their new names in …
Browse files Browse the repository at this point in the history
…jax.sharding.

This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 510027595
  • Loading branch information
hawkinsp authored and jax authors committed Feb 16, 2023
1 parent 1b2a318 commit 0af9fff
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 10 deletions.
9 changes: 5 additions & 4 deletions docs/Custom_Operation_for_GPUs.md
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,8 @@ We are using `jax.experimental.pjit.pjit` for parallel execution on multiple dev
Let's first test the forward operation on multiple devices. We are creating a simple 1D mesh and sharding `x` on all devices.
```python
from jax.experimental.maps import Mesh
from jax.experimental.pjit import PartitionSpec, pjit
from jax.sharding import Mesh, PartitionSpec
from jax.experimental.pjit import pjit
mesh = Mesh(jax.local_devices(), ("x",))
Expand Down Expand Up @@ -777,11 +777,12 @@ import jax.numpy as jnp
from build import gpu_ops
from jax import core, dtypes
from jax.abstract_arrays import ShapedArray
from jax.experimental.maps import Mesh, xmap
from jax.experimental.pjit import PartitionSpec, pjit
from jax.experimental.maps import xmap
from jax.experimental.pjit import pjit
from jax.interpreters import mlir, xla
from jax.interpreters.mlir import ir
from jax.lib import xla_client
from jax.sharding import Mesh, PartitionSpec
from jaxlib.mhlo_helpers import custom_call
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/xmap_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@
"source": [
"import jax\n",
"import numpy as np\n",
"from jax.experimental.maps import Mesh\n",
"from jax.sharding import Mesh\n",
"\n",
"loss = xmap(named_loss, in_axes=in_axes, out_axes=[...],\n",
" axis_resources={'batch': 'x'})\n",
Expand Down Expand Up @@ -790,7 +790,7 @@
},
"outputs": [],
"source": [
"from jax.experimental.maps import Mesh\n",
"from jax.sharding import Mesh\n",
"\n",
"local = local_matmul(x, x) # The local function doesn't require the mesh definition\n",
"with Mesh(*mesh_def): # Makes the mesh axis names available as resources\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/xmap_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ But on a whim we can decide to parallelize over the batch axis:
import jax
import numpy as np
from jax.experimental.maps import Mesh
from jax.sharding import Mesh
loss = xmap(named_loss, in_axes=in_axes, out_axes=[...],
axis_resources={'batch': 'x'})
Expand Down Expand Up @@ -536,7 +536,7 @@ To introduce the resources in a scope, use the `with Mesh` context manager:
```{code-cell} ipython3
:id: kYdoeaSS9m9f
from jax.experimental.maps import Mesh
from jax.sharding import Mesh
local = local_matmul(x, x) # The local function doesn't require the mesh definition
with Mesh(*mesh_def): # Makes the mesh axis names available as resources
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/custom_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def infer_sharding_from_operands(arg_shapes, arg_shardings, shape):
from jax.experimental.custom_partitioning import custom_partitioning
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
from jax.experimental.maps import Mesh
from jax.sharding import Mesh
from jax.numpy.fft import fft
import regex as re
import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/gda_serialization/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from jax._src import array
from jax._src import sharding
from jax._src import typing
from jax.experimental.maps import Mesh
from jax.sharding import Mesh
import jax.numpy as jnp
import numpy as np
import tensorstore as ts
Expand Down

0 comments on commit 0af9fff

Please sign in to comment.