diff --git a/docs/pjrt.md b/docs/pjrt.md index 265c0abcce5..2cdf6be2119 100644 --- a/docs/pjrt.md +++ b/docs/pjrt.md @@ -194,6 +194,8 @@ for more information. *Warning: GPU support is still highly experimental!* +### Single-node GPU training + To use GPUs with PJRT, simply set `PJRT_DEVICE=GPU` and configure `GPU_NUM_DEVICES` to the number of devices on the host. For example: @@ -201,8 +203,52 @@ To use GPUs with PJRT, simply set `PJRT_DEVICE=GPU` and configure PJRT_DEVICE=GPU GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1 ``` -Currently, only a single host is supported, and multi-host GPU cluster support -will be added in an future release. +You can also use `torchrun` to initiate the single-node multi-GPU training. For example, + +``` +PJRT_DEVICE=GPU torchrun --nnodes 1 --nproc-per-node ${NUM_GPU_DEVICES} xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1 +``` + +In the above example, `--nnodes` means how many machines (physical machines or VMs) to be used (it is 1 since we do single-node training). `--nproc-per-node` means how many GPU devices to be used. + +### Multi-node GPU training + +**Note that this feature only works for cuda 12+**. Similar to how PyTorch uses multi-node training, you can run the command as below: + +``` +PJRT_DEVICE=GPU torchrun \ +--nnodes=${NUMBER_GPU_VM} \ +--node_rank=${CURRENT_NODE_RANK} \ +--nproc_per_node=${NUMBER_LOCAL_GPU_DEVICES} \ +--rdzv_endpoint= multinode_training.py +``` + +- `--nnodes`: how many GPU machines to be used. +- `--node_rank`: the index of the current GPU machines. The value can be 0, 1, ..., ${NUMBER_GPU_VM}-1. +- `--nproc_per_node`: the number of GPU devices to be used on the current machine. +- `--rdzv_endpoint`: the endpoint of the GPU machine with node_rank==0, in the form :. The `host` will be the internal IP address. The port can be any available port on the machine. + +For example, if you want to train on 2 GPU machines: machine_0 and machine_1, on the first GPU machine machine_0, run + +``` +# PJRT_DEVICE=GPU torchrun \ +--nnodes=2 \ +--node_rank=0 \ +--nproc_per_node=4 \ +--rdzv_endpoint=":12355" pytorch/xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1 +``` + +On the second GPU machine, run + +``` +# PJRT_DEVICE=GPU torchrun \ +--nnodes=2 \ +--node_rank=1 \ +--nproc_per_node=4 \ +--rdzv_endpoint=":12355" pytorch/xla/test/test_train_mp_imagenet_torchrun.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1 +``` + +the difference between the 2 commands above are `--node_rank` and potentially `--nproc_per_node` if you want to use different number of GPU devices on each machine. All the rest are identical. ## Differences from XRT