Skip to content

Commit

Permalink
[FEA] DGL Examples (#4583)
Browse files Browse the repository at this point in the history
Adds updated examples for the new cuGraph-DGL API. 

Also fixes a minor performance issue by allowing access to tensors through an `EmbeddingView`.

Authors:
  - Alex Barghi (https://github.com/alexbarghi-nv)
  - Ralph Liu (https://github.com/nv-rliu)

Approvers:
  - Mike Sarahan (https://github.com/msarahan)
  - Tingyu Wang (https://github.com/tingyu66)
  - Rick Ratzel (https://github.com/rlratzel)

URL: #4583
  • Loading branch information
alexbarghi-nv authored Aug 28, 2024
1 parent f7e1cea commit 4940e6d
Show file tree
Hide file tree
Showing 13 changed files with 922 additions and 279 deletions.
2 changes: 2 additions & 0 deletions conda/environments/all_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dependencies:
- numpy>=1.23,<2.0a0
- numpydoc
- nvcc_linux-64=11.8
- ogb
- openmpi
- packaging>=21
- pandas
Expand Down Expand Up @@ -74,6 +75,7 @@ dependencies:
- sphinxcontrib-websupport
- thriftpy2!=0.5.0,!=0.5.1
- torchdata
- torchmetrics
- ucx-proc=*=gpu
- ucx-py==0.40.*,>=0.0.0a0
- wget
Expand Down
2 changes: 2 additions & 0 deletions conda/environments/all_cuda-125_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ dependencies:
- numba>=0.57
- numpy>=1.23,<2.0a0
- numpydoc
- ogb
- openmpi
- packaging>=21
- pandas
Expand Down Expand Up @@ -79,6 +80,7 @@ dependencies:
- sphinxcontrib-websupport
- thriftpy2!=0.5.0,!=0.5.1
- torchdata
- torchmetrics
- ucx-proc=*=gpu
- ucx-py==0.40.*,>=0.0.0a0
- wget
Expand Down
2 changes: 2 additions & 0 deletions dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,8 @@ dependencies:
- &pytorch_unsuffixed pytorch>=2.0,<2.2.0a0
- torchdata
- pydantic
- ogb
- torchmetrics

specific:
- output_types: [requirements]
Expand Down
4 changes: 4 additions & 0 deletions python/cugraph-dgl/cugraph_dgl/dataloading/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ def __init__(
self.__graph = graph
self.__device = device

@property
def _batch_size(self):
return self.__batch_size

@property
def dataset(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def sample(

if g.is_homogeneous:
indices = torch.concat(list(indices))
ds.sample_from_nodes(indices, batch_size=batch_size)
ds.sample_from_nodes(indices.long(), batch_size=batch_size)
return HomogeneousSampleReader(
ds.get_reader(), self.output_format, self.edge_dir
)
Expand Down
20 changes: 14 additions & 6 deletions python/cugraph-dgl/cugraph_dgl/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
HeteroNodeDataView,
HeteroEdgeView,
HeteroEdgeDataView,
EmbeddingView,
)


Expand Down Expand Up @@ -567,8 +568,8 @@ def _has_n_emb(self, ntype: str, emb_name: str) -> bool:
return (ntype, emb_name) in self.__ndata_storage

def _get_n_emb(
self, ntype: str, emb_name: str, u: Union[str, TensorType]
) -> "torch.Tensor":
self, ntype: Union[str, None], emb_name: str, u: Union[str, TensorType]
) -> Union["torch.Tensor", "EmbeddingView"]:
"""
Gets the embedding of a single node type.
Unlike DGL, this function takes the string node
Expand All @@ -583,11 +584,11 @@ def _get_n_emb(
u: Union[str, TensorType]
Nodes to get the representation of, or ALL
to get the representation of all nodes of
the given type.
the given type (returns embedding view).
Returns
-------
torch.Tensor
Union[torch.Tensor, cugraph_dgl.view.EmbeddingView]
The embedding of the given edge type with the given embedding name.
"""

Expand All @@ -598,9 +599,14 @@ def _get_n_emb(
raise ValueError("Must provide the node type for a heterogeneous graph")

if dgl.base.is_all(u):
u = torch.arange(self.num_nodes(ntype), dtype=self.idtype, device="cpu")
return EmbeddingView(
self.__ndata_storage[ntype, emb_name], self.num_nodes(ntype)
)

try:
print(
u,
)
return self.__ndata_storage[ntype, emb_name].fetch(
_cast_to_torch_tensor(u), "cuda"
)
Expand Down Expand Up @@ -644,7 +650,9 @@ def _get_e_emb(
etype = self.to_canonical_etype(etype)

if dgl.base.is_all(u):
u = torch.arange(self.num_edges(etype), dtype=self.idtype, device="cpu")
return EmbeddingView(
self.__edata_storage[etype, emb_name], self.num_edges(etype)
)

try:
return self.__edata_storage[etype, emb_name].fetch(
Expand Down
36 changes: 36 additions & 0 deletions python/cugraph-dgl/cugraph_dgl/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# limitations under the License.


import warnings

from collections import defaultdict
from collections.abc import MutableMapping
from typing import Union, Dict, List, Tuple
Expand All @@ -20,11 +22,45 @@

import cugraph_dgl
from cugraph_dgl.typing import TensorType
from cugraph_dgl.utils.cugraph_conversion_utils import _cast_to_torch_tensor

torch = import_optional("torch")
dgl = import_optional("dgl")


class EmbeddingView:
def __init__(self, storage: "dgl.storages.base.FeatureStorage", ld: int):
self.__ld = ld
self.__storage = storage

def __getitem__(self, u: TensorType) -> "torch.Tensor":
u = _cast_to_torch_tensor(u)
try:
return self.__storage.fetch(
u,
"cuda",
)
except RuntimeError as ex:
warnings.warn(
"Got error accessing data, trying again with index on device: "
+ str(ex)
)
return self.__storage.fetch(
u.cuda(),
"cuda",
)

@property
def shape(self) -> "torch.Size":
try:
f = self.__storage.fetch(torch.tensor([0]), "cpu")
except RuntimeError:
f = self.__storage.fetch(torch.tensor([0], device="cuda"), "cuda")
sz = [s for s in f.shape]
sz[0] = self.__ld
return torch.Size(tuple(sz))


class HeteroEdgeDataView(MutableMapping):
"""
Duck-typed version of DGL's HeteroEdgeDataView.
Expand Down
Loading

0 comments on commit 4940e6d

Please sign in to comment.