diff --git a/docs/guides/flax_on_pjit.ipynb b/docs/guides/flax_on_pjit.ipynb index 160f334d09..d5eced46e7 100644 --- a/docs/guides/flax_on_pjit.ipynb +++ b/docs/guides/flax_on_pjit.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "2a9f78765c0c" @@ -12,23 +13,27 @@ ] }, { + "attachments": {}, "cell_type": "markdown", - "metadata": {}, + "metadata": { + "id": "b1e0e5fc8bc1" + }, "source": [ "## Flax and `jax.jit` scaled up\n", "\n", - "`jax.jit` follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm and automatically compiles your code to run it on multiple devices. You need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.\n", - "\n", - "To learn more about `jax.jit` APIs for scaling-up, refer to [JAX in multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html) and [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html).\n", + "[`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm and automatically compiles your code to run it on multiple devices. You need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.\n", "\n", "Flax provides several functionalities that can help you use auto-SPMD on [Flax Modules](https://flax.readthedocs.io/en/latest/developer_notes/module_lifecycle.html), including:\n", "\n", "1. An interface to specify partitions of your data when defining [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#module).\n", "2. Utility functions to generate the sharding information that `jax.jit` requires to run.\n", - "3. An interface to customize your axis names called \"logical axis annotations\" to decouple both your Module code and partition plan to experiment with different partition layouts more easily." + "3. An interface to customize your axis names called \"logical axis annotations\" to decouple both your Module code and partition plan to experiment with different partition layouts more easily.\n", + "\n", + "You can learn more about `jax.jit` APIs for scaling up in [JAX in multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html) and [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) on JAX's documentation site." ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "a9601432b448" @@ -38,22 +43,7 @@ "\n", "Import some necessary dependencies.\n", "\n", - "**Note:** This guide uses the `--xla_force_host_platform_device_count=8` flag to emulate multiple devices in a CPU environment in a Google Colab/Jupyter Notebook. You don't need this if you are already running on a multi-device TPU environment." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "867203db3bef", - "tags": [ - "skip-execution" - ] - }, - "outputs": [], - "source": [ - "# Once Flax v0.6.10 is released, no need to do this.\n", - "# ! pip3 install -qq \"git+https://github.com/google/flax.git@main#egg=flax\"" + "**Note:** This guide uses the `--xla_force_host_platform_device_count=8` flag to emulate multiple devices in a CPU environment in a Google Colab/Jupyter Notebook. You don't need this if you are already using a multi-device TPU environment." ] }, { @@ -94,7 +84,9 @@ { "cell_type": "code", "execution_count": 4, - "metadata": {}, + "metadata": { + "id": "bcc30de1d6eb" + }, "outputs": [ { "name": "stdout", @@ -109,20 +101,21 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "c0d280def897" }, "source": [ - "Import and set up the JAX-level device API following [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html):\n", + "The code below shows how to import and set up the JAX-level device API, following JAX's [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) guide:\n", "\n", - "1. Start a 2x4 device `mesh` (8 devices)—this is the same as the layout of [TPU v3-8](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#single_tpu_board).\n", + "1. Start a 2x4 device `mesh` (8 devices) using JAX's `mesh_utils.create_device_mesh`. This layout is the same as the one of a [TPU v3-8](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#single_tpu_board).\n", "\n", - "2. Annotate each axis with a name. A typical way to annotate axis names is `('data', 'model')`, where:\n", + "2. Annotate each axis with a name using the `axis_names` parameter in `jax.sharding.Mesh`. A typical way to annotate axis names is `axis_name=('data', 'model')`, where:\n", " * `'data'`: the mesh dimension used for data-parallel sharding of the batch dimension of inputs and activations.\n", " * `'model'`: the mesh dimension used for sharding parameters of the model across devices.\n", " \n", - "3. Make a simple util `mesh_sharding` to generate a sharding object from the mesh and any layout." + "3. Make a simple utility function `mesh_sharding` for generating a sharding object from the mesh and any layout." ] }, { @@ -141,7 +134,9 @@ { "cell_type": "code", "execution_count": 6, - "metadata": {}, + "metadata": { + "id": "4589d7a6d4bb" + }, "outputs": [ { "name": "stdout", @@ -163,7 +158,7 @@ "print(mesh)\n", "\n", "def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:\n", - " return NamedSharding(mesh, pspec)" + " return NamedSharding(mesh, pspec)" ] }, { @@ -175,13 +170,13 @@ "source": [ "## Define a layer\n", "\n", - "Before defining a model, create an example layer called `DotReluDot` (by subclassing `flax.linen.Module`), which creates two parameters `W1` and `W2` for dot product multiplication, and uses the `jax.nn.relu` (ReLU) activation function in-between.\n", + "Before defining a simple model, create an example layer called `DotReluDot` (by subclassing `flax.linen.Module`). The layer creates two parameters `W1` and `W2` for dot product multiplication, and uses the `jax.nn.relu` (ReLU) activation function in-between.\n", "\n", "To shard the parameters efficiently, apply the following APIs to annotate the parameters and intermediate variables:\n", "\n", "1. Use [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.with_partitioning.html#flax.linen.with_partitioning) to decorate the initializer function when creating sub-layers or raw parameters.\n", "\n", - "2. Apply [`jax.lax.with_sharding_constraint`](https://github.com/google/jax/blob/main/jax/_src/pjit.py#L1516) to annotate intermediate variables like `y` and `z` to force a particular sharding pattern when the ideal constraint is known.\n", + "2. Apply [`jax.lax.with_sharding_constraint`](https://github.com/google/jax/blob/main/jax/_src/pjit.py#L1516) (formerly, `pjit.with_sharding_constraint`) to annotate intermediate variables like `y` and `z` to force a particular sharding pattern when the ideal constraint is known.\n", "\n", " * This step is optional, but can sometimes help auto-SPMD to partition efficiently. In the example below, the call is not required, because XLA will figure out the same sharding layout for `y` and `z` regardless." ] @@ -223,12 +218,13 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "cbac5321c08e" }, "source": [ - "Note that device axis names like `'data'`, `'model'` or `None` are passed into both `flax.linen.with_partitioning` and `with_sharding_constraint` API calls. This refers to how each dimension of this data should be sharded — either across one of the device mesh dimensions, or not sharded at all.\n", + "Note that device axis names like `'data'`, `'model'` or `None` are passed into both [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.with_partitioning.html) and [`jax.lax.with_sharding_constraint`](https://github.com/google/jax/blob/main/jax/_src/pjit.py#L1516) API calls. This refers to how each dimension of this data should be sharded — either across one of the device mesh dimensions, or not sharded at all.\n", "\n", "For example:\n", "\n", @@ -244,6 +240,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "b8389c11af79" @@ -251,14 +248,14 @@ "source": [ "## Define a model with `flax.linen.scan` lifted transformation\n", "\n", - "Having created `DotReluDot`, define the `MLP` model (by subclassing `flax.linen.Module`) as multiple layers of `DotReluDot`.\n", + "Having created `DotReluDot`, you can now define the `MLP` model (by subclassing [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module)) as multiple layers of `DotReluDot`.\n", "\n", - "To replicate identical layers, you can either use `flax.linen.scan`, or a for-loop:\n", + "To replicate identical layers, you can either use [`flax.linen.scan`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.scan.html), or a for-loop:\n", "\n", - "* `flax.linen.scan` can offer faster compilation times.\n", + "* `flax.linen.scan` can provide faster compilation times.\n", "* The for-loop can be faster on runtime.\n", "\n", - "The code below shows how to apply both methods, and default with the for-loop, so that all the parameters are two-dimentional and we can visualize their sharding. \n", + "The code below shows how to apply both methods, and default with the for-loop, so that all the parameters are two-dimensional and you can visualize their sharding. \n", "\n", "The `flax.linen.scan` code is just to show that this API works with [Flax lifted transforms](https://flax.readthedocs.io/en/latest/developer_notes/lift.html#supported-transformations)." ] @@ -290,16 +287,21 @@ ] }, { + "attachments": {}, "cell_type": "markdown", - "metadata": {}, + "metadata": { + "id": "44395b62561d" + }, "source": [ - "Now we make a `model` instance, and a sample input `x`." + "Now, create a `model` instance, and a sample input `x`." ] }, { "cell_type": "code", "execution_count": 45, - "metadata": {}, + "metadata": { + "id": "5686299b4839" + }, "outputs": [], "source": [ "# MLP hyperparameters.\n", @@ -316,6 +318,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "5b3abfef359d" @@ -323,11 +326,11 @@ "source": [ "## Specify sharding\n", "\n", - "Next, we need to tell `jax.jit` how to share our data across devices.\n", + "Next, you need to tell `jax.jit` how to shard our data across devices.\n", "\n", - "### Input's sharding\n", + "### The input's sharding\n", "\n", - "For data parallelism, we shard the batched _input_ `x` across the `data` axis by denoting the batch axis as `data`. Then, use `jax.device_put` to place it into the correct devices." + "For data parallelism, you can shard the batched _input_ `x` across the `data` axis by denoting the batch axis as `'data'`. Then, use [`jax.device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html) to place it onto the correct `device`s." ] }, { @@ -378,27 +381,30 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "06d134795ae1" }, "source": [ - "### Output's sharding\n", + "### The output's sharding\n", "\n", - "We want to compile `model.init()`, and its output is a pytree of parameters. Sometimes we even wrap it with a `flax.training.train_state` to track other variables like optimizer states, and that makes the output an even more complex pytree.\n", + "You need to compile `model.init()` (that is, [`flax.linen.Module.init()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.init)), and its output as a pytree of parameters. Additionally, you may sometimes need wrap it with a [`flax.training.train_state`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) to track other variables, such as optimizer states, and that would make the output an even more complex pytree.\n", "\n", - "Luckily we don't have to hardcode the output's sharding by hand. We do:\n", + "To achieve this, luckily, you don't have to hardcode the output's sharding by hand. Instead, you can:\n", "\n", "1. Evaluate `model.init` (in this case, a wrapper of it) abstractly using [`jax.eval_shape`](https://jax.readthedocs.io/en/latest/_autosummary/jax.eval_shape.html).\n", "\n", "1. Use [`flax.linen.get_sharding`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.get_sharding.html) to automatically generate the `jax.sharding.NamedSharding`.\n", - " * This steps utilizes the `nn.with_partitioning` annotations in earlier definition to genereate the correct sharding for the params." + " * This step utilizes the [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.with_partitioning.html) annotations in the earlier definition to generate the correct sharding for the parameters." ] }, { "cell_type": "code", "execution_count": 47, - "metadata": {}, + "metadata": { + "id": "19094ec63385" + }, "outputs": [], "source": [ "def init_fn(k, x, model, optimizer):\n", @@ -413,7 +419,9 @@ { "cell_type": "code", "execution_count": 48, - "metadata": {}, + "metadata": { + "id": "e49264a3c78e" + }, "outputs": [ { "data": { @@ -508,7 +516,7 @@ ], "source": [ "# Create an abstract closure to wrap the function before feeding it in\n", - "# because `jax.eval_shape` only takes pytrees as arguments`.\n", + "# because `jax.eval_shape` only takes pytrees as arguments.\n", "abstract_variables = jax.eval_shape(\n", " functools.partial(init_fn, model=model, optimizer=optimizer), k, x)\n", "\n", @@ -519,6 +527,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "2ec24614050b" @@ -534,7 +543,9 @@ { "cell_type": "code", "execution_count": 49, - "metadata": {}, + "metadata": { + "id": "5b6e699df733" + }, "outputs": [ { "data": { @@ -613,6 +624,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "8f74b009f11f" @@ -651,8 +663,11 @@ ] }, { + "attachments": {}, "cell_type": "markdown", - "metadata": {}, + "metadata": { + "id": "2beee7d27bdb" + }, "source": [ "You can also check the underlying [`jax.sharding`](https://jax.readthedocs.io/en/latest/jax.sharding.html) of each parameter, which is now more internal than `NamedSharding`. Note that numbers like `initialized_state.step` are replicated across all devices." ] @@ -682,7 +697,9 @@ { "cell_type": "code", "execution_count": 16, - "metadata": {}, + "metadata": { + "id": "d7cf0baa334b" + }, "outputs": [ { "name": "stdout", @@ -708,6 +725,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "273547d3ab89" @@ -749,6 +767,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "f7e1ccb14c6b" @@ -756,7 +775,7 @@ "source": [ "## Compile the train step and inference \n", "\n", - "Now, you create a `jit`ted training step:" + "Create a `jit`ted training step as follows:" ] }, { @@ -786,7 +805,9 @@ { "cell_type": "code", "execution_count": 19, - "metadata": {}, + "metadata": { + "id": "91c6c2662c12" + }, "outputs": [ { "name": "stdout", @@ -873,12 +894,13 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "2bae79e2e71b" }, "source": [ - "And a compiled inference step. Note that the output is also sharded along `(data, None)`." + "Then, create a compiled inference step. Note that the output is also sharded along `(data, None)`." ] }, { @@ -946,6 +968,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "7daa9e6e6eb4" @@ -983,6 +1006,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "51420b514d53" @@ -990,13 +1014,13 @@ "source": [ "## Logical axis annotation\n", "\n", - "JAX auto SPMD encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you actually can annotate the dimensions of any data with more descriptive axis names (not just device mesh axis names like `'data'` and `'model'`). \n", + "JAX's automatic SPMD encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you actually can annotate the dimensions of any data with more descriptive axis names (not just device mesh axis names like `'data'` and `'model'`). \n", "\n", "The `LogicalDotReluDot` and `LogicalMLP` Module definition below are similar to the Modules you created earlier, except for the following:\n", "\n", "1. All axes are annotated with more concrete, meaningful names, such as `'embed'`, `'hidden'`, `'batch'` and `'layer'`. These names are referred to as _logical axis names_ in Flax. They make the dimensional changes inside model definitions more readable.\n", "\n", - "2. `nn.with_logical_partitioning` replaces `nn.with_partitioning`; and `nn.with_logical_constraint` replaces `jax.lax.with_sharding_constraint`, to recognize the logical axis names." + "2. [`flax.linen.with_logical_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.with_logical_partitioning.html) replaces `flax.linen.with_partitioning`; and [`flax.linen.with_logical_constraint`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.with_logical_constraint.html#flax-linen-with-logical-constraint) replaces `jax.lax.with_sharding_constraint`, to recognize the logical axis names." ] }, { @@ -1050,14 +1074,15 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "0de93ec6cbd6" }, "source": [ - "Now initiate a model and try to figure out what sharding its `state` should have.\n", + "Now, initiate a model and try to figure out what sharding its `state` should have.\n", "\n", - "To allow the device mesh to take your model correctly, you need to decide which of these logical axis names are mapped to the device axis `'data'` or `'model'`. This rule is a list of (`logical_axis_name`, `device_axis_name`) tuples, and `nn.logical_to_mesh_sharding` will convert them to the sharding that the device mesh understands.\n", + "To allow the device mesh to take your model correctly, you need to decide which of these logical axis names are mapped to the device axis `'data'` or `'model'`. This rule is a list of (`logical_axis_name`, `device_axis_name`) tuples, and [`flax.linen.logical_to_mesh_sharding`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.logical_to_mesh_sharding.html#flax-linen-logical-to-mesh-sharding) will convert them to the kind of sharding that the device mesh can understand.\n", "\n", "This allows you to change the rules and try out new partition layouts without modifying the model definition." ] @@ -1097,12 +1122,13 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "58475fffb2de" }, "source": [ - "You can verify that the `logical_state_spec` here has the same content as `state_spec` in the previous (\"non-logical\") example. This allows you to `jax.jit` your module's `init` and `apply`, same as above." + "You can verify that the `logical_state_spec` here has the same content as `state_spec` in the previous (\"non-logical\") example. This allows you to `jax.jit` your Module's [`flax.linen.Module.init`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.init) and [`flax.linen.Module.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.apply) the same way in the above above." ] }, { @@ -1145,7 +1171,9 @@ { "cell_type": "code", "execution_count": 54, - "metadata": {}, + "metadata": { + "id": "fb53bc20e0f9" + }, "outputs": [ { "name": "stdout", @@ -1232,6 +1260,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "ae1754a3031d" @@ -1239,16 +1268,17 @@ "source": [ "## When to use device axis / logical axis\n", "\n", - "Choosing when to use a device or logical axis depends on how much you want to control the partitioning of your model.\n", + "Choosing when to use a device or logical axis depends on how much you want to control the partitioning of your model:\n", "\n", - "If you want a very simple model, or you are very confident of your way of partitioning, defining it with __device mesh axis__ can potentially save you a few extra lines of code of converting the logical naming back to the device naming.\n", + "* **Device mesh axis**: If you want a very simple model, or you are very confident of your way of partitioning, defining it with __device mesh axis__ can potentially save you a few extra lines of code of converting the logical naming back to the device naming.\n", "\n", - "On the other hand, the __logical naming__ helpers are useful for exploring different sharding layouts. Use this if you want to experiment around and find the most optimal partition layout for your model.\n", + "* **logical naming**: On the other hand, the __logical naming__ helpers can be useful for exploring different sharding layouts. Use this if you want to experiment around and find the most optimal partition layout for your model.\n", "\n", - "In really advanced use cases, you may have more complicated sharding patterns that require annotating *activation* dimension names differently from *parameter* dimension names. When people wish to have more fine-grained control on manual mesh assignments, directly using __device axis names__ could be more helpful." + "* **Device axis names**: In really advanced use cases, you may have more complicated sharding patterns that require annotating *activation* dimension names differently from *parameter* dimension names. If you wish to have more fine-grained control on manual mesh assignments, directly using __device axis names__ could be more helpful." ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "576bdd5cd782" @@ -1256,9 +1286,9 @@ "source": [ "## Save the data\n", "\n", - "You can use [`flax.training.checkpoints`](https://flax.readthedocs.io/en/latest/_modules/flax/training/checkpoints.html) to save the cross-device array, as shown in the [Save and load checkpoints guide - Multi-host/multi-process checkpointing](https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#multi-host-multi-process-checkpointing). This is especially required if you are running on a multi-host environment (for example, a TPU pod).\n", + "To save the cross-device array, you can use [`flax.training.checkpoints`](https://flax.readthedocs.io/en/latest/_modules/flax/training/checkpoints.html), as shown in the [Save and load checkpoints guide - Multi-host/multi-process checkpointing](https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#multi-host-multi-process-checkpointing). This is especially required if you are running on a multi-host environment (for example, a TPU pod).\n", "\n", - "Keep in mind that to restore the arrays to the desired partition, you need to provide a sample `target` pytree that has the same structure and has the desired `jax.sharding.Sharding` in place for each JAX array. The sharding you use to restore the array doesn't necessarily need to be the same as the ones you used to store the array." + "Keep in mind that to restore the arrays to the desired partition, you need to provide a sample `target` pytree that has the same structure and has the desired [`jax.sharding.Sharding`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.Sharding) in place for each JAX array. The sharding you use to restore the array doesn't necessarily need to be the same as the ones you used to store the array." ] } ], @@ -1267,18 +1297,9 @@ "formats": "ipynb,md:myst" }, "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.15" + "name": "python" } }, "nbformat": 4, - "nbformat_minor": 1 + "nbformat_minor": 0 } diff --git a/docs/guides/flax_on_pjit.md b/docs/guides/flax_on_pjit.md index 5cc652a041..dc36ed6d11 100644 --- a/docs/guides/flax_on_pjit.md +++ b/docs/guides/flax_on_pjit.md @@ -14,13 +14,11 @@ jupytext: This guide shows how to scale up [Flax Modules](https://flax.readthedocs.io/en/latest/developer_notes/module_lifecycle.html) on multiple devices and hosts using [`jax.jit`](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) (formerly [`experimental.pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html#module-jax.experimental.pjit)) and [`flax.linen`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html). -+++ ++++ {"id": "b1e0e5fc8bc1"} ## Flax and `jax.jit` scaled up -`jax.jit` follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm and automatically compiles your code to run it on multiple devices. You need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications. - -To learn more about `jax.jit` APIs for scaling-up, refer to [JAX in multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html) and [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html). +[`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm and automatically compiles your code to run it on multiple devices. You need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications. Flax provides several functionalities that can help you use auto-SPMD on [Flax Modules](https://flax.readthedocs.io/en/latest/developer_notes/module_lifecycle.html), including: @@ -28,30 +26,24 @@ Flax provides several functionalities that can help you use auto-SPMD on [Flax M 2. Utility functions to generate the sharding information that `jax.jit` requires to run. 3. An interface to customize your axis names called "logical axis annotations" to decouple both your Module code and partition plan to experiment with different partition layouts more easily. +You can learn more about `jax.jit` APIs for scaling up in [JAX in multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html) and [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) on JAX's documentation site. + +++ {"id": "a9601432b448"} ## Setup Import some necessary dependencies. -**Note:** This guide uses the `--xla_force_host_platform_device_count=8` flag to emulate multiple devices in a CPU environment in a Google Colab/Jupyter Notebook. You don't need this if you are already running on a multi-device TPU environment. - -```{code-cell} ipython3 -:id: 867203db3bef -:tags: [skip-execution] - -# Once Flax v0.6.10 is released, no need to do this. -# ! pip3 install -qq "git+https://github.com/google/flax.git@main#egg=flax" -``` +**Note:** This guide uses the `--xla_force_host_platform_device_count=8` flag to emulate multiple devices in a CPU environment in a Google Colab/Jupyter Notebook. You don't need this if you are already using a multi-device TPU environment. -```{code-cell} ipython3 +```{code-cell} :id: f8f42d1174e5 import os os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' ``` -```{code-cell} ipython3 +```{code-cell} :id: b8da40732f0b import functools @@ -69,23 +61,25 @@ from flax.training import train_state, checkpoints import optax # Optax for common losses and optimizers. ``` -```{code-cell} ipython3 +```{code-cell} +:id: bcc30de1d6eb + print(f'We have 8 fake JAX devices now: {jax.devices()}') ``` +++ {"id": "c0d280def897"} -Import and set up the JAX-level device API following [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html): +The code below shows how to import and set up the JAX-level device API, following JAX's [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) guide: -1. Start a 2x4 device `mesh` (8 devices)—this is the same as the layout of [TPU v3-8](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#single_tpu_board). +1. Start a 2x4 device `mesh` (8 devices) using JAX's `mesh_utils.create_device_mesh`. This layout is the same as the one of a [TPU v3-8](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#single_tpu_board). -2. Annotate each axis with a name. A typical way to annotate axis names is `('data', 'model')`, where: +2. Annotate each axis with a name using the `axis_names` parameter in `jax.sharding.Mesh`. A typical way to annotate axis names is `axis_name=('data', 'model')`, where: * `'data'`: the mesh dimension used for data-parallel sharding of the batch dimension of inputs and activations. * `'model'`: the mesh dimension used for sharding parameters of the model across devices. -3. Make a simple util `mesh_sharding` to generate a sharding object from the mesh and any layout. +3. Make a simple utility function `mesh_sharding` for generating a sharding object from the mesh and any layout. -```{code-cell} ipython3 +```{code-cell} :id: 684fe9fe13a0 from jax.sharding import Mesh, PartitionSpec, NamedSharding @@ -93,7 +87,9 @@ from jax.lax import with_sharding_constraint from jax.experimental import mesh_utils ``` -```{code-cell} ipython3 +```{code-cell} +:id: 4589d7a6d4bb + # Create a mesh and annotate each axis with a name. device_mesh = mesh_utils.create_device_mesh((2, 4)) print(device_mesh) @@ -102,24 +98,24 @@ mesh = Mesh(devices=device_mesh, axis_names=('data', 'model')) print(mesh) def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: - return NamedSharding(mesh, pspec) + return NamedSharding(mesh, pspec) ``` +++ {"id": "307d39db6d94"} ## Define a layer -Before defining a model, create an example layer called `DotReluDot` (by subclassing `flax.linen.Module`), which creates two parameters `W1` and `W2` for dot product multiplication, and uses the `jax.nn.relu` (ReLU) activation function in-between. +Before defining a simple model, create an example layer called `DotReluDot` (by subclassing `flax.linen.Module`). The layer creates two parameters `W1` and `W2` for dot product multiplication, and uses the `jax.nn.relu` (ReLU) activation function in-between. To shard the parameters efficiently, apply the following APIs to annotate the parameters and intermediate variables: 1. Use [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.with_partitioning.html#flax.linen.with_partitioning) to decorate the initializer function when creating sub-layers or raw parameters. -2. Apply [`jax.lax.with_sharding_constraint`](https://github.com/google/jax/blob/main/jax/_src/pjit.py#L1516) to annotate intermediate variables like `y` and `z` to force a particular sharding pattern when the ideal constraint is known. +2. Apply [`jax.lax.with_sharding_constraint`](https://github.com/google/jax/blob/main/jax/_src/pjit.py#L1516) (formerly, `pjit.with_sharding_constraint`) to annotate intermediate variables like `y` and `z` to force a particular sharding pattern when the ideal constraint is known. * This step is optional, but can sometimes help auto-SPMD to partition efficiently. In the example below, the call is not required, because XLA will figure out the same sharding layout for `y` and `z` regardless. -```{code-cell} ipython3 +```{code-cell} :id: b74c049968dc class DotReluDot(nn.Module): @@ -152,7 +148,7 @@ class DotReluDot(nn.Module): +++ {"id": "cbac5321c08e"} -Note that device axis names like `'data'`, `'model'` or `None` are passed into both `flax.linen.with_partitioning` and `with_sharding_constraint` API calls. This refers to how each dimension of this data should be sharded — either across one of the device mesh dimensions, or not sharded at all. +Note that device axis names like `'data'`, `'model'` or `None` are passed into both [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.with_partitioning.html) and [`jax.lax.with_sharding_constraint`](https://github.com/google/jax/blob/main/jax/_src/pjit.py#L1516) API calls. This refers to how each dimension of this data should be sharded — either across one of the device mesh dimensions, or not sharded at all. For example: @@ -170,18 +166,18 @@ For example: ## Define a model with `flax.linen.scan` lifted transformation -Having created `DotReluDot`, define the `MLP` model (by subclassing `flax.linen.Module`) as multiple layers of `DotReluDot`. +Having created `DotReluDot`, you can now define the `MLP` model (by subclassing [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module)) as multiple layers of `DotReluDot`. -To replicate identical layers, you can either use `flax.linen.scan`, or a for-loop: +To replicate identical layers, you can either use [`flax.linen.scan`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.scan.html), or a for-loop: -* `flax.linen.scan` can offer faster compilation times. +* `flax.linen.scan` can provide faster compilation times. * The for-loop can be faster on runtime. -The code below shows how to apply both methods, and default with the for-loop, so that all the parameters are two-dimentional and we can visualize their sharding. +The code below shows how to apply both methods, and default with the for-loop, so that all the parameters are two-dimensional and you can visualize their sharding. The `flax.linen.scan` code is just to show that this API works with [Flax lifted transforms](https://flax.readthedocs.io/en/latest/developer_notes/lift.html#supported-transformations). -```{code-cell} ipython3 +```{code-cell} :id: a0ea0dcccbc3 class MLP(nn.Module): @@ -202,9 +198,13 @@ class MLP(nn.Module): return x ``` -Now we make a `model` instance, and a sample input `x`. ++++ {"id": "44395b62561d"} + +Now, create a `model` instance, and a sample input `x`. + +```{code-cell} +:id: 5686299b4839 -```{code-cell} ipython3 # MLP hyperparameters. BATCH, LAYERS, DEPTH, USE_SCAN = 8, 4, 1024, False # Create fake inputs. @@ -222,13 +222,13 @@ model = MLP(LAYERS, DEPTH, USE_SCAN) ## Specify sharding -Next, we need to tell `jax.jit` how to share our data across devices. +Next, you need to tell `jax.jit` how to shard our data across devices. -### Input's sharding +### The input's sharding -For data parallelism, we shard the batched _input_ `x` across the `data` axis by denoting the batch axis as `data`. Then, use `jax.device_put` to place it into the correct devices. +For data parallelism, you can shard the batched _input_ `x` across the `data` axis by denoting the batch axis as `'data'`. Then, use [`jax.device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html) to place it onto the correct `device`s. -```{code-cell} ipython3 +```{code-cell} :id: 8b913a2e57d3 x_sharding = mesh_sharding(PartitionSpec('data', None)) # dimensions: (batch, length) @@ -238,18 +238,20 @@ jax.debug.visualize_array_sharding(x) +++ {"id": "06d134795ae1"} -### Output's sharding +### The output's sharding -We want to compile `model.init()`, and its output is a pytree of parameters. Sometimes we even wrap it with a `flax.training.train_state` to track other variables like optimizer states, and that makes the output an even more complex pytree. +You need to compile `model.init()` (that is, [`flax.linen.Module.init()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.init)), and its output as a pytree of parameters. Additionally, you may sometimes need wrap it with a [`flax.training.train_state`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) to track other variables, such as optimizer states, and that would make the output an even more complex pytree. -Luckily we don't have to hardcode the output's sharding by hand. We do: +To achieve this, luckily, you don't have to hardcode the output's sharding by hand. Instead, you can: 1. Evaluate `model.init` (in this case, a wrapper of it) abstractly using [`jax.eval_shape`](https://jax.readthedocs.io/en/latest/_autosummary/jax.eval_shape.html). 1. Use [`flax.linen.get_sharding`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.get_sharding.html) to automatically generate the `jax.sharding.NamedSharding`. - * This steps utilizes the `nn.with_partitioning` annotations in earlier definition to genereate the correct sharding for the params. + * This step utilizes the [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.with_partitioning.html) annotations in the earlier definition to generate the correct sharding for the parameters. + +```{code-cell} +:id: 19094ec63385 -```{code-cell} ipython3 def init_fn(k, x, model, optimizer): variables = model.init(k, x) # Initialize the model. state = train_state.TrainState.create( # Create a `TrainState`. @@ -259,9 +261,11 @@ def init_fn(k, x, model, optimizer): return state ``` -```{code-cell} ipython3 +```{code-cell} +:id: e49264a3c78e + # Create an abstract closure to wrap the function before feeding it in -# because `jax.eval_shape` only takes pytrees as arguments`. +# because `jax.eval_shape` only takes pytrees as arguments. abstract_variables = jax.eval_shape( functools.partial(init_fn, model=model, optimizer=optimizer), k, x) @@ -279,7 +283,9 @@ Now you can apply [`jax.jit`](https://jax.readthedocs.io/en/latest/jax-101/02-ji Run it to get the `initialized_state`, in which parameters are sharded exactly as instructed: -```{code-cell} ipython3 +```{code-cell} +:id: 5b6e699df733 + jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3), in_shardings=(mesh_sharding(None), x_sharding), # PRNG key and x out_shardings=state_sharding) @@ -300,7 +306,7 @@ Note that in the output of `initialized_state`, the `params` `W1` and `W2` are o You can access the raw `jax.Array` by adding `.value` when outside `jit`, or by `.unbox()` when inside. -```{code-cell} ipython3 +```{code-cell} :id: 19243982c892 print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'])) @@ -309,15 +315,19 @@ print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].names) print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.shape) ``` ++++ {"id": "2beee7d27bdb"} + You can also check the underlying [`jax.sharding`](https://jax.readthedocs.io/en/latest/jax.sharding.html) of each parameter, which is now more internal than `NamedSharding`. Note that numbers like `initialized_state.step` are replicated across all devices. -```{code-cell} ipython3 +```{code-cell} :id: 2067c419a826 initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.sharding ``` -```{code-cell} ipython3 +```{code-cell} +:id: d7cf0baa334b + print(initialized_state.step) initialized_state.step.sharding ``` @@ -326,7 +336,7 @@ initialized_state.step.sharding You can use [`jax.tree_map`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.tree_map.html) to perform mass computation on a dict of boxed params, in the same way as on a dict of JAX arrays. -```{code-cell} ipython3 +```{code-cell} :id: 29b3dae156a2 diff = jax.tree_map( @@ -342,9 +352,9 @@ print(diff_array.shape) ## Compile the train step and inference -Now, you create a `jit`ted training step: +Create a `jit`ted training step as follows: -```{code-cell} ipython3 +```{code-cell} :id: 4e3cc300cfee @functools.partial(jax.jit, in_shardings=(state_sharding, x_sharding), @@ -363,7 +373,9 @@ with mesh: new_state = train_step(initialized_state, x) ``` -```{code-cell} ipython3 +```{code-cell} +:id: 91c6c2662c12 + print(f'Sharding of Weight 1:') jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value) print(f'Sharding of Weight 2:') @@ -372,9 +384,9 @@ jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2' +++ {"id": "2bae79e2e71b"} -And a compiled inference step. Note that the output is also sharded along `(data, None)`. +Then, create a compiled inference step. Note that the output is also sharded along `(data, None)`. -```{code-cell} ipython3 +```{code-cell} :id: c9264a48b9ee @functools.partial(jax.jit, in_shardings=(state_sharding, x_sharding), @@ -396,7 +408,7 @@ jax.debug.visualize_array_sharding(y) If you are running on a TPU pod or a pod slice, you can use a custom `block_all` utility function, as defined below, to measure the performance: -```{code-cell} ipython3 +```{code-cell} :id: a68d7cb2eb89 %%timeit @@ -413,15 +425,15 @@ with mesh: ## Logical axis annotation -JAX auto SPMD encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you actually can annotate the dimensions of any data with more descriptive axis names (not just device mesh axis names like `'data'` and `'model'`). +JAX's automatic SPMD encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you actually can annotate the dimensions of any data with more descriptive axis names (not just device mesh axis names like `'data'` and `'model'`). The `LogicalDotReluDot` and `LogicalMLP` Module definition below are similar to the Modules you created earlier, except for the following: 1. All axes are annotated with more concrete, meaningful names, such as `'embed'`, `'hidden'`, `'batch'` and `'layer'`. These names are referred to as _logical axis names_ in Flax. They make the dimensional changes inside model definitions more readable. -2. `nn.with_logical_partitioning` replaces `nn.with_partitioning`; and `nn.with_logical_constraint` replaces `jax.lax.with_sharding_constraint`, to recognize the logical axis names. +2. [`flax.linen.with_logical_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.with_logical_partitioning.html) replaces `flax.linen.with_partitioning`; and [`flax.linen.with_logical_constraint`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.with_logical_constraint.html#flax-linen-with-logical-constraint) replaces `jax.lax.with_sharding_constraint`, to recognize the logical axis names. -```{code-cell} ipython3 +```{code-cell} :id: a26f85a9e772 class LogicalDotReluDot(nn.Module): @@ -468,13 +480,13 @@ class LogicalMLP(nn.Module): +++ {"id": "0de93ec6cbd6"} -Now initiate a model and try to figure out what sharding its `state` should have. +Now, initiate a model and try to figure out what sharding its `state` should have. -To allow the device mesh to take your model correctly, you need to decide which of these logical axis names are mapped to the device axis `'data'` or `'model'`. This rule is a list of (`logical_axis_name`, `device_axis_name`) tuples, and `nn.logical_to_mesh_sharding` will convert them to the sharding that the device mesh understands. +To allow the device mesh to take your model correctly, you need to decide which of these logical axis names are mapped to the device axis `'data'` or `'model'`. This rule is a list of (`logical_axis_name`, `device_axis_name`) tuples, and [`flax.linen.logical_to_mesh_sharding`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.logical_to_mesh_sharding.html#flax-linen-logical-to-mesh-sharding) will convert them to the kind of sharding that the device mesh can understand. This allows you to change the rules and try out new partition layouts without modifying the model definition. -```{code-cell} ipython3 +```{code-cell} :id: 14db7a1e30fd # Unspecified rule means unsharded by default, so no need to specify `('embed', None)` and `('layer', None)`. @@ -496,15 +508,15 @@ print('sharding annotations are mesh-specific: ', +++ {"id": "58475fffb2de"} -You can verify that the `logical_state_spec` here has the same content as `state_spec` in the previous ("non-logical") example. This allows you to `jax.jit` your module's `init` and `apply`, same as above. +You can verify that the `logical_state_spec` here has the same content as `state_spec` in the previous ("non-logical") example. This allows you to `jax.jit` your Module's [`flax.linen.Module.init`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.init) and [`flax.linen.Module.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.apply) the same way in the above above. -```{code-cell} ipython3 +```{code-cell} :id: 589ff774bb4c state_sharding.params['DotReluDot_0'] == logical_state_sharding.params['LogicalDotReluDot_0'] ``` -```{code-cell} ipython3 +```{code-cell} :id: 77e07a0ab309 logical_jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3), @@ -514,7 +526,9 @@ logical_jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3), logical_initialized_state = logical_jit_init_fn(k, x, logical_model, optimizer) ``` -```{code-cell} ipython3 +```{code-cell} +:id: fb53bc20e0f9 + print(f'Sharding of Weight 1:') jax.debug.visualize_array_sharding(logical_initialized_state.params['LogicalDotReluDot_0']['Dense_0']['kernel'].value) print(f'Sharding of Weight 2:') @@ -525,18 +539,18 @@ jax.debug.visualize_array_sharding(logical_initialized_state.params['LogicalDotR ## When to use device axis / logical axis -Choosing when to use a device or logical axis depends on how much you want to control the partitioning of your model. +Choosing when to use a device or logical axis depends on how much you want to control the partitioning of your model: -If you want a very simple model, or you are very confident of your way of partitioning, defining it with __device mesh axis__ can potentially save you a few extra lines of code of converting the logical naming back to the device naming. +* **Device mesh axis**: If you want a very simple model, or you are very confident of your way of partitioning, defining it with __device mesh axis__ can potentially save you a few extra lines of code of converting the logical naming back to the device naming. -On the other hand, the __logical naming__ helpers are useful for exploring different sharding layouts. Use this if you want to experiment around and find the most optimal partition layout for your model. +* **logical naming**: On the other hand, the __logical naming__ helpers can be useful for exploring different sharding layouts. Use this if you want to experiment around and find the most optimal partition layout for your model. -In really advanced use cases, you may have more complicated sharding patterns that require annotating *activation* dimension names differently from *parameter* dimension names. When people wish to have more fine-grained control on manual mesh assignments, directly using __device axis names__ could be more helpful. +* **Device axis names**: In really advanced use cases, you may have more complicated sharding patterns that require annotating *activation* dimension names differently from *parameter* dimension names. If you wish to have more fine-grained control on manual mesh assignments, directly using __device axis names__ could be more helpful. +++ {"id": "576bdd5cd782"} ## Save the data -You can use [`flax.training.checkpoints`](https://flax.readthedocs.io/en/latest/_modules/flax/training/checkpoints.html) to save the cross-device array, as shown in the [Save and load checkpoints guide - Multi-host/multi-process checkpointing](https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#multi-host-multi-process-checkpointing). This is especially required if you are running on a multi-host environment (for example, a TPU pod). +To save the cross-device array, you can use [`flax.training.checkpoints`](https://flax.readthedocs.io/en/latest/_modules/flax/training/checkpoints.html), as shown in the [Save and load checkpoints guide - Multi-host/multi-process checkpointing](https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#multi-host-multi-process-checkpointing). This is especially required if you are running on a multi-host environment (for example, a TPU pod). -Keep in mind that to restore the arrays to the desired partition, you need to provide a sample `target` pytree that has the same structure and has the desired `jax.sharding.Sharding` in place for each JAX array. The sharding you use to restore the array doesn't necessarily need to be the same as the ones you used to store the array. +Keep in mind that to restore the arrays to the desired partition, you need to provide a sample `target` pytree that has the same structure and has the desired [`jax.sharding.Sharding`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.Sharding) in place for each JAX array. The sharding you use to restore the array doesn't necessarily need to be the same as the ones you used to store the array.