Skip to content

Commit

Permalink
[FEA] cuGraph GNN NCCL-only Setup and Distributed Sampling (#4278)
Browse files Browse the repository at this point in the history
* Adds the ability to run `pylibcugraph` without UCX/dask within PyTorch DDP.
* Adds the new distributed sampler which uses the new nccl+ddp path to perform bulk sampling.

Closes #4200 
Closes #4201 
Closes #4246 
Closes #3851

Authors:
  - Alex Barghi (https://github.com/alexbarghi-nv)

Approvers:
  - Seunghwa Kang (https://github.com/seunghwak)
  - Rick Ratzel (https://github.com/rlratzel)
  - Chuck Hastings (https://github.com/ChuckHastings)
  - Jake Awe (https://github.com/AyodeAwe)
  - Joseph Nke (https://github.com/jnke2016)

URL: #4278
  • Loading branch information
alexbarghi-nv authored Apr 15, 2024
1 parent 80d0ecb commit 5c7cb2b
Show file tree
Hide file tree
Showing 14 changed files with 1,398 additions and 5 deletions.
2 changes: 1 addition & 1 deletion ci/run_cugraph_pyg_pytests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ pytest --cache-clear --ignore=tests/mg "$@" .
# Test examples
for e in "$(pwd)"/examples/*.py; do
rapids-logger "running example $e"
python $e
(yes || true) | python $e
done
2 changes: 1 addition & 1 deletion ci/test_wheel_cugraph-pyg.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,6 @@ python -m pytest \
# Test examples
for e in "$(pwd)"/examples/*.py; do
rapids-logger "running example $e"
python $e
(yes || true) | python $e
done
popd
112 changes: 112 additions & 0 deletions python/cugraph-pyg/cugraph_pyg/examples/cugraph_dist_sampling_mg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This example shows how to use cuGraph nccl-only comms, pylibcuGraph,
# and PyTorch DDP to run a multi-GPU sampling workflow. Most users of the
# GNN packages will not interact with cuGraph directly. This example
# is intented for users who want to extend cuGraph within a DDP workflow.

import os
import re
import tempfile

import numpy as np
import torch
import torch.multiprocessing as tmp
import torch.distributed as dist

import cudf

from cugraph.gnn import (
cugraph_comms_init,
cugraph_comms_shutdown,
cugraph_comms_create_unique_id,
cugraph_comms_get_raft_handle,
DistSampleWriter,
UniformNeighborSampler,
)

from pylibcugraph import MGGraph, ResourceHandle, GraphProperties

from ogb.nodeproppred import NodePropPredDataset


def init_pytorch(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group("nccl", rank=rank, world_size=world_size)


def sample(rank: int, world_size: int, uid, edgelist, directory):
init_pytorch(rank, world_size)

device = rank
cugraph_comms_init(rank, world_size, uid, device)

print(f"rank {rank} initialized cugraph")

src = cudf.Series(np.array_split(edgelist[0], world_size)[rank])
dst = cudf.Series(np.array_split(edgelist[1], world_size)[rank])

seeds_per_rank = 50
seeds = cudf.Series(np.arange(rank * seeds_per_rank, (rank + 1) * seeds_per_rank))
handle = ResourceHandle(cugraph_comms_get_raft_handle().getHandle())

print("constructing graph")
G = MGGraph(
handle,
GraphProperties(is_multigraph=True, is_symmetric=False),
[src],
[dst],
)
print("graph constructed")

sample_writer = DistSampleWriter(directory=directory, batches_per_partition=2)
sampler = UniformNeighborSampler(
G,
sample_writer,
fanout=[5, 5],
)

sampler.sample_from_nodes(seeds, batch_size=16, random_state=62)

dist.barrier()
cugraph_comms_shutdown()
print(f"rank {rank} shut down cugraph")


def main():
world_size = torch.cuda.device_count()
uid = cugraph_comms_create_unique_id()

dataset = NodePropPredDataset("ogbn-products")
el = dataset[0][0]["edge_index"].astype("int64")

with tempfile.TemporaryDirectory() as directory:
tmp.spawn(
sample,
args=(world_size, uid, el, "."),
nprocs=world_size,
)

print("Printing samples...")
for file in os.listdir(directory):
m = re.match(r"batch=([0-9]+)\.([0-9]+)\-([0-9]+)\.([0-9]+)\.parquet", file)
rank, start, _, end = int(m[1]), int(m[2]), int(m[3]), int(m[4])
print(f"File: {file} (batches {start} to {end} for rank {rank})")
print(cudf.read_parquet(os.path.join(directory, file)))
print("\n")


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This example shows how to use cuGraph nccl-only comms, pylibcuGraph,
# and PyTorch to run a single-GPU sampling workflow. Most users of the
# GNN packages will not interact with cuGraph directly. This example
# is intented for users who want to extend cuGraph within a PyTorch workflow.

import os
import re
import tempfile

import numpy as np

import cudf

from cugraph.gnn import (
DistSampleWriter,
UniformNeighborSampler,
)

from pylibcugraph import SGGraph, ResourceHandle, GraphProperties

from ogb.nodeproppred import NodePropPredDataset


def sample(edgelist, directory):
src = cudf.Series(edgelist[0])
dst = cudf.Series(edgelist[1])

seeds_per_rank = 50
seeds = cudf.Series(np.arange(0, seeds_per_rank))

print("constructing graph")
G = SGGraph(
ResourceHandle(),
GraphProperties(is_multigraph=True, is_symmetric=False),
src,
dst,
)
print("graph constructed")

sample_writer = DistSampleWriter(directory=directory, batches_per_partition=2)
sampler = UniformNeighborSampler(
G,
sample_writer,
fanout=[5, 5],
)

sampler.sample_from_nodes(seeds, batch_size=16, random_state=62)


def main():
dataset = NodePropPredDataset("ogbn-products")
el = dataset[0][0]["edge_index"].astype("int64")

with tempfile.TemporaryDirectory() as directory:
sample(el, directory)

print("Printing samples...")
for file in os.listdir(directory):
m = re.match(r"batch=([0-9]+)\.([0-9]+)\-([0-9]+)\.([0-9]+)\.parquet", file)
rank, start, _, end = int(m[1]), int(m[2]), int(m[3]), int(m[4])
print(f"File: {file} (batches {start} to {end} for rank {rank})")
print(cudf.read_parquet(os.path.join(directory, file)))
print("\n")


if __name__ == "__main__":
main()
100 changes: 100 additions & 0 deletions python/cugraph-pyg/cugraph_pyg/examples/pylibcugraph_mg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This example shows how to use cuGraph nccl-only comms, pylibcuGraph,
# and PyTorch DDP to run a multi-GPU workflow. Most users of the
# GNN packages will not interact with cuGraph directly. This example
# is intented for users who want to extend cuGraph within a DDP workflow.

import os

import pandas
import numpy as np
import torch
import torch.multiprocessing as tmp
import torch.distributed as dist

import cudf

from cugraph.gnn import (
cugraph_comms_init,
cugraph_comms_shutdown,
cugraph_comms_create_unique_id,
cugraph_comms_get_raft_handle,
)

from pylibcugraph import MGGraph, ResourceHandle, GraphProperties, degrees

from ogb.nodeproppred import NodePropPredDataset


def init_pytorch(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group("nccl", rank=rank, world_size=world_size)


def calc_degree(rank: int, world_size: int, uid, edgelist):
init_pytorch(rank, world_size)

device = rank
cugraph_comms_init(rank, world_size, uid, device)

print(f"rank {rank} initialized cugraph")

src = cudf.Series(np.array_split(edgelist[0], world_size)[rank])
dst = cudf.Series(np.array_split(edgelist[1], world_size)[rank])

seeds = cudf.Series(np.arange(rank * 50, (rank + 1) * 50))
handle = ResourceHandle(cugraph_comms_get_raft_handle().getHandle())

print("constructing graph")
G = MGGraph(
handle,
GraphProperties(is_multigraph=True, is_symmetric=False),
[src],
[dst],
)
print("graph constructed")

print("calculating degrees")
vertices, in_deg, out_deg = degrees(handle, G, seeds, do_expensive_check=False)
print("degrees calculated")

print("constructing dataframe")
df = pandas.DataFrame(
{"v": vertices.get(), "in": in_deg.get(), "out": out_deg.get()}
)
print(df)

dist.barrier()
cugraph_comms_shutdown()
print(f"rank {rank} shut down cugraph")


def main():
world_size = torch.cuda.device_count()
uid = cugraph_comms_create_unique_id()

dataset = NodePropPredDataset("ogbn-products")
el = dataset[0][0]["edge_index"].astype("int64")

tmp.spawn(
calc_degree,
args=(world_size, uid, el),
nprocs=world_size,
)


if __name__ == "__main__":
main()
66 changes: 66 additions & 0 deletions python/cugraph-pyg/cugraph_pyg/examples/pylibcugraph_sg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This example shows how to use cuGraph and pylibcuGraph to run a
# single-GPU workflow. Most users of the GNN packages will not interact
# with cuGraph directly. This example is intented for users who want
# to extend cuGraph within a PyTorch workflow.

import pandas
import numpy as np

import cudf

from pylibcugraph import SGGraph, ResourceHandle, GraphProperties, degrees

from ogb.nodeproppred import NodePropPredDataset


def calc_degree(edgelist):
src = cudf.Series(edgelist[0])
dst = cudf.Series(edgelist[1])

seeds = cudf.Series(np.arange(256))

print("constructing graph")
G = SGGraph(
ResourceHandle(),
GraphProperties(is_multigraph=True, is_symmetric=False),
src,
dst,
)
print("graph constructed")

print("calculating degrees")
vertices, in_deg, out_deg = degrees(
ResourceHandle(), G, seeds, do_expensive_check=False
)
print("degrees calculated")

print("constructing dataframe")
df = pandas.DataFrame(
{"v": vertices.get(), "in": in_deg.get(), "out": out_deg.get()}
)
print(df)

print("done")


def main():
dataset = NodePropPredDataset("ogbn-products")
el = dataset[0][0]["edge_index"].astype("int64")
calc_degree(el)


if __name__ == "__main__":
main()
13 changes: 12 additions & 1 deletion python/cugraph/cugraph/gnn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -13,3 +13,14 @@

from .feature_storage.feat_storage import FeatureStore
from .data_loading.bulk_sampler import BulkSampler
from .data_loading.dist_sampler import (
DistSampler,
DistSampleWriter,
UniformNeighborSampler,
)
from .comms.cugraph_nccl_comms import (
cugraph_comms_init,
cugraph_comms_shutdown,
cugraph_comms_create_unique_id,
cugraph_comms_get_raft_handle,
)
Loading

0 comments on commit 5c7cb2b

Please sign in to comment.