Skip to content

Commit

Permalink
[Code Coverage] utils/convert.py (#6653)
Browse files Browse the repository at this point in the history
Adding tests for from/to  networkit, trimesh, cugraph
Updates for cu_graph functions in convert.py
adding from_cugraph to __init__

---------

Co-authored-by: zpehlivan <zpehlivan@ina.fr>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
4 people authored Feb 19, 2023
1 parent 6ee08da commit 209268a
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 12 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Properly reset the `data_list` cache of an `InMemoryDataset` when accessing `dataset.data` ([#6685](https://github.com/pyg-team/pytorch_geometric/pull/6685))
- Fixed a bug in `Data.subgraph()` and `HeteroData.subgraph()` ([#6613](https://github.com/pyg-team/pytorch_geometric/pull/6613))
- Fixed a bug in `PNAConv` and `DegreeScalerAggregation` to correctly incorporate degree statistics of isolated nodes ([#6609](https://github.com/pyg-team/pytorch_geometric/pull/6609))
- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6662](https://github.com/pyg-team/pytorch_geometric/pull/6662), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664), [#6667](https://github.com/pyg-team/pytorch_geometric/pull/6667), [#6668](https://github.com/pyg-team/pytorch_geometric/pull/6668), [#6669](https://github.com/pyg-team/pytorch_geometric/pull/6669), [#6670](https://github.com/pyg-team/pytorch_geometric/pull/6670), [#6671](https://github.com/pyg-team/pytorch_geometric/pull/6671), [#6673](https://github.com/pyg-team/pytorch_geometric/pull/6673), [#6675](https://github.com/pyg-team/pytorch_geometric/pull/6675), [#6676](https://github.com/pyg-team/pytorch_geometric/pull/6676), [#6677](https://github.com/pyg-team/pytorch_geometric/pull/6677), [#6678](https://github.com/pyg-team/pytorch_geometric/pull/6678), [#6681](https://github.com/pyg-team/pytorch_geometric/pull/6681), [#6683](https://github.com/pyg-team/pytorch_geometric/pull/6683), [#6703](https://github.com/pyg-team/pytorch_geometric/pull/6703), [#6720](https://github.com/pyg-team/pytorch_geometric/pull/6720), [#6735](https://github.com/pyg-team/pytorch_geometric/pull/6735), [#6736](https://github.com/pyg-team/pytorch_geometric/pull/6736))
- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6653](https://github.com/pyg-team/pytorch_geometric/pull/6653), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6662](https://github.com/pyg-team/pytorch_geometric/pull/6662), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664), [#6667](https://github.com/pyg-team/pytorch_geometric/pull/6667), [#6668](https://github.com/pyg-team/pytorch_geometric/pull/6668), [#6669](https://github.com/pyg-team/pytorch_geometric/pull/6669), [#6670](https://github.com/pyg-team/pytorch_geometric/pull/6670), [#6671](https://github.com/pyg-team/pytorch_geometric/pull/6671), [#6673](https://github.com/pyg-team/pytorch_geometric/pull/6673), [#6675](https://github.com/pyg-team/pytorch_geometric/pull/6675), [#6676](https://github.com/pyg-team/pytorch_geometric/pull/6676), [#6677](https://github.com/pyg-team/pytorch_geometric/pull/6677), [#6678](https://github.com/pyg-team/pytorch_geometric/pull/6678), [#6681](https://github.com/pyg-team/pytorch_geometric/pull/6681), [#6683](https://github.com/pyg-team/pytorch_geometric/pull/6683), [#6703](https://github.com/pyg-team/pytorch_geometric/pull/6703), [#6720](https://github.com/pyg-team/pytorch_geometric/pull/6720), [#6735](https://github.com/pyg-team/pytorch_geometric/pull/6735), [#6736](https://github.com/pyg-team/pytorch_geometric/pull/6736))
- Fixed a bug in which `data.to_heterogeneous()` filtered attributs in the wrong dimension ([#6522](https://github.com/pyg-team/pytorch_geometric/pull/6522))
- Breaking Change: Temporal sampling will now also sample nodes with an equal timestamp to the seed time (requires `pyg-lib>0.1.0`) ([#6517](https://github.com/pyg-team/pytorch_geometric/pull/6517))
- Changed `DataLoader` workers with affinity to start at `cpu0` ([#6512](https://github.com/pyg-team/pytorch_geometric/pull/6512))
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,5 @@ exclude_lines = [
"register_parameter",
"warn",
"torch.cuda.is_available",
"WITH_PT2",
]
171 changes: 169 additions & 2 deletions test/utils/test_convert.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import pytest
import scipy.sparse
import torch

from torch_geometric.data import Data
from torch_geometric.testing import withPackage
from torch_geometric.utils import (
from_cugraph,
from_networkit,
from_networkx,
from_scipy_sparse_matrix,
from_trimesh,
subgraph,
to_cugraph,
to_networkit,
to_networkx,
to_scipy_sparse_matrix,
Expand Down Expand Up @@ -281,7 +284,7 @@ def test_from_networkx_subgraph_convert():


@withPackage('networkit')
def test_to_networkit():
def test_to_networkit_vice_versa():
edge_index = torch.tensor([[0, 1], [1, 0]])

g = to_networkit(edge_index, directed=False)
Expand All @@ -293,8 +296,67 @@ def test_to_networkit():
assert edge_weight is None


@withPackage('networkit')
@pytest.mark.parametrize('directed', [True, False])
@pytest.mark.parametrize('num_nodes', [None, 3])
@pytest.mark.parametrize('edge_weight', [None, torch.rand(3)])
def test_to_networkit(directed, edge_weight, num_nodes):
import networkit

edge_index = torch.tensor([[0, 1, 1], [1, 0, 2]], dtype=torch.long)
g = to_networkit(edge_index, edge_weight, num_nodes, directed)

assert isinstance(g, networkit.Graph)
assert g.isDirected() == directed
assert g.numberOfNodes() == 3

if edge_weight is None:
edge_weight = torch.tensor([1., 1., 1.])

assert g.weight(0, 1) == float(edge_weight[0])
assert g.weight(1, 2) == float(edge_weight[2])

if directed:
assert g.numberOfEdges() == 3
assert g.weight(1, 0) == float(edge_weight[1])
else:
assert g.numberOfEdges() == 2


@pytest.mark.parametrize('directed', [True, False])
@pytest.mark.parametrize('weighted', [True, False])
@withPackage('networkit')
def test_from_networkit(directed, weighted):
import networkit

g = networkit.Graph(3, weighted=weighted, directed=directed)
g.addEdge(0, 1)
g.addEdge(1, 2)
if directed:
g.addEdge(1, 0)

if weighted:
for i, (u, v) in enumerate(g.iterEdges()):
g.setWeight(u, v, i + 1)

edge_index, edge_weight = from_networkit(g)

if directed:
assert edge_index.tolist() == [[0, 1, 1], [1, 2, 0]]
if weighted:
assert edge_weight.tolist() == [1, 2, 3]
else:
assert edge_weight is None
else:
assert edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]
if weighted:
assert edge_weight.tolist() == [1, 1, 2, 2]
else:
assert edge_weight is None


@withPackage('trimesh')
def test_trimesh():
def test_trimesh_vice_versa():
pos = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]],
dtype=torch.float)
face = torch.tensor([[0, 1, 2], [1, 2, 3]]).t()
Expand All @@ -305,3 +367,108 @@ def test_trimesh():

assert pos.tolist() == data.pos.tolist()
assert face.tolist() == data.face.tolist()


@withPackage('trimesh')
def test_to_trimesh():
import trimesh

pos = torch.tensor([[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0]])
face = torch.tensor([[0, 1, 2], [2, 1, 3]]).t()
data = Data(pos=pos, face=face)

obj = to_trimesh(data)

assert isinstance(obj, trimesh.Trimesh)
assert obj.vertices.shape == (4, 3)
assert obj.faces.shape == (2, 3)
assert obj.vertices.tolist() == data.pos.tolist()
assert obj.faces.tolist() == data.face.t().contiguous().tolist()


@withPackage('trimesh')
def test_from_trimesh():
import trimesh

vertices = [[0, 0, 0], [1, 0, 0], [0, 1, 0]]
faces = [[0, 1, 2]]
mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False)

data = from_trimesh(mesh)

assert data.pos.tolist() == vertices
assert data.face.t().contiguous().tolist() == faces


@withPackage('cudf')
@withPackage('cugraph')
@pytest.mark.parametrize('edge_weight', [None, torch.rand(4)])
@pytest.mark.parametrize('relabel_nodes', [True, False])
@pytest.mark.parametrize('directed', [True, False])
def test_to_cugraph(edge_weight, directed, relabel_nodes):
import cugraph

if directed:
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
else:
edge_index = torch.tensor([[0, 1], [1, 2]])

if edge_weight is not None:
edge_weight[:edge_index.size(1)]

graph = to_cugraph(edge_index, edge_weight, relabel_nodes, directed)
assert isinstance(graph, cugraph.Graph)
assert graph.number_of_nodes() == 3

edge_list = graph.view_edge_list()
assert edge_list is not None

edge_list = edge_list.sort_values(by=['src', 'dst'])

cu_edge_index = edge_list[['src', 'dst']].to_pandas().values
assert edge_index.tolist() == cu_edge_index.T.tolist()

if edge_weight is not None:
cu_edge_weight = edge_list['weights'].to_pandas().values
assert edge_weight.tolist() == cu_edge_weight.tolist()


@withPackage('cudf')
@withPackage('cugraph')
@pytest.mark.parametrize('edge_weight', [None, torch.randn(4)])
@pytest.mark.parametrize('directed', [True, False])
@pytest.mark.parametrize('relabel_nodes', [True, False])
def test_from_cugraph(edge_weight, directed, relabel_nodes):
import cudf
import cugraph
from torch.utils.dlpack import to_dlpack

if directed:
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
else:
edge_index = torch.tensor([[0, 1], [1, 2]])

if edge_weight is not None:
edge_weight[:edge_index.size(1)]

G = cugraph.Graph(directed=directed)
df = cudf.from_dlpack(to_dlpack(edge_index.t()))
if edge_weight is not None:
df['2'] = cudf.from_dlpack(to_dlpack(edge_weight))

G.from_cudf_edgelist(
df,
source=0,
destination=1,
edge_attr='2' if edge_weight is not None else None,
renumber=relabel_nodes,
)

cu_edge_index, cu_edge_weight = from_cugraph(G)

assert cu_edge_index.tolist() == edge_index.tolist()

if edge_weight is None:
assert cu_edge_weight is None
else:
assert cu_edge_weight.tolist() == edge_weight.tolist()
3 changes: 2 additions & 1 deletion torch_geometric/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from .convert import to_networkx, from_networkx
from .convert import to_networkit, from_networkit
from .convert import to_trimesh, from_trimesh
from .convert import to_cugraph
from .convert import to_cugraph, from_cugraph
from .smiles import from_smiles, to_smiles
from .random import (erdos_renyi_graph, stochastic_blockmodel_graph,
barabasi_albert_graph)
Expand Down Expand Up @@ -106,6 +106,7 @@
'to_trimesh',
'from_trimesh',
'to_cugraph',
'from_cugraph',
'from_smiles',
'to_smiles',
'erdos_renyi_graph',
Expand Down
30 changes: 22 additions & 8 deletions torch_geometric/utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,35 +403,49 @@ def from_trimesh(mesh):


def to_cugraph(edge_index: Tensor, edge_weight: Optional[Tensor] = None,
relabel_nodes: bool = True):
relabel_nodes: bool = True, directed: bool = True):
r"""Converts a graph given by :obj:`edge_index` and optional
:obj:`edge_weight` into a :obj:`cugraph` graph object.
Args:
edge_index (torch.Tensor): The edge indices of the graph.
edge_weight (torch.Tensor, optional): The edge weights of the graph.
(default: :obj:`None`)
relabel_nodes (bool, optional): If set to :obj:`True`,
:obj:`cugraph` will remove any isolated nodes, leading to a
relabeling of nodes. (default: :obj:`True`)
directed (bool, optional): If set to :obj:`False`, the graph will be
undirected. (default: :obj:`True`)
"""
import cudf
import cugraph

g = cugraph.Graph(directed=directed)
df = cudf.from_dlpack(to_dlpack(edge_index.t()))

if edge_weight is not None:
assert edge_weight.dim() == 1
df[2] = cudf.from_dlpack(to_dlpack(edge_weight))
df['2'] = cudf.from_dlpack(to_dlpack(edge_weight))

g.from_cudf_edgelist(
df,
source=0,
destination=1,
edge_attr='2' if edge_weight is not None else None,
renumber=relabel_nodes,
)

return cugraph.from_cudf_edgelist(
df, source=0, destination=1,
edge_attr=2 if edge_weight is not None else None,
renumber=relabel_nodes)
return g


def from_cugraph(G) -> Tuple[Tensor, Optional[Tensor]]:
def from_cugraph(g: Any) -> Tuple[Tensor, Optional[Tensor]]:
r"""Converts a :obj:`cugraph` graph object into :obj:`edge_index` and
optional :obj:`edge_weight` tensors.
Args:
g (cugraph.Graph): A :obj:`cugraph` graph object.
"""
df = G.edgelist.edgelist_df
df = g.view_edge_list()

src = from_dlpack(df['src'].to_dlpack()).long()
dst = from_dlpack(df['dst'].to_dlpack()).long()
Expand Down

0 comments on commit 209268a

Please sign in to comment.