Skip to content

Commit

Permalink
Utilize cudf for relabel_nodes pathway in utils.subgraph (#7764)
Browse files Browse the repository at this point in the history
```
cd /opt/pyg; pip uninstall -y torch-geometric; rm -rf pytorch_geometric; git clone -b cudf-relabel https://github.com/pyg-team/pytorch_geometric.git; cd /opt/pyg/pytorch_geometric; pip install .; py.test -s /opt/pyg/pytorch_geometric/test/utils/test_subgraph.py -v
test/utils/test_subgraph.py::test_get_num_hops PASSED
test/utils/test_subgraph.py::test_subgraph PASSED
test/utils/test_subgraph.py::test_subgraph_large_cudf PASSED
test/utils/test_subgraph.py::test_bipartite_subgraph PASSED
test/utils/test_subgraph.py::test_k_hop_subgraph PASSED

==================================================================================== 5 passed
```

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored Aug 3, 2023
1 parent 497ed91 commit 1d1583a
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `disable_dynamic_shape` experimental flag ([#7246](https://github.com/pyg-team/pytorch_geometric/pull/7246), [#7534](https://github.com/pyg-team/pytorch_geometric/pull/7534))
- Added the option to override `use_segmm` selection in `HeteroLinear` ([#7474](https://github.com/pyg-team/pytorch_geometric/pull/7474))
- Added the `MovieLens-1M` heterogeneous dataset ([#7479](https://github.com/pyg-team/pytorch_geometric/pull/7479))
- Added a CPU-based and GPU-based `map_index` implementation ([#7493](https://github.com/pyg-team/pytorch_geometric/pull/7493), [#7765](https://github.com/pyg-team/pytorch_geometric/pull/7765))
- Added a CPU-based and GPU-based `map_index` implementation ([#7493](https://github.com/pyg-team/pytorch_geometric/pull/7493), [#7764](https://github.com/pyg-team/pytorch_geometric/pull/7764) [#7765](https://github.com/pyg-team/pytorch_geometric/pull/7765))
- Added the `AmazonBook` heterogeneous dataset ([#7483](https://github.com/pyg-team/pytorch_geometric/pull/7483))
- Added hierarchical heterogeneous GraphSAGE example on OGB-MAG ([#7425](https://github.com/pyg-team/pytorch_geometric/pull/7425))
- Added the `torch_geometric.distributed` package ([#7451](https://github.com/pyg-team/pytorch_geometric/pull/7451), [#7452](https://github.com/pyg-team/pytorch_geometric/pull/7452)), [#7482](https://github.com/pyg-team/pytorch_geometric/pull/7482), [#7502](https://github.com/pyg-team/pytorch_geometric/pull/7502), [#7628](https://github.com/pyg-team/pytorch_geometric/pull/7628), [#7671](https://github.com/pyg-team/pytorch_geometric/pull/7671))
Expand Down
9 changes: 9 additions & 0 deletions test/utils/test_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ def test_subgraph():
assert out[1].tolist() == [7, 8, 9, 10]


@withCUDA
@withPackage('pandas')
def test_subgraph_large_index(device):
subset = torch.tensor([50_000_000], device=device)
edge_index = torch.tensor([[50_000_000], [50_000_000]], device=device)
edge_index, _ = subgraph(subset, edge_index, relabel_nodes=True)
assert edge_index.tolist() == [[0], [0]]


def test_bipartite_subgraph():
edge_index = torch.tensor([[0, 5, 2, 3, 3, 4, 4, 3, 5, 5, 6],
[0, 0, 3, 2, 0, 0, 2, 1, 2, 3, 1]])
Expand Down
12 changes: 8 additions & 4 deletions torch_geometric/utils/subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,20 @@ def subgraph(
else:
num_nodes = subset.size(0)
node_mask = subset
subset = node_mask.nonzero().view(-1)

edge_mask = node_mask[edge_index[0]] & node_mask[edge_index[1]]
edge_index = edge_index[:, edge_mask]
edge_attr = edge_attr[edge_mask] if edge_attr is not None else None

if relabel_nodes:
node_idx = torch.zeros(node_mask.size(0), dtype=torch.long,
device=device)
node_idx[subset] = torch.arange(node_mask.sum().item(), device=device)
edge_index = node_idx[edge_index]
edge_index, _ = map_index(
edge_index.view(-1),
subset,
max_index=num_nodes,
inclusive=True,
)
edge_index = edge_index.view(2, -1)

if return_edge_mask:
return edge_index, edge_attr, edge_mask
Expand Down

0 comments on commit 1d1583a

Please sign in to comment.