diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f326f05a77d..e2fbd283aae0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -95,7 +95,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Properly reset the `data_list` cache of an `InMemoryDataset` when accessing `dataset.data` ([#6685](https://github.com/pyg-team/pytorch_geometric/pull/6685)) - Fixed a bug in `Data.subgraph()` and `HeteroData.subgraph()` ([#6613](https://github.com/pyg-team/pytorch_geometric/pull/6613)) - Fixed a bug in `PNAConv` and `DegreeScalerAggregation` to correctly incorporate degree statistics of isolated nodes ([#6609](https://github.com/pyg-team/pytorch_geometric/pull/6609)) -- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6662](https://github.com/pyg-team/pytorch_geometric/pull/6662), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664), [#6667](https://github.com/pyg-team/pytorch_geometric/pull/6667), [#6668](https://github.com/pyg-team/pytorch_geometric/pull/6668), [#6669](https://github.com/pyg-team/pytorch_geometric/pull/6669), [#6670](https://github.com/pyg-team/pytorch_geometric/pull/6670), [#6671](https://github.com/pyg-team/pytorch_geometric/pull/6671), [#6673](https://github.com/pyg-team/pytorch_geometric/pull/6673), [#6675](https://github.com/pyg-team/pytorch_geometric/pull/6675), [#6676](https://github.com/pyg-team/pytorch_geometric/pull/6676), [#6677](https://github.com/pyg-team/pytorch_geometric/pull/6677), [#6678](https://github.com/pyg-team/pytorch_geometric/pull/6678), [#6681](https://github.com/pyg-team/pytorch_geometric/pull/6681), [#6683](https://github.com/pyg-team/pytorch_geometric/pull/6683), [#6703](https://github.com/pyg-team/pytorch_geometric/pull/6703), [#6720](https://github.com/pyg-team/pytorch_geometric/pull/6720), [#6735](https://github.com/pyg-team/pytorch_geometric/pull/6735), [#6736](https://github.com/pyg-team/pytorch_geometric/pull/6736)) +- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6653](https://github.com/pyg-team/pytorch_geometric/pull/6653), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6662](https://github.com/pyg-team/pytorch_geometric/pull/6662), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664), [#6667](https://github.com/pyg-team/pytorch_geometric/pull/6667), [#6668](https://github.com/pyg-team/pytorch_geometric/pull/6668), [#6669](https://github.com/pyg-team/pytorch_geometric/pull/6669), [#6670](https://github.com/pyg-team/pytorch_geometric/pull/6670), [#6671](https://github.com/pyg-team/pytorch_geometric/pull/6671), [#6673](https://github.com/pyg-team/pytorch_geometric/pull/6673), [#6675](https://github.com/pyg-team/pytorch_geometric/pull/6675), [#6676](https://github.com/pyg-team/pytorch_geometric/pull/6676), [#6677](https://github.com/pyg-team/pytorch_geometric/pull/6677), [#6678](https://github.com/pyg-team/pytorch_geometric/pull/6678), [#6681](https://github.com/pyg-team/pytorch_geometric/pull/6681), [#6683](https://github.com/pyg-team/pytorch_geometric/pull/6683), [#6703](https://github.com/pyg-team/pytorch_geometric/pull/6703), [#6720](https://github.com/pyg-team/pytorch_geometric/pull/6720), [#6735](https://github.com/pyg-team/pytorch_geometric/pull/6735), [#6736](https://github.com/pyg-team/pytorch_geometric/pull/6736)) - Fixed a bug in which `data.to_heterogeneous()` filtered attributs in the wrong dimension ([#6522](https://github.com/pyg-team/pytorch_geometric/pull/6522)) - Breaking Change: Temporal sampling will now also sample nodes with an equal timestamp to the seed time (requires `pyg-lib>0.1.0`) ([#6517](https://github.com/pyg-team/pytorch_geometric/pull/6517)) - Changed `DataLoader` workers with affinity to start at `cpu0` ([#6512](https://github.com/pyg-team/pytorch_geometric/pull/6512)) diff --git a/pyproject.toml b/pyproject.toml index b9e7cf51b419..61fc4d7a1ef8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,4 +95,5 @@ exclude_lines = [ "register_parameter", "warn", "torch.cuda.is_available", + "WITH_PT2", ] diff --git a/test/utils/test_convert.py b/test/utils/test_convert.py index 8357613b2493..8a27a058c357 100644 --- a/test/utils/test_convert.py +++ b/test/utils/test_convert.py @@ -1,14 +1,17 @@ +import pytest import scipy.sparse import torch from torch_geometric.data import Data from torch_geometric.testing import withPackage from torch_geometric.utils import ( + from_cugraph, from_networkit, from_networkx, from_scipy_sparse_matrix, from_trimesh, subgraph, + to_cugraph, to_networkit, to_networkx, to_scipy_sparse_matrix, @@ -281,7 +284,7 @@ def test_from_networkx_subgraph_convert(): @withPackage('networkit') -def test_to_networkit(): +def test_to_networkit_vice_versa(): edge_index = torch.tensor([[0, 1], [1, 0]]) g = to_networkit(edge_index, directed=False) @@ -293,8 +296,67 @@ def test_to_networkit(): assert edge_weight is None +@withPackage('networkit') +@pytest.mark.parametrize('directed', [True, False]) +@pytest.mark.parametrize('num_nodes', [None, 3]) +@pytest.mark.parametrize('edge_weight', [None, torch.rand(3)]) +def test_to_networkit(directed, edge_weight, num_nodes): + import networkit + + edge_index = torch.tensor([[0, 1, 1], [1, 0, 2]], dtype=torch.long) + g = to_networkit(edge_index, edge_weight, num_nodes, directed) + + assert isinstance(g, networkit.Graph) + assert g.isDirected() == directed + assert g.numberOfNodes() == 3 + + if edge_weight is None: + edge_weight = torch.tensor([1., 1., 1.]) + + assert g.weight(0, 1) == float(edge_weight[0]) + assert g.weight(1, 2) == float(edge_weight[2]) + + if directed: + assert g.numberOfEdges() == 3 + assert g.weight(1, 0) == float(edge_weight[1]) + else: + assert g.numberOfEdges() == 2 + + +@pytest.mark.parametrize('directed', [True, False]) +@pytest.mark.parametrize('weighted', [True, False]) +@withPackage('networkit') +def test_from_networkit(directed, weighted): + import networkit + + g = networkit.Graph(3, weighted=weighted, directed=directed) + g.addEdge(0, 1) + g.addEdge(1, 2) + if directed: + g.addEdge(1, 0) + + if weighted: + for i, (u, v) in enumerate(g.iterEdges()): + g.setWeight(u, v, i + 1) + + edge_index, edge_weight = from_networkit(g) + + if directed: + assert edge_index.tolist() == [[0, 1, 1], [1, 2, 0]] + if weighted: + assert edge_weight.tolist() == [1, 2, 3] + else: + assert edge_weight is None + else: + assert edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] + if weighted: + assert edge_weight.tolist() == [1, 1, 2, 2] + else: + assert edge_weight is None + + @withPackage('trimesh') -def test_trimesh(): +def test_trimesh_vice_versa(): pos = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]], dtype=torch.float) face = torch.tensor([[0, 1, 2], [1, 2, 3]]).t() @@ -305,3 +367,108 @@ def test_trimesh(): assert pos.tolist() == data.pos.tolist() assert face.tolist() == data.face.tolist() + + +@withPackage('trimesh') +def test_to_trimesh(): + import trimesh + + pos = torch.tensor([[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0]]) + face = torch.tensor([[0, 1, 2], [2, 1, 3]]).t() + data = Data(pos=pos, face=face) + + obj = to_trimesh(data) + + assert isinstance(obj, trimesh.Trimesh) + assert obj.vertices.shape == (4, 3) + assert obj.faces.shape == (2, 3) + assert obj.vertices.tolist() == data.pos.tolist() + assert obj.faces.tolist() == data.face.t().contiguous().tolist() + + +@withPackage('trimesh') +def test_from_trimesh(): + import trimesh + + vertices = [[0, 0, 0], [1, 0, 0], [0, 1, 0]] + faces = [[0, 1, 2]] + mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False) + + data = from_trimesh(mesh) + + assert data.pos.tolist() == vertices + assert data.face.t().contiguous().tolist() == faces + + +@withPackage('cudf') +@withPackage('cugraph') +@pytest.mark.parametrize('edge_weight', [None, torch.rand(4)]) +@pytest.mark.parametrize('relabel_nodes', [True, False]) +@pytest.mark.parametrize('directed', [True, False]) +def test_to_cugraph(edge_weight, directed, relabel_nodes): + import cugraph + + if directed: + edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + else: + edge_index = torch.tensor([[0, 1], [1, 2]]) + + if edge_weight is not None: + edge_weight[:edge_index.size(1)] + + graph = to_cugraph(edge_index, edge_weight, relabel_nodes, directed) + assert isinstance(graph, cugraph.Graph) + assert graph.number_of_nodes() == 3 + + edge_list = graph.view_edge_list() + assert edge_list is not None + + edge_list = edge_list.sort_values(by=['src', 'dst']) + + cu_edge_index = edge_list[['src', 'dst']].to_pandas().values + assert edge_index.tolist() == cu_edge_index.T.tolist() + + if edge_weight is not None: + cu_edge_weight = edge_list['weights'].to_pandas().values + assert edge_weight.tolist() == cu_edge_weight.tolist() + + +@withPackage('cudf') +@withPackage('cugraph') +@pytest.mark.parametrize('edge_weight', [None, torch.randn(4)]) +@pytest.mark.parametrize('directed', [True, False]) +@pytest.mark.parametrize('relabel_nodes', [True, False]) +def test_from_cugraph(edge_weight, directed, relabel_nodes): + import cudf + import cugraph + from torch.utils.dlpack import to_dlpack + + if directed: + edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + else: + edge_index = torch.tensor([[0, 1], [1, 2]]) + + if edge_weight is not None: + edge_weight[:edge_index.size(1)] + + G = cugraph.Graph(directed=directed) + df = cudf.from_dlpack(to_dlpack(edge_index.t())) + if edge_weight is not None: + df['2'] = cudf.from_dlpack(to_dlpack(edge_weight)) + + G.from_cudf_edgelist( + df, + source=0, + destination=1, + edge_attr='2' if edge_weight is not None else None, + renumber=relabel_nodes, + ) + + cu_edge_index, cu_edge_weight = from_cugraph(G) + + assert cu_edge_index.tolist() == edge_index.tolist() + + if edge_weight is None: + assert cu_edge_weight is None + else: + assert cu_edge_weight.tolist() == edge_weight.tolist() diff --git a/torch_geometric/utils/__init__.py b/torch_geometric/utils/__init__.py index 0e8a38f83c98..a2dea0154800 100644 --- a/torch_geometric/utils/__init__.py +++ b/torch_geometric/utils/__init__.py @@ -35,7 +35,7 @@ from .convert import to_networkx, from_networkx from .convert import to_networkit, from_networkit from .convert import to_trimesh, from_trimesh -from .convert import to_cugraph +from .convert import to_cugraph, from_cugraph from .smiles import from_smiles, to_smiles from .random import (erdos_renyi_graph, stochastic_blockmodel_graph, barabasi_albert_graph) @@ -106,6 +106,7 @@ 'to_trimesh', 'from_trimesh', 'to_cugraph', + 'from_cugraph', 'from_smiles', 'to_smiles', 'erdos_renyi_graph', diff --git a/torch_geometric/utils/convert.py b/torch_geometric/utils/convert.py index fd1689a3218f..c82fb908a69c 100644 --- a/torch_geometric/utils/convert.py +++ b/torch_geometric/utils/convert.py @@ -403,35 +403,49 @@ def from_trimesh(mesh): def to_cugraph(edge_index: Tensor, edge_weight: Optional[Tensor] = None, - relabel_nodes: bool = True): + relabel_nodes: bool = True, directed: bool = True): r"""Converts a graph given by :obj:`edge_index` and optional :obj:`edge_weight` into a :obj:`cugraph` graph object. Args: + edge_index (torch.Tensor): The edge indices of the graph. + edge_weight (torch.Tensor, optional): The edge weights of the graph. + (default: :obj:`None`) relabel_nodes (bool, optional): If set to :obj:`True`, :obj:`cugraph` will remove any isolated nodes, leading to a relabeling of nodes. (default: :obj:`True`) + directed (bool, optional): If set to :obj:`False`, the graph will be + undirected. (default: :obj:`True`) """ import cudf import cugraph + g = cugraph.Graph(directed=directed) df = cudf.from_dlpack(to_dlpack(edge_index.t())) if edge_weight is not None: assert edge_weight.dim() == 1 - df[2] = cudf.from_dlpack(to_dlpack(edge_weight)) + df['2'] = cudf.from_dlpack(to_dlpack(edge_weight)) + + g.from_cudf_edgelist( + df, + source=0, + destination=1, + edge_attr='2' if edge_weight is not None else None, + renumber=relabel_nodes, + ) - return cugraph.from_cudf_edgelist( - df, source=0, destination=1, - edge_attr=2 if edge_weight is not None else None, - renumber=relabel_nodes) + return g -def from_cugraph(G) -> Tuple[Tensor, Optional[Tensor]]: +def from_cugraph(g: Any) -> Tuple[Tensor, Optional[Tensor]]: r"""Converts a :obj:`cugraph` graph object into :obj:`edge_index` and optional :obj:`edge_weight` tensors. + + Args: + g (cugraph.Graph): A :obj:`cugraph` graph object. """ - df = G.edgelist.edgelist_df + df = g.view_edge_list() src = from_dlpack(df['src'].to_dlpack()).long() dst = from_dlpack(df['dst'].to_dlpack()).long()