diff --git a/_static/img/distributed/fsdp_sharding.png b/_static/img/distributed/fsdp_sharding.png new file mode 100755 index 0000000000..9dd1e3c111 Binary files /dev/null and b/_static/img/distributed/fsdp_sharding.png differ diff --git a/intermediate_source/FSDP_tutorial.rst b/intermediate_source/FSDP_tutorial.rst index 26988eda90..58fa0ca0c2 100644 --- a/intermediate_source/FSDP_tutorial.rst +++ b/intermediate_source/FSDP_tutorial.rst @@ -46,6 +46,15 @@ At a high level FSDP works as follow: * Run reduce_scatter to sync gradients * Discard parameters. +One way to view FSDP's sharding is to decompose the DDP gradient all-reduce into reduce-scatter and all-gather. Specifically, during the backward pass, FSDP reduces and scatters gradients, ensuring that each rank possesses a shard of the gradients. Then it updates the corresponding shard of the parameters in the optimizer step. Finally, in the subsequent forward pass, it performs an all-gather operation to collect and combine the updated parameter shards. + +.. figure:: /_static/img/distributed/fsdp_sharding.png + :width: 100% + :align: center + :alt: FSDP allreduce + + FSDP Allreduce + How to use FSDP -------------- Here we use a toy model to run training on the MNIST dataset for demonstration purposes. The APIs and logic can be applied to training larger models as well.