From e99045381de5e159f52594a49ba95c5944d341af Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 1 Jun 2023 16:11:54 -0700 Subject: [PATCH] Update mentioning of `DeviceArray` and `ShardedDeviceArray` to `jax.Array` in the parallelism tutorial `jax.Array` is now a unified type for all kinds of arrays. PiperOrigin-RevId: 537155869 --- docs/jax-101/06-parallelism.ipynb | 112 +++++++++++++++--------------- docs/jax-101/06-parallelism.md | 8 +-- 2 files changed, 60 insertions(+), 60 deletions(-) diff --git a/docs/jax-101/06-parallelism.ipynb b/docs/jax-101/06-parallelism.ipynb index 9211efc2a7be..d48580a6fff1 100644 --- a/docs/jax-101/06-parallelism.ipynb +++ b/docs/jax-101/06-parallelism.ipynb @@ -94,7 +94,7 @@ { "data": { "text/plain": [ - "DeviceArray([11., 20., 29.], dtype=float32)" + "Array([11., 20., 29.], dtype=float32)" ] }, "execution_count": 5, @@ -217,14 +217,14 @@ { "data": { "text/plain": [ - "DeviceArray([[ 11., 20., 29.],\n", - " [ 56., 65., 74.],\n", - " [101., 110., 119.],\n", - " [146., 155., 164.],\n", - " [191., 200., 209.],\n", - " [236., 245., 254.],\n", - " [281., 290., 299.],\n", - " [326., 335., 344.]], dtype=float32)" + "Array([[ 11., 20., 29.],\n", + " [ 56., 65., 74.],\n", + " [101., 110., 119.],\n", + " [146., 155., 164.],\n", + " [191., 200., 209.],\n", + " [236., 245., 254.],\n", + " [281., 290., 299.],\n", + " [326., 335., 344.]], dtype=float32)" ] }, "execution_count": 8, @@ -258,14 +258,14 @@ { "data": { "text/plain": [ - "ShardedDeviceArray([[ 11., 20., 29.],\n", - " [ 56., 65., 74.],\n", - " [101., 110., 119.],\n", - " [146., 155., 164.],\n", - " [191., 200., 209.],\n", - " [236., 245., 254.],\n", - " [281., 290., 299.],\n", - " [326., 335., 344.]], dtype=float32)" + "Array([[ 11., 20., 29.],\n", + " [ 56., 65., 74.],\n", + " [101., 110., 119.],\n", + " [146., 155., 164.],\n", + " [191., 200., 209.],\n", + " [236., 245., 254.],\n", + " [281., 290., 299.],\n", + " [326., 335., 344.]], dtype=float32)" ] }, "execution_count": 9, @@ -285,7 +285,7 @@ "id": "E69cVxQPksxe" }, "source": [ - "Note that the parallelized `convolve` returns a `ShardedDeviceArray`. That is because the elements of this array are sharded across all of the devices used in the parallelism. If we were to run another parallel computation, the elements would stay on their respective devices, without incurring cross-device communication costs." + "Note that the parallelized `convolve` returns a `jax.Array`. That is because the elements of this array are sharded across all of the devices used in the parallelism. If we were to run another parallel computation, the elements would stay on their respective devices, without incurring cross-device communication costs." ] }, { @@ -299,14 +299,14 @@ { "data": { "text/plain": [ - "ShardedDeviceArray([[ 78., 138., 198.],\n", - " [ 1188., 1383., 1578.],\n", - " [ 3648., 3978., 4308.],\n", - " [ 7458., 7923., 8388.],\n", - " [12618., 13218., 13818.],\n", - " [19128., 19863., 20598.],\n", - " [26988., 27858., 28728.],\n", - " [36198., 37203., 38208.]], dtype=float32)" + "Array([[ 78., 138., 198.],\n", + " [ 1188., 1383., 1578.],\n", + " [ 3648., 3978., 4308.],\n", + " [ 7458., 7923., 8388.],\n", + " [12618., 13218., 13818.],\n", + " [19128., 19863., 20598.],\n", + " [26988., 27858., 28728.],\n", + " [36198., 37203., 38208.]], dtype=float32)" ] }, "execution_count": 11, @@ -351,14 +351,14 @@ { "data": { "text/plain": [ - "ShardedDeviceArray([[ 11., 20., 29.],\n", - " [ 56., 65., 74.],\n", - " [101., 110., 119.],\n", - " [146., 155., 164.],\n", - " [191., 200., 209.],\n", - " [236., 245., 254.],\n", - " [281., 290., 299.],\n", - " [326., 335., 344.]], dtype=float32)" + "Array([[ 11., 20., 29.],\n", + " [ 56., 65., 74.],\n", + " [101., 110., 119.],\n", + " [146., 155., 164.],\n", + " [191., 200., 209.],\n", + " [236., 245., 254.],\n", + " [281., 290., 299.],\n", + " [326., 335., 344.]], dtype=float32)" ] }, "execution_count": 12, @@ -424,14 +424,14 @@ { "data": { "text/plain": [ - "ShardedDeviceArray([[0.00816024, 0.01408451, 0.019437 ],\n", - " [0.04154303, 0.04577465, 0.04959785],\n", - " [0.07492582, 0.07746479, 0.07975871],\n", - " [0.10830861, 0.10915492, 0.10991956],\n", - " [0.14169139, 0.14084506, 0.14008042],\n", - " [0.17507419, 0.17253521, 0.17024128],\n", - " [0.20845698, 0.20422535, 0.20040214],\n", - " [0.24183977, 0.23591548, 0.23056298]], dtype=float32)" + "Array([[0.00816024, 0.01408451, 0.019437 ],\n", + " [0.04154303, 0.04577465, 0.04959785],\n", + " [0.07492582, 0.07746479, 0.07975871],\n", + " [0.10830861, 0.10915492, 0.10991956],\n", + " [0.14169139, 0.14084506, 0.14008042],\n", + " [0.17507419, 0.17253521, 0.17024128],\n", + " [0.20845698, 0.20422535, 0.20040214],\n", + " [0.24183977, 0.23591548, 0.23056298]], dtype=float32)" ] }, "execution_count": 13, @@ -474,14 +474,14 @@ { "data": { "text/plain": [ - "DeviceArray([[0.00816024, 0.01408451, 0.019437 ],\n", - " [0.04154303, 0.04577465, 0.04959785],\n", - " [0.07492582, 0.07746479, 0.07975871],\n", - " [0.10830861, 0.10915492, 0.10991956],\n", - " [0.14169139, 0.14084506, 0.14008042],\n", - " [0.17507419, 0.17253521, 0.17024128],\n", - " [0.20845698, 0.20422535, 0.20040214],\n", - " [0.24183977, 0.23591548, 0.23056298]], dtype=float32)" + "Array([[0.00816024, 0.01408451, 0.019437 ],\n", + " [0.04154303, 0.04577465, 0.04959785],\n", + " [0.07492582, 0.07746479, 0.07975871],\n", + " [0.10830861, 0.10915492, 0.10991956],\n", + " [0.14169139, 0.14084506, 0.14008042],\n", + " [0.17507419, 0.17253521, 0.17024128],\n", + " [0.20845698, 0.20422535, 0.20040214],\n", + " [0.24183977, 0.23591548, 0.23056298]], dtype=float32)" ] }, "execution_count": 14, @@ -634,7 +634,7 @@ "id": "dmCMyLP9SV99" }, "source": [ - "So far, we've just constructed arrays with an additional leading dimension. The params are all still all on the host (CPU). `pmap` will communicate them to the devices when `update()` is first called, and each copy will stay on its own device subsequently. You can tell because they are a DeviceArray, not a ShardedDeviceArray:" + "So far, we've just constructed arrays with an additional leading dimension. The params are all still all on the host (CPU). `pmap` will communicate them to the devices when `update()` is first called, and each copy will stay on its own device subsequently." ] }, { @@ -648,7 +648,7 @@ { "data": { "text/plain": [ - "jax.interpreters.xla._DeviceArray" + "jax.Array" ] }, "execution_count": 19, @@ -668,7 +668,7 @@ "id": "90VtjPbeY-hD" }, "source": [ - "The params will become a ShardedDeviceArray when they are returned by our pmapped `update()` (see further down)." + "The params will become a jax.Array when they are returned by our pmapped `update()` (see further down)." ] }, { @@ -734,8 +734,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "after first `update()`, `replicated_params.weight` is a \n", - "after first `update()`, `loss` is a \n", + "after first `update()`, `replicated_params.weight` is a \n", + "after first `update()`, `loss` is a \n", "after first `update()`, `x_split` is a \n", "Step 0, loss: 0.228\n", "Step 100, loss: 0.228\n", @@ -760,7 +760,7 @@ " # This is where the params and data gets communicated to devices:\n", " replicated_params, loss = update(replicated_params, x_split, y_split)\n", "\n", - " # The returned `replicated_params` and `loss` are now both ShardedDeviceArrays,\n", + " # The returned `replicated_params` and `loss` are now both jax.Arrays,\n", " # indicating that they're on the devices.\n", " # `x_split`, of course, remains a NumPy array on the host.\n", " if i == 0:\n", diff --git a/docs/jax-101/06-parallelism.md b/docs/jax-101/06-parallelism.md index 03688f08d551..a6fe3f1ac6b4 100644 --- a/docs/jax-101/06-parallelism.md +++ b/docs/jax-101/06-parallelism.md @@ -114,7 +114,7 @@ jax.pmap(convolve)(xs, ws) +++ {"id": "E69cVxQPksxe"} -Note that the parallelized `convolve` returns a `ShardedDeviceArray`. That is because the elements of this array are sharded across all of the devices used in the parallelism. If we were to run another parallel computation, the elements would stay on their respective devices, without incurring cross-device communication costs. +Note that the parallelized `convolve` returns a `jax.Array`. That is because the elements of this array are sharded across all of the devices used in the parallelism. If we were to run another parallel computation, the elements would stay on their respective devices, without incurring cross-device communication costs. ```{code-cell} ipython3 :id: P9dUyk-ciquy @@ -298,7 +298,7 @@ replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params) +++ {"id": "dmCMyLP9SV99"} -So far, we've just constructed arrays with an additional leading dimension. The params are all still all on the host (CPU). `pmap` will communicate them to the devices when `update()` is first called, and each copy will stay on its own device subsequently. You can tell because they are a DeviceArray, not a ShardedDeviceArray: +So far, we've just constructed arrays with an additional leading dimension. The params are all still all on the host (CPU). `pmap` will communicate them to the devices when `update()` is first called, and each copy will stay on its own device subsequently. ```{code-cell} ipython3 :id: YSCgHguTSdGW @@ -309,7 +309,7 @@ type(replicated_params.weight) +++ {"id": "90VtjPbeY-hD"} -The params will become a ShardedDeviceArray when they are returned by our pmapped `update()` (see further down). +The params will become a jax.Array when they are returned by our pmapped `update()` (see further down). +++ {"id": "eGVKxk1CV-m1"} @@ -347,7 +347,7 @@ for i in range(1000): # This is where the params and data gets communicated to devices: replicated_params, loss = update(replicated_params, x_split, y_split) - # The returned `replicated_params` and `loss` are now both ShardedDeviceArrays, + # The returned `replicated_params` and `loss` are now both jax.Arrays, # indicating that they're on the devices. # `x_split`, of course, remains a NumPy array on the host. if i == 0: