Skip to content

Commit

Permalink
Fix to_csc in cugraph models (#6972)
Browse files Browse the repository at this point in the history
This PR:
- fixes the bug in `to_csc()` when `size` is not given, since PyG uses
the same subgraph over all layers.
- absorb `num_src_nodes` into the `csc` tuple to make `get_cugraph` more
concise
- misc: no longer allow identity node feature tensor (by setting
`x=None`) in RGCN. It is an oversight from me in the previous PR to not
remove it, since this is memory-inefficient

CC: @MatthiasKohl, @puririshi98

---------

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
tingyu66 and rusty1s authored Mar 20, 2023
1 parent d2bbcfd commit 9f9fd65
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 41 deletions.
66 changes: 37 additions & 29 deletions torch_geometric/nn/conv/cugraph/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Any, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -37,7 +38,8 @@ def to_csc(
edge_index: Tensor,
size: Optional[Tuple[int, int]] = None,
edge_attr: Optional[Tensor] = None,
) -> Union[Tuple[Tensor, Tensor], Tuple[Tuple[Tensor, Tensor], Tensor]]:
) -> Union[Tuple[Tensor, Tensor, int], Tuple[Tuple[Tensor, Tensor, int],
Tensor]]:
r"""Returns a CSC representation of an :obj:`edge_index` tensor to be
used as input to a :class:`CuGraphModule`.
Expand All @@ -48,40 +50,46 @@ def to_csc(
edge_attr (torch.Tensor, optional): The edge features.
(default: :obj:`None`)
"""
if size is None:
warnings.warn(f"Inferring the graph size from 'edge_index' causes "
f"a decline in performance and does not work for "
f"bipartite graphs. To suppress this warning, pass "
f"the 'size' explicitly in '{__name__}.to_csc()'.")
num_src_nodes = num_dst_nodes = int(edge_index.max()) + 1
else:
num_src_nodes, num_dst_nodes = size

row, col = edge_index
num_target_nodes = size[1] if size is not None else int(col.max() + 1)
col, perm = index_sort(col, max_value=num_target_nodes)
col, perm = index_sort(col, max_value=num_dst_nodes)
row = row[perm]

colptr = index2ptr(col, num_target_nodes)
colptr = index2ptr(col, num_dst_nodes)

if edge_attr is not None:
return (row, colptr), edge_attr[perm]
return (row, colptr, num_src_nodes), edge_attr[perm]

return row, colptr
return row, colptr, num_src_nodes

def get_cugraph(
self,
num_src_nodes: int,
csc: Tuple[Tensor, Tensor],
csc: Tuple[Tensor, Tensor, int],
max_num_neighbors: Optional[int] = None,
) -> Any:
r"""Constructs a :obj:`cugraph` graph object from CSC representation.
Supports both bipartite and non-bipartite graphs.
Args:
num_src_nodes (int): The number of source nodes.
csc ((torch.Tensor, torch.Tensor)): A tuple containing the CSC
csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC
representation of a graph, given as a tuple of
:obj:`(row, colptr)`. Use the :meth:`CuGraphModule.to_csc`
method to convert an :obj:`edge_index` representation to the
desired format.
:obj:`(row, colptr, num_src_nodes)`. Use the
:meth:`CuGraphModule.to_csc` method to convert an
:obj:`edge_index` representation to the desired format.
max_num_neighbors (int, optional): The maximum number of neighbors
of a target node. It is only effective when operating in a
bipartite graph. When not given, will be computed on-the-fly,
leading to slightly worse performance. (default: :obj:`None`)
"""
row, colptr = csc
row, colptr, num_src_nodes = csc

if not row.is_cuda:
raise RuntimeError(f"'{self.__class__.__name__}' requires GPU-"
Expand All @@ -100,8 +108,7 @@ def get_cugraph(

def get_typed_cugraph(
self,
num_src_nodes: int,
csc: Tuple[Tensor, Tensor],
csc: Tuple[Tensor, Tensor, int],
edge_type: Tensor,
num_edge_types: Optional[int] = None,
max_num_neighbors: Optional[int] = None,
Expand All @@ -111,11 +118,11 @@ def get_typed_cugraph(
Supports both bipartite and non-bipartite graphs.
Args:
num_src_nodes (int): The number of source nodes.
csc ((torch.Tensor, torch.Tensor)): A tuple containing the CSC
csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC
representation of a graph, given as a tuple of
:obj:`(row, colptr)`. Use the :meth:`to_csc` method to convert
an :obj:`edge_index` representation to the desired format.
:obj:`(row, colptr, num_src_nodes)`. Use the
:meth:`CuGraphModule.to_csc` method to convert an
:obj:`edge_index` representation to the desired format.
edge_type (torch.Tensor): The edge type.
num_edge_types (int, optional): The maximum number of edge types.
When not given, will be computed on-the-fly, leading to
Expand All @@ -128,7 +135,7 @@ def get_typed_cugraph(
if num_edge_types is None:
num_edge_types = int(edge_type.max()) + 1

row, colptr = csc
row, colptr, num_src_nodes = csc
edge_type = edge_type.int()

if num_src_nodes != colptr.numel() - 1: # Bipartite graph:
Expand All @@ -142,25 +149,26 @@ def get_typed_cugraph(
n_edge_types=num_edge_types,
out_node_types=None, in_node_types=None,
edge_types=edge_type)
else:
return make_fg_csr_hg(colptr, row, n_node_types=0,
n_edge_types=num_edge_types, node_types=None,
edge_types=edge_type)

return make_fg_csr_hg(colptr, row, n_node_types=0,
n_edge_types=num_edge_types, node_types=None,
edge_types=edge_type)

def forward(
self,
x: Tensor,
csc: Tuple[Tensor, Tensor],
csc: Tuple[Tensor, Tensor, int],
max_num_neighbors: Optional[int] = None,
) -> Tensor:
r"""Runs the forward pass of the module.
Args:
x (torch.Tensor): The node features.
csc ((torch.Tensor, torch.Tensor)): A tuple containing the CSC
csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC
representation of a graph, given as a tuple of
:obj:`(row, colptr)`. Use the :meth:`to_csc` method to convert
an :obj:`edge_index` representation to the desired format.
:obj:`(row, colptr, num_src_nodes)`. Use the
:meth:`CuGraphModule.to_csc` method to convert an
:obj:`edge_index` representation to the desired format.
max_num_neighbors (int, optional): The maximum number of neighbors
of a target node. It is only effective when operating in a
bipartite graph. When not given, the value will be computed
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/nn/conv/cugraph/gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ def reset_parameters(self):
def forward(
self,
x: Tensor,
csc: Tuple[Tensor, Tensor],
csc: Tuple[Tensor, Tensor, int],
max_num_neighbors: Optional[int] = None,
) -> Tensor:
graph = self.get_cugraph(x.size(0), csc, max_num_neighbors)
graph = self.get_cugraph(csc, max_num_neighbors)

x = self.lin(x)
out = GATConvAgg(x, self.att, graph, self.heads, 'LeakyReLU',
Expand Down
12 changes: 4 additions & 8 deletions torch_geometric/nn/conv/cugraph/rgcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from torch_geometric.nn.conv.cugraph import CuGraphModule
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.typing import OptTensor

try:
from pylibcugraphops.torch.autograd import \
Expand Down Expand Up @@ -71,8 +70,8 @@ def reset_parameters(self):

def forward(
self,
x: OptTensor,
csc: Tuple[Tensor, Tensor],
x: Tensor,
csc: Tuple[Tensor, Tensor, int],
edge_type: Tensor,
max_num_neighbors: Optional[int] = None,
) -> Tensor:
Expand All @@ -91,11 +90,8 @@ def forward(
on-the-fly, leading to slightly worse performance.
(default: :obj:`None`)
"""
if x is None:
x = torch.eye(self.in_channels, device=edge_type.device)

graph = self.get_typed_cugraph(x.size(0), csc, edge_type,
self.num_relations, max_num_neighbors)
graph = self.get_typed_cugraph(csc, edge_type, self.num_relations,
max_num_neighbors)

out = RGCNConvAgg(x, self.comp, graph, concat_own=self.root_weight,
norm_by_out_degree=bool(self.aggr == 'mean'))
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/nn/conv/cugraph/sage_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ def reset_parameters(self):
def forward(
self,
x: Tensor,
csc: Tuple[Tensor, Tensor],
csc: Tuple[Tensor, Tensor, int],
max_num_neighbors: Optional[int] = None,
) -> Tensor:
graph = self.get_cugraph(x.size(0), csc, max_num_neighbors)
graph = self.get_cugraph(csc, max_num_neighbors)

if self.project:
x = self.pre_lin(x).relu()
Expand Down

0 comments on commit 9f9fd65

Please sign in to comment.