Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Example] Add WholeGraph to accelerate PyG dataloaders with GPUs #9714

Open
wants to merge 27 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
1240df9
Add example
chang-l Oct 17, 2024
9f170b0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 17, 2024
1e2bd6f
Minor fix for typos and comments
chang-l Oct 18, 2024
fb432d8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 18, 2024
f86e75b
Merge branch 'master' into add-uva-ddp-pyg
puririshi98 Oct 22, 2024
9ebbd19
Merge branch 'master' into add-uva-ddp-pyg
puririshi98 Oct 31, 2024
12b604b
Example reorg under NVIDIA RAPIDS folder
chang-l Nov 1, 2024
7193592
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2024
74d0830
Update README
chang-l Nov 4, 2024
ab52677
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 4, 2024
47ed670
Update README.md
puririshi98 Nov 6, 2024
8d18a4b
Update README.md
puririshi98 Nov 6, 2024
7eb6d3d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 6, 2024
1a85441
Address comment
chang-l Nov 6, 2024
f7a13c7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 6, 2024
297ddde
Update as torch now uses different fp.register (pytorch-pr-135030)
chang-l Nov 7, 2024
8111f1e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2024
5f27215
Add download script for data prep
chang-l Nov 8, 2024
4abb0a5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 8, 2024
4446583
Update README
puririshi98 Nov 8, 2024
973b6db
No need error-out for MNNVL check
chang-l Nov 12, 2024
ed0b8d8
Add cugraph README
chang-l Nov 12, 2024
aa66042
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 12, 2024
c4c7736
Update examples/distributed/NVIDIA-RAPIDS/wholegraph/nv_distributed_g…
puririshi98 Nov 19, 2024
80b9680
Merge branch 'master' into add-uva-ddp-pyg
puririshi98 Nov 20, 2024
f46a172
Update CHANGELOG.md
puririshi98 Nov 20, 2024
86003a6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions examples/distributed/wholegraph/README
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Using NVIDIA WholeGraph Library for Distributed Training with PyG

**[RAPIDS WholeGraph](https://github.com/rapidsai/wholegraph)**
NVIDIA WholeGraph is designed to optimize the training of Graph Neural Networks (GNNs) that are often constrained by data loading operations. It provides an underlying storage structure, called WholeMemory, which efficiently manages data storage/communication across disk, RAM, and device memory by leveraging NVIDIA GPUs and communication libraries like NCCL/NVSHMEM.

WholeGraph is a low-level graph storage library, integrated into and able to work alongside cuGraph, that directly provides an efficient feature and graph store with associated primitive operations (e.g., GPU-accelerated fast embedding retrieval and graph sampling). It is specifically optimized for NVLink systems, including DGX, MGX, and GH/GB200 machine or clusters.

This example demonstrates how to use WholeGraph to easily distribute the graph and feature store to pinned-host memory for fast GPU UVA access (see the DistTensor class), eliminating the need for manual graph partitioning or any custom third-party launch scripts. WholeGraph seamlessly integrates with PyTorch's Distributed Data Parallel (DDP) setup and works with standard distributed job launchers such as torchrun, mpirun, or srun.

## Requirements

- **PyTorch**: `>= 2.0`
- **PyTorch Geometric**: `>= 2.0.0`
- **WholeGraph**: `>= 24.02`
- **NVIDIA GPU(s)**

## Environment Setup

```bash
pip install pylibwholegraph-cu12
```

## Sinlge/Multi-GPU Run

Using PyTorch torchrun elastic launcher:
```
torchrun papers100m_dist_wholegraph_nc.py
```
or, using multi-GPUs if applicable:
```
torchrun --nnodes 1 --nproc-per-node <ngpu_per_node> papers100m_dist_wholegraph_nc.py
```

## Distributed (multi-node) Run

For example, let's use the slurm launcher here:

```
srun -N<num_nodes> --ntasks-per-node=<ngpu_per_node> python papers100m_dist_wholegraph_nc.py
```

Note the above command line setting is simplified for demonstration purposes. For more details, please refer to this [sbatch script](https://github.com/chang-l/pytorch_geometric/blob/master/examples/multi_gpu/distributed_sampling_multinode.sbatch), as cluster setups may vary.


## Benchmark Run

The benchmark script is similar to the above example but includes a `--mode` command-line argument, allowing users to easily compare PyG's native features/graph store (`torch_geometric.data.Data` and `torch_geometric.data.HeteroData`) with the WholeMemory-based feature store and graph store, shown in this example. It performs a node classification task on the `ogbn-products` dataset.

### PyG baseline
```
torchrun --nnodes 1 --nproc-per-node <ngpu_per_node> benchmark_data.py --mode baseline
```

### WholeGraph FeatureStore integration (UVA for feature store access)
```
torchrun --nnodes 1 --nproc-per-node <ngpu_per_node> benchmark_data.py --mode UVA-features
```

### WholeGraph FeatureStore + GraphStore (UVA for feature and graph store access)
```
torchrun --nnodes 1 --nproc-per-node <ngpu_per_node> benchmark_data.py --mode UVA
```
231 changes: 231 additions & 0 deletions examples/distributed/wholegraph/benchmark_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
"""Multi-node multi-GPU example on ogbn-papers100m.

Example way to run using srun:
srun -l -N<num_nodes> --ntasks-per-node=<ngpu_per_node> \
--container-name=cont --container-image=<image_url> \
--container-mounts=/ogb-papers100m/:/workspace/dataset
python3 path_to_script.py
"""
import argparse
import os
import time
from typing import Optional

import torch
import torch.distributed as dist
import torch.nn.functional as F
from feature_store import WholeGraphFeatureStore
from graph_store import WholeGraphGraphStore
from nv_distributed_graph import dist_shmem
from ogb.nodeproppred import PygNodePropPredDataset
from torch.nn.parallel import DistributedDataParallel
from torchmetrics import Accuracy

from torch_geometric.loader import NeighborLoader, NodeLoader
from torch_geometric.nn import GCN
from torch_geometric.sampler import BaseSampler


class WholeGraphSampler(BaseSampler):
r"""A naive sampler class for WholeGraph graph storage that only supports uniform node-based sampling on homogeneous graph.
"""
from torch_geometric.sampler import NodeSamplerInput, SamplerOutput

def __init__(
self,
graph: WholeGraphGraphStore,
num_neighbors,
):
import pylibwholegraph.torch as wgth

self.num_neighbors = num_neighbors
self.wg_sampler = wgth.GraphStructure()
row_indx, col_ptrs, _ = graph.csc()
self.wg_sampler.set_csr_graph(col_ptrs._tensor, row_indx._tensor)

def sample_from_nodes(self, inputs: NodeSamplerInput) -> SamplerOutput:
r"""Sample subgraphs from the given nodes based on uniform node-based sampling.
"""
seed = inputs.node.cuda(
non_blocking=True) # WholeGraph Sampler needs all seeds on device
WG_SampleOutput = self.wg_sampler.multilayer_sample_without_replacement(
seed, self.num_neighbors, None)
out = WholeGraphGraphStore.create_pyg_subgraph(WG_SampleOutput)
out.metadata = (inputs.input_id, inputs.time)
return out


def run(world_size, rank, local_rank, device, mode):
wall_clock_start = time.perf_counter()

# Will query the runtime environment for `MASTER_ADDR` and `MASTER_PORT`.
# Make sure, those are set!
dist.init_process_group('nccl', world_size=world_size, rank=rank)
dist_shmem.init_process_group_per_node()

# Load the dataset in the local root process and share it with local ranks
if dist_shmem.get_local_rank() == 0:
dataset = PygNodePropPredDataset(name='ogbn-products',
root='/workspace')
else:
dataset = None
dataset = dist_shmem.to_shmem(dataset) # move dataset to shmem

split_idx = dataset.get_idx_split()
split_idx['train'] = split_idx['train'].split(
split_idx['train'].size(0) // world_size, dim=0)[rank].clone()
split_idx['valid'] = split_idx['valid'].split(
split_idx['valid'].size(0) // world_size, dim=0)[rank].clone()
split_idx['test'] = split_idx['test'].split(
split_idx['test'].size(0) // world_size, dim=0)[rank].clone()
data = dataset[0]
num_features = dataset.num_features
num_classes = dataset.num_classes

if mode == 'baseline':
data = data
kwargs = dict(
data=data,
batch_size=1024,
num_neighbors=[30, 30],
num_workers=4,
)
train_loader = NeighborLoader(
input_nodes=split_idx['train'],
shuffle=True,
drop_last=True,
**kwargs,
)
val_loader = NeighborLoader(input_nodes=split_idx['valid'], **kwargs)
test_loader = NeighborLoader(input_nodes=split_idx['test'], **kwargs)

elif mode == 'UVA-features':
feature_store = WholeGraphFeatureStore(pyg_data=data)
graph_store = WholeGraphGraphStore(pyg_data=data, format='pyg')
data = (feature_store, graph_store)
kwargs = dict(
data=data,
batch_size=1024,
num_neighbors=[30, 30],
num_workers=4,
filter_per_worker=
False, # WholeGraph feature fetching is not fork-safe
)
train_loader = NeighborLoader(
input_nodes=split_idx['train'],
shuffle=True,
drop_last=True,
**kwargs,
)
val_loader = NeighborLoader(input_nodes=split_idx['valid'], **kwargs)
test_loader = NeighborLoader(input_nodes=split_idx['test'], **kwargs)

elif mode == 'UVA':
feature_store = WholeGraphFeatureStore(pyg_data=data)
graph_store = WholeGraphGraphStore(pyg_data=data)
data = (feature_store, graph_store)
kwargs = dict(
data=data,
batch_size=1024,
num_workers=0, # with wholegraph sampler you don't need workers
filter_per_worker=
False, # WholeGraph feature fetching is not fork-safe
)
node_sampler = WholeGraphSampler(
graph_store,
num_neighbors=[30, 30],
)
train_loader = NodeLoader(
input_nodes=split_idx['train'],
node_sampler=node_sampler,
shuffle=True,
drop_last=True,
**kwargs,
)
val_loader = NodeLoader(input_nodes=split_idx['valid'],
node_sampler=node_sampler, **kwargs)
test_loader = NodeLoader(input_nodes=split_idx['test'],
node_sampler=node_sampler, **kwargs)

eval_steps = 1000
model = GCN(num_features, 256, 2, num_classes)
acc = Accuracy(task="multiclass", num_classes=num_classes).to(device)
model = DistributedDataParallel(model.to(device), device_ids=[local_rank])
optimizer = torch.optim.Adam(model.parameters(), lr=0.001,
weight_decay=5e-4)

if rank == 0:
prep_time = round(time.perf_counter() - wall_clock_start, 2)
print("Total time before training begins (prep_time)=", prep_time,
"seconds")
print("Beginning training...")

for epoch in range(1, 21):
dist.barrier()
start = time.time()
model.train()
for i, batch in enumerate(train_loader):
batch = batch.to(device)
optimizer.zero_grad()
y = batch.y[:batch.batch_size].view(-1).to(torch.long)
out = model(batch.x, batch.edge_index)[:batch.batch_size]
loss = F.cross_entropy(out, y)
loss.backward()
optimizer.step()
if rank == 0 and i % 100 == 0:
print(f'Epoch: {epoch:02d}, Iteration: {i}, Loss: {loss:.4f}')

# Profile run:
# We synchronize before barrier to flush GPU OPs first,
# then adding barrier to sync CPUs to find max train time among all ranks.
torch.cuda.synchronize()
dist.barrier()
epoch_end = time.time()

@torch.no_grad()
def test(loader: NodeLoader, num_steps: Optional[int] = None):
model.eval()
for j, batch in enumerate(loader):
if num_steps is not None and j >= num_steps:
break
batch = batch.to(device)
out = model(batch.x, batch.edge_index)[:batch.batch_size]
y = batch.y[:batch.batch_size].view(-1).to(torch.long)
acc(out, y)
acc_sum = acc.compute()
return acc_sum

eval_acc = test(val_loader, num_steps=eval_steps)
if rank == 0:
print(f"Val Accuracy: {eval_acc:.4f}%", )
print(f"Epoch {epoch:05d} | "
f"Accuracy {eval_acc:.4f} | "
f"Time {epoch_end - start:.2f}")

acc.reset()
dist.barrier()

test_acc = test(test_loader)
if rank == 0:
print(f"Test Accuracy: {test_acc:.4f}%", )
dist.destroy_process_group() if dist.is_initialized() else None


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--mode', type=str, default='baseline',
choices=['baseline', 'UVA-features', 'UVA'])
args = parser.parse_args()

# Get the world size from the WORLD_SIZE variable or directly from SLURM:
world_size = int(
os.environ.get('WORLD_SIZE', os.environ.get('SLURM_NTASKS')))
# Likewise for RANK and LOCAL_RANK:
rank = int(os.environ.get('RANK', os.environ.get('SLURM_PROCID')))
local_rank = int(
os.environ.get('LOCAL_RANK', os.environ.get('SLURM_LOCALID')))

assert torch.cuda.is_available()
device = torch.device(local_rank)
torch.cuda.set_device(device)
run(world_size, rank, local_rank, device, args.mode)
Loading
Loading