Skip to content

Commit

Permalink
Update mentioning of DeviceArray and ShardedDeviceArray to `jax.A…
Browse files Browse the repository at this point in the history
…rray` in the parallelism tutorial

`jax.Array` is now a unified type for all kinds of arrays.

PiperOrigin-RevId: 537155869
  • Loading branch information
jax authors committed Jun 1, 2023
1 parent 5c2070c commit e990453
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 60 deletions.
112 changes: 56 additions & 56 deletions docs/jax-101/06-parallelism.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
{
"data": {
"text/plain": [
"DeviceArray([11., 20., 29.], dtype=float32)"
"Array([11., 20., 29.], dtype=float32)"
]
},
"execution_count": 5,
Expand Down Expand Up @@ -217,14 +217,14 @@
{
"data": {
"text/plain": [
"DeviceArray([[ 11., 20., 29.],\n",
" [ 56., 65., 74.],\n",
" [101., 110., 119.],\n",
" [146., 155., 164.],\n",
" [191., 200., 209.],\n",
" [236., 245., 254.],\n",
" [281., 290., 299.],\n",
" [326., 335., 344.]], dtype=float32)"
"Array([[ 11., 20., 29.],\n",
" [ 56., 65., 74.],\n",
" [101., 110., 119.],\n",
" [146., 155., 164.],\n",
" [191., 200., 209.],\n",
" [236., 245., 254.],\n",
" [281., 290., 299.],\n",
" [326., 335., 344.]], dtype=float32)"
]
},
"execution_count": 8,
Expand Down Expand Up @@ -258,14 +258,14 @@
{
"data": {
"text/plain": [
"ShardedDeviceArray([[ 11., 20., 29.],\n",
" [ 56., 65., 74.],\n",
" [101., 110., 119.],\n",
" [146., 155., 164.],\n",
" [191., 200., 209.],\n",
" [236., 245., 254.],\n",
" [281., 290., 299.],\n",
" [326., 335., 344.]], dtype=float32)"
"Array([[ 11., 20., 29.],\n",
" [ 56., 65., 74.],\n",
" [101., 110., 119.],\n",
" [146., 155., 164.],\n",
" [191., 200., 209.],\n",
" [236., 245., 254.],\n",
" [281., 290., 299.],\n",
" [326., 335., 344.]], dtype=float32)"
]
},
"execution_count": 9,
Expand All @@ -285,7 +285,7 @@
"id": "E69cVxQPksxe"
},
"source": [
"Note that the parallelized `convolve` returns a `ShardedDeviceArray`. That is because the elements of this array are sharded across all of the devices used in the parallelism. If we were to run another parallel computation, the elements would stay on their respective devices, without incurring cross-device communication costs."
"Note that the parallelized `convolve` returns a `jax.Array`. That is because the elements of this array are sharded across all of the devices used in the parallelism. If we were to run another parallel computation, the elements would stay on their respective devices, without incurring cross-device communication costs."
]
},
{
Expand All @@ -299,14 +299,14 @@
{
"data": {
"text/plain": [
"ShardedDeviceArray([[ 78., 138., 198.],\n",
" [ 1188., 1383., 1578.],\n",
" [ 3648., 3978., 4308.],\n",
" [ 7458., 7923., 8388.],\n",
" [12618., 13218., 13818.],\n",
" [19128., 19863., 20598.],\n",
" [26988., 27858., 28728.],\n",
" [36198., 37203., 38208.]], dtype=float32)"
"Array([[ 78., 138., 198.],\n",
" [ 1188., 1383., 1578.],\n",
" [ 3648., 3978., 4308.],\n",
" [ 7458., 7923., 8388.],\n",
" [12618., 13218., 13818.],\n",
" [19128., 19863., 20598.],\n",
" [26988., 27858., 28728.],\n",
" [36198., 37203., 38208.]], dtype=float32)"
]
},
"execution_count": 11,
Expand Down Expand Up @@ -351,14 +351,14 @@
{
"data": {
"text/plain": [
"ShardedDeviceArray([[ 11., 20., 29.],\n",
" [ 56., 65., 74.],\n",
" [101., 110., 119.],\n",
" [146., 155., 164.],\n",
" [191., 200., 209.],\n",
" [236., 245., 254.],\n",
" [281., 290., 299.],\n",
" [326., 335., 344.]], dtype=float32)"
"Array([[ 11., 20., 29.],\n",
" [ 56., 65., 74.],\n",
" [101., 110., 119.],\n",
" [146., 155., 164.],\n",
" [191., 200., 209.],\n",
" [236., 245., 254.],\n",
" [281., 290., 299.],\n",
" [326., 335., 344.]], dtype=float32)"
]
},
"execution_count": 12,
Expand Down Expand Up @@ -424,14 +424,14 @@
{
"data": {
"text/plain": [
"ShardedDeviceArray([[0.00816024, 0.01408451, 0.019437 ],\n",
" [0.04154303, 0.04577465, 0.04959785],\n",
" [0.07492582, 0.07746479, 0.07975871],\n",
" [0.10830861, 0.10915492, 0.10991956],\n",
" [0.14169139, 0.14084506, 0.14008042],\n",
" [0.17507419, 0.17253521, 0.17024128],\n",
" [0.20845698, 0.20422535, 0.20040214],\n",
" [0.24183977, 0.23591548, 0.23056298]], dtype=float32)"
"Array([[0.00816024, 0.01408451, 0.019437 ],\n",
" [0.04154303, 0.04577465, 0.04959785],\n",
" [0.07492582, 0.07746479, 0.07975871],\n",
" [0.10830861, 0.10915492, 0.10991956],\n",
" [0.14169139, 0.14084506, 0.14008042],\n",
" [0.17507419, 0.17253521, 0.17024128],\n",
" [0.20845698, 0.20422535, 0.20040214],\n",
" [0.24183977, 0.23591548, 0.23056298]], dtype=float32)"
]
},
"execution_count": 13,
Expand Down Expand Up @@ -474,14 +474,14 @@
{
"data": {
"text/plain": [
"DeviceArray([[0.00816024, 0.01408451, 0.019437 ],\n",
" [0.04154303, 0.04577465, 0.04959785],\n",
" [0.07492582, 0.07746479, 0.07975871],\n",
" [0.10830861, 0.10915492, 0.10991956],\n",
" [0.14169139, 0.14084506, 0.14008042],\n",
" [0.17507419, 0.17253521, 0.17024128],\n",
" [0.20845698, 0.20422535, 0.20040214],\n",
" [0.24183977, 0.23591548, 0.23056298]], dtype=float32)"
"Array([[0.00816024, 0.01408451, 0.019437 ],\n",
" [0.04154303, 0.04577465, 0.04959785],\n",
" [0.07492582, 0.07746479, 0.07975871],\n",
" [0.10830861, 0.10915492, 0.10991956],\n",
" [0.14169139, 0.14084506, 0.14008042],\n",
" [0.17507419, 0.17253521, 0.17024128],\n",
" [0.20845698, 0.20422535, 0.20040214],\n",
" [0.24183977, 0.23591548, 0.23056298]], dtype=float32)"
]
},
"execution_count": 14,
Expand Down Expand Up @@ -634,7 +634,7 @@
"id": "dmCMyLP9SV99"
},
"source": [
"So far, we've just constructed arrays with an additional leading dimension. The params are all still all on the host (CPU). `pmap` will communicate them to the devices when `update()` is first called, and each copy will stay on its own device subsequently. You can tell because they are a DeviceArray, not a ShardedDeviceArray:"
"So far, we've just constructed arrays with an additional leading dimension. The params are all still all on the host (CPU). `pmap` will communicate them to the devices when `update()` is first called, and each copy will stay on its own device subsequently."
]
},
{
Expand All @@ -648,7 +648,7 @@
{
"data": {
"text/plain": [
"jax.interpreters.xla._DeviceArray"
"jax.Array"
]
},
"execution_count": 19,
Expand All @@ -668,7 +668,7 @@
"id": "90VtjPbeY-hD"
},
"source": [
"The params will become a ShardedDeviceArray when they are returned by our pmapped `update()` (see further down)."
"The params will become a jax.Array when they are returned by our pmapped `update()` (see further down)."
]
},
{
Expand Down Expand Up @@ -734,8 +734,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"after first `update()`, `replicated_params.weight` is a <class 'jax.interpreters.pxla.ShardedDeviceArray'>\n",
"after first `update()`, `loss` is a <class 'jax.interpreters.pxla.ShardedDeviceArray'>\n",
"after first `update()`, `replicated_params.weight` is a <class 'jax.Array'>\n",
"after first `update()`, `loss` is a <class 'jax.Array'>\n",
"after first `update()`, `x_split` is a <class 'numpy.ndarray'>\n",
"Step 0, loss: 0.228\n",
"Step 100, loss: 0.228\n",
Expand All @@ -760,7 +760,7 @@
" # This is where the params and data gets communicated to devices:\n",
" replicated_params, loss = update(replicated_params, x_split, y_split)\n",
"\n",
" # The returned `replicated_params` and `loss` are now both ShardedDeviceArrays,\n",
" # The returned `replicated_params` and `loss` are now both jax.Arrays,\n",
" # indicating that they're on the devices.\n",
" # `x_split`, of course, remains a NumPy array on the host.\n",
" if i == 0:\n",
Expand Down
8 changes: 4 additions & 4 deletions docs/jax-101/06-parallelism.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ jax.pmap(convolve)(xs, ws)

+++ {"id": "E69cVxQPksxe"}

Note that the parallelized `convolve` returns a `ShardedDeviceArray`. That is because the elements of this array are sharded across all of the devices used in the parallelism. If we were to run another parallel computation, the elements would stay on their respective devices, without incurring cross-device communication costs.
Note that the parallelized `convolve` returns a `jax.Array`. That is because the elements of this array are sharded across all of the devices used in the parallelism. If we were to run another parallel computation, the elements would stay on their respective devices, without incurring cross-device communication costs.

```{code-cell} ipython3
:id: P9dUyk-ciquy
Expand Down Expand Up @@ -298,7 +298,7 @@ replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)

+++ {"id": "dmCMyLP9SV99"}

So far, we've just constructed arrays with an additional leading dimension. The params are all still all on the host (CPU). `pmap` will communicate them to the devices when `update()` is first called, and each copy will stay on its own device subsequently. You can tell because they are a DeviceArray, not a ShardedDeviceArray:
So far, we've just constructed arrays with an additional leading dimension. The params are all still all on the host (CPU). `pmap` will communicate them to the devices when `update()` is first called, and each copy will stay on its own device subsequently.

```{code-cell} ipython3
:id: YSCgHguTSdGW
Expand All @@ -309,7 +309,7 @@ type(replicated_params.weight)

+++ {"id": "90VtjPbeY-hD"}

The params will become a ShardedDeviceArray when they are returned by our pmapped `update()` (see further down).
The params will become a jax.Array when they are returned by our pmapped `update()` (see further down).

+++ {"id": "eGVKxk1CV-m1"}

Expand Down Expand Up @@ -347,7 +347,7 @@ for i in range(1000):
# This is where the params and data gets communicated to devices:
replicated_params, loss = update(replicated_params, x_split, y_split)
# The returned `replicated_params` and `loss` are now both ShardedDeviceArrays,
# The returned `replicated_params` and `loss` are now both jax.Arrays,
# indicating that they're on the devices.
# `x_split`, of course, remains a NumPy array on the host.
if i == 0:
Expand Down

0 comments on commit e990453

Please sign in to comment.