Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-Node-Multi-GPU Tutorial #8071

Merged
merged 44 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
e4687c1
wip
puririshi98 Sep 22, 2023
ab1a24a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 22, 2023
d71728e
wip
puririshi98 Sep 22, 2023
090aa6d
Merge branch 'multinode_multigpu_tutorial' of https://github.com/pyg-…
puririshi98 Sep 22, 2023
fd281ad
wip
puririshi98 Sep 22, 2023
9331b92
wip
puririshi98 Sep 25, 2023
99c2aa6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 25, 2023
671b79c
wip
puririshi98 Sep 25, 2023
b4ee6a5
draft done
puririshi98 Sep 25, 2023
4b0d4f3
Merge branch 'master' into multinode_multigpu_tutorial
puririshi98 Sep 25, 2023
4ac34c3
draft done
puririshi98 Sep 25, 2023
d16f7a7
draft done
puririshi98 Sep 25, 2023
ea24757
draft done
puririshi98 Sep 26, 2023
cef4340
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2023
900a6a1
draft done
puririshi98 Sep 26, 2023
bf88252
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2023
d5342f1
draft done
puririshi98 Sep 26, 2023
7b03492
Merge branch 'multinode_multigpu_tutorial' of https://github.com/pyg-…
puririshi98 Sep 26, 2023
05bd132
draft done
puririshi98 Sep 26, 2023
2965941
draft done
puririshi98 Sep 26, 2023
a99f7b2
draft done
puririshi98 Sep 26, 2023
b24651d
draft done
puririshi98 Sep 26, 2023
73585d4
Merge branch 'master' into multinode_multigpu_tutorial
puririshi98 Sep 27, 2023
6ef7dd5
Merge branch 'master' into multinode_multigpu_tutorial
puririshi98 Sep 28, 2023
1c18d3e
Merge branch 'master' into multinode_multigpu_tutorial
puririshi98 Oct 2, 2023
ccd604a
Merge branch 'master' into multinode_multigpu_tutorial
puririshi98 Oct 6, 2023
290ea7b
Merge branch 'master' into multinode_multigpu_tutorial
puririshi98 Oct 6, 2023
2d29f10
Merge branch 'master' into multinode_multigpu_tutorial
puririshi98 Oct 6, 2023
08fec75
Merge branch 'master' into multinode_multigpu_tutorial
puririshi98 Oct 8, 2023
73fdbdf
Merge branch 'master' into multinode_multigpu_tutorial
puririshi98 Oct 9, 2023
137f730
Merge branch 'master' into multinode_multigpu_tutorial
puririshi98 Oct 9, 2023
e392a65
Merge branch 'master' into multinode_multigpu_tutorial
puririshi98 Oct 9, 2023
1355f72
Merge branch 'master' into multinode_multigpu_tutorial
puririshi98 Oct 18, 2023
8931e4d
Merge branch 'master' into multinode_multigpu_tutorial
puririshi98 Oct 18, 2023
5527236
Update CHANGELOG.md
akihironitta Oct 18, 2023
fe11dcc
Update docs/source/tutorial/multi_gpu_vanilla.rst
akihironitta Oct 18, 2023
60a39ff
addressing akihiro's comments
puririshi98 Oct 18, 2023
561640d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 18, 2023
e19c2c5
Merge branch 'master' into multinode_multigpu_tutorial
puririshi98 Oct 20, 2023
731454c
Merge branch 'master' into multinode_multigpu_tutorial
puririshi98 Oct 23, 2023
b7fec85
Merge branch 'master' into multinode_multigpu_tutorial
puririshi98 Oct 23, 2023
cf196b6
Merge branch 'master' into multinode_multigpu_tutorial
rusty1s Oct 24, 2023
9f80b8a
Merge branch 'master' into multinode_multigpu_tutorial
rusty1s Oct 24, 2023
b9f0512
update
rusty1s Oct 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [2.4.0] - 2023-MM-DD

### Added

- Added a tutorial for multi-node-multi-GPU training with pure PyTorch ([#8071](https://github.com/pyg-team/pytorch_geometric/pull/8071)
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
- Added `OnDiskDataset` interface ([#8066](https://github.com/pyg-team/pytorch_geometric/pull/8066))
- Added a tutorial for `Node2Vec` and `MetaPath2Vec` usage ([#7938](https://github.com/pyg-team/pytorch_geometric/pull/7938)
- Added a tutorial for multi-GPU training with pure PyTorch ([#7894](https://github.com/pyg-team/pytorch_geometric/pull/7894)
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ In addition, it consists of easy-to-use mini-batch loaders for operating on many
tutorial/dataset
tutorial/application
tutorial/multi_gpu
tutorial/multi_node_multi_gpu_tutorial

.. toctree::
:maxdepth: 1
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorial/multi_gpu_vanilla.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Multi-GPU Training in Pure PyTorch
Multi-GPU GNN Training
==================================
akihironitta marked this conversation as resolved.
Show resolved Hide resolved

For many large scale, real-world datasets, it may be necessary to scale-up training across multiple GPUs.
Expand Down Expand Up @@ -171,5 +171,5 @@ After finishing training, we can clean up processes and destroy the process grou
dist.destroy_process_group()

And that's it.
Putting it all together gives a working multi-GPU example that follows a similar training flow than single GPU training.
Putting it all together gives a working multi-GPU example that follows a training flow that is similar to single GPU training.
You can run the shown tutorial by yourself by looking at `examples/multi_gpu/distributed_sampling.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/multi_gpu/distributed_sampling.py>`_.
252 changes: 252 additions & 0 deletions docs/source/tutorial/multi_node_multi_gpu_tutorial.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
Multi-Node-Multi-GPU GNN Training
==================================

Before doing this tutorial we recommend going through <insert single-node-multi-gpu tutorial> as a warm up.
Our first step is to understand the basic structure of a multi-node-multi-gpu example.


.. code-block:: python

import argparse
import time
import warnings

import torch
import torch.distributed as dist
from torch_geometric.datasets import FakeDataset

from torch_geometric.nn.models import GCN

warnings.filterwarnings("ignore")


_LOCAL_PROCESS_GROUP = None


def create_local_process_group(num_workers_per_node):
...


def get_local_process_group():
...


def run(device, data, world_size, model, epochs, batch_size, fan_out,
split_idx, num_classes):
...
akihironitta marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--hidden_channels', type=int, default=64)
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--epochs', type=int, default=3)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--fan_out', type=int, default=50)
parser.add_argument(
"--ngpu_per_node",
type=int,
default="1",
help="number of GPU(s) for each node for multi-gpu training,",
)
args = parser.parse_args()
# setup multi node
torch.distributed.init_process_group("nccl")
nprocs = dist.get_world_size()
create_local_process_group(args.ngpu_per_node)
local_group = get_local_process_group()
device_id = dist.get_rank(
group=local_group) if dist.is_initialized() else 0
torch.cuda.set_device(device_id)
device = torch.device(device_id)

dataset = FakeDataset(avg_num_nodes=100000)
data = dataset.data
num_nodes = data.num_nodes
rand_id = torch.randperm(num_nodes)

# 60/20/20 split
split_idx = {
'train':rand_id[:int(.6 * num_nodes)],
'valid':rand_id[int(.6 * num_nodes):int(.8 * num_nodes)],
'test':rand_id[:int(.8 * num_nodes):],
}

model = GCN(dataset.num_features, args.hidden_channels, 2,
dataset.num_classes)
run(device, data, nprocs, model, args.epochs, args.batch_size,
args.fan_out, split_idx, dataset.num_classes)


Similarly to the warm up example, we define a :meth:`run` function. However, in this case we are using torch distributed with NVIDIA NCCL backend, instead of relying on :class:`~torch.multiprocessing`. Because we are running on multiple nodes, we want to set up a local process group for each node, and use :obj:`args.ngpu_per_node` GPUs per node. We then select the the CUDA device that will be used by each process within each process group. The next steps are fairly basic :pyg:`PyG` and :pytorch:`PyTorch` usage. We load our (synthetic) dataset and then set up our 60/20/20 train/val/test split. Next, we define our :class:`~torch_geometric.nn.models.GCN` model and finally call our :meth:`run` function.
puririshi98 marked this conversation as resolved.
Show resolved Hide resolved

Before we look into how our run function should be defined, we need to understand how we create and get our local process groups.


.. code-block:: python

def create_local_process_group(num_workers_per_node):
global _LOCAL_PROCESS_GROUP
assert _LOCAL_PROCESS_GROUP is None
world_size = dist.get_world_size() if dist.is_initialized() else 1
rank = dist.get_rank() if dist.is_initialized() else 0
assert world_size % num_workers_per_node == 0

num_nodes = world_size // num_workers_per_node
node_rank = rank // num_workers_per_node
for i in range(num_nodes):
ranks_on_i = list(
range(i * num_workers_per_node, (i + 1) * num_workers_per_node))
pg = dist.new_group(ranks_on_i)
if i == node_rank:
_LOCAL_PROCESS_GROUP = pg


def get_local_process_group():
assert _LOCAL_PROCESS_GROUP is not None
return _LOCAL_PROCESS_GROUP

To create our local process groups we create a :class:`~torch.distributed.new_group` from the sequential ranks split into groups of :obj:`num_workers_per_node`. We then store this value in a global variable for each node which we access via :meth:`get_local_process_group`.

The final step of coding is to define our :meth:`run` function:

.. code-block:: python

from torch.nn.parallel import DistributedDataParallel
from torchmetrics import Accuracy
import torch.nn.functional as F
from torch_geometric.loader import NeighborLoader

def run(device, data, world_size, model, epochs, batch_size, fan_out,
split_idx, num_classes):
local_group = get_local_process_group()
loc_id = dist.get_rank(group=local_group)
rank = torch.distributed.get_rank()
if rank == 0:
print("Data =", data)
print('Using', nprocs, 'GPUs...')
split_idx['train'] = split_idx['train'].split(
split_idx['train'].size(0) // world_size, dim=0)[rank].clone()
model = model.to(device)
model = DistributedDataParallel(model, device_ids=[loc_id])
optimizer = torch.optim.Adam(model.parameters(), lr=0.01,
weight_decay=0.0005)
acc = Accuracy(task="multiclass", num_classes=num_classes).to(device)

train_loader = NeighborLoader(data, num_neighbors=[fan_out, fan_out],
input_nodes=split_idx['train'],
batch_size=batch_size)
if rank == 0:
eval_loader = NeighborLoader(data, num_neighbors=[fan_out, fan_out],
input_nodes=split_idx['valid'],
batch_size=batch_size)
test_loader = NeighborLoader(data, num_neighbors=[fan_out, fan_out],
input_nodes=split_idx['test'],
batch_size=batch_size)
eval_steps = 100
acc = Accuracy(task="multiclass", num_classes=num_classes).to(device)
if rank == 0:
print("Beginning training...")
for epoch in range(epochs):
for i, batch in enumerate(train_loader):
if i >= 10:
start = time.time()
batch = batch.to(device)
batch.y = batch.y.to(torch.long)
optimizer.zero_grad()
out = model(batch.x, batch.edge_index)
loss = F.cross_entropy(out[:batch_size], batch.y[:batch_size])
loss.backward()
optimizer.step()
if rank == 0 and i % 10 == 0:
print("Epoch: " + str(epoch) + ", Iteration: " + str(i) +
", Loss: " + str(loss))
if rank == 0:
print("Average Training Iteration Time:",
(time.time() - start) / (i - 10), "s/iter")
acc_sum = 0.0
with torch.no_grad():
for i, batch in enumerate(eval_loader):
if i >= eval_steps:
break
if i >= 10:
start = time.time()
batch = batch.to(device)
batch.y = batch.y.to(torch.long)
out = model(batch.x, batch.edge_index)
acc_sum += acc(out[:batch_size].softmax(dim=-1),
batch.y[:batch_size])
# We should expect poor Val/Test accuracy's since data is random
print(f"Validation Accuracy: {acc_sum/(i) * 100.0:.4f}%", )
print("Average Inference Iteration Time:",
(time.time() - start) / (i - 10), "s/iter")
if rank == 0:
acc_sum = 0.0
with torch.no_grad():
for i, batch in enumerate(test_loader):
batch = batch.to(device)
batch.y = batch.y.to(torch.long)
out = model(batch.x, batch.edge_index)
acc_sum += acc(out[:batch_size].softmax(dim=-1),
batch.y[:batch_size])
print(f"Test Accuracy: {acc_sum/(i) * 100.0:.4f}%", )

Our :meth:`run` function is very similar to that of our warm up example except for the beginning. In this tutorial our distributed groups have already been initialized so we only need to assign our :obj:`loc_id` for the local GPU id for each device on each node. We also need to assign our global :obj:`rank`. As an example to understand this better, consider a scendario where we use use 3 nodes with 8 GPUs each. The 7th GPU on the 3rd node, or the 23rd GPU in our system, that GPUs process would be rank :obj:`22`. However the value of :obj:`loc_id` for that GPU would be :obj:`6`.

After that its very similar to our warm up:
1. We put :class:`~torch_geometric.nn.GCN` model on :obj:`device` and wrap it inside :class:`~torch.nn.parallel.DistributedDataParallel`, passing the :obj:`loc_id` for :obj:`device_id` parameter.
2. We then set up our optimizer and accuracy objective for evalution and testing.
3. We split training indices into :obj:`world_size` many chunks for each GPU, and initialize the :class:`~torch_geometric.loader.NeighborLoader` class to only operate on its specific chunk of the training set.
4. We create a :class:`~torch_geometric.loader.NeighborLoader` instance for evaluation. Again, for simplicity, we only do this on rank :obj:`0`
5. Finally we follow a similar training and evaluation loop as our warmup example.

And that's all the coding.

Putting it all together gives a working multi-node-multi-GPU example that follows a training flow that is similar to single GPU training.
You can run the shown tutorial by yourself by looking at `examples/multi_gpu/multi_node_multi_gpu_synthetic.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/multi_gpu/multi_node_multi_gpu_synthetic.py>`_.

However, to run the example you need to use slurm on a cluster with pyxis enabled. Here's how:

Step 1:

In your slurm login terminal:

.. code-block:: bash

srun --overlap -A <slurm_access_group> -p interactive -J <experiment-name> -N <num_nodes> -t 00:30:00 --pty bash

This will allocate num_nodes nodes for 30 minutes. The -A and -J arguments may be required on your cluster, speak with your cluster management team for more information on usage for those params.


Then open another slurm login terminal for step 2:

.. code-block:: bash

squeue -u <slurm-unix-account-id>
export jobid=<JOBID from SQUEUE>

In this step we are saving the job id of our slurm job from step 1.

Step 3:

Now we are going to pull a container with a functional PyG and CUDA environment onto each node.

.. code-block:: bash

srun -l -N<num_nodes> --ntasks-per-node=1 --overlap --jobid=$jobid \
--container-image=<image_url> --container-name=cont \
--container-mounts=<data-directory>/ogb-papers100m/:/workspace/dataset true

NVIDIA recommends using our NVIDIA PyG container updated each month with the latest from NVIDIA and PyG. Sign up for early access at `developer.nvidia.com/pyg-container-early-access <https://developer.nvidia.com/pyg-container-early-access>`_. General availability on `NVIDIA NGC <https://www.ngc.nvidia.com/>`_ is set for the end of 2023. Alternatively, see `docker.com <https://www.docker.com/>`_ for information on creating your own container.

Once you have your container loaded, simply run:

Step 4:

.. code-block:: bash

srun -l -N<num_nodes> --ntasks-per-node=<ngpu_per_node> --overlap --jobid=$jobid \
--container-name=cont \
python3 pyg_multinode_tutorial.py --ngpu_per_node <>

Give it a try!
Loading