diff --git a/docs/Custom_Operation_for_GPUs.md b/docs/Custom_Operation_for_GPUs.md index 052c809df94c..cd21a8f64c07 100644 --- a/docs/Custom_Operation_for_GPUs.md +++ b/docs/Custom_Operation_for_GPUs.md @@ -479,8 +479,8 @@ We are using `jax.experimental.pjit.pjit` for parallel execution on multiple dev Let's first test the forward operation on multiple devices. We are creating a simple 1D mesh and sharding `x` on all devices. ```python -from jax.experimental.maps import Mesh -from jax.experimental.pjit import PartitionSpec, pjit +from jax.sharding import Mesh, PartitionSpec +from jax.experimental.pjit import pjit mesh = Mesh(jax.local_devices(), ("x",)) @@ -777,11 +777,12 @@ import jax.numpy as jnp from build import gpu_ops from jax import core, dtypes from jax.abstract_arrays import ShapedArray -from jax.experimental.maps import Mesh, xmap -from jax.experimental.pjit import PartitionSpec, pjit +from jax.experimental.maps import xmap +from jax.experimental.pjit import pjit from jax.interpreters import mlir, xla from jax.interpreters.mlir import ir from jax.lib import xla_client +from jax.sharding import Mesh, PartitionSpec from jaxlib.mhlo_helpers import custom_call diff --git a/docs/notebooks/xmap_tutorial.ipynb b/docs/notebooks/xmap_tutorial.ipynb index 8ed53f5413c4..a156ecaef390 100644 --- a/docs/notebooks/xmap_tutorial.ipynb +++ b/docs/notebooks/xmap_tutorial.ipynb @@ -169,7 +169,7 @@ "source": [ "import jax\n", "import numpy as np\n", - "from jax.experimental.maps import Mesh\n", + "from jax.sharding import Mesh\n", "\n", "loss = xmap(named_loss, in_axes=in_axes, out_axes=[...],\n", " axis_resources={'batch': 'x'})\n", @@ -790,7 +790,7 @@ }, "outputs": [], "source": [ - "from jax.experimental.maps import Mesh\n", + "from jax.sharding import Mesh\n", "\n", "local = local_matmul(x, x) # The local function doesn't require the mesh definition\n", "with Mesh(*mesh_def): # Makes the mesh axis names available as resources\n", diff --git a/docs/notebooks/xmap_tutorial.md b/docs/notebooks/xmap_tutorial.md index 5135c209c795..4d243a91b28a 100644 --- a/docs/notebooks/xmap_tutorial.md +++ b/docs/notebooks/xmap_tutorial.md @@ -120,7 +120,7 @@ But on a whim we can decide to parallelize over the batch axis: import jax import numpy as np -from jax.experimental.maps import Mesh +from jax.sharding import Mesh loss = xmap(named_loss, in_axes=in_axes, out_axes=[...], axis_resources={'batch': 'x'}) @@ -536,7 +536,7 @@ To introduce the resources in a scope, use the `with Mesh` context manager: ```{code-cell} ipython3 :id: kYdoeaSS9m9f -from jax.experimental.maps import Mesh +from jax.sharding import Mesh local = local_matmul(x, x) # The local function doesn't require the mesh definition with Mesh(*mesh_def): # Makes the mesh axis names available as resources diff --git a/jax/experimental/custom_partitioning.py b/jax/experimental/custom_partitioning.py index dddaecf5304c..012b322d786b 100644 --- a/jax/experimental/custom_partitioning.py +++ b/jax/experimental/custom_partitioning.py @@ -213,7 +213,7 @@ def infer_sharding_from_operands(arg_shapes, arg_shardings, shape): from jax.experimental.custom_partitioning import custom_partitioning from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P - from jax.experimental.maps import Mesh + from jax.sharding import Mesh from jax.numpy.fft import fft import regex as re import numpy as np diff --git a/jax/experimental/gda_serialization/serialization.py b/jax/experimental/gda_serialization/serialization.py index 89dfae2030b3..b84c32f86b04 100644 --- a/jax/experimental/gda_serialization/serialization.py +++ b/jax/experimental/gda_serialization/serialization.py @@ -29,7 +29,7 @@ from jax._src import array from jax._src import sharding from jax._src import typing -from jax.experimental.maps import Mesh +from jax.sharding import Mesh import jax.numpy as jnp import numpy as np import tensorstore as ts