Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add to_dgl and from_dgl conversions #7053

Merged
merged 21 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `to_dgl` and `from_dgl` conversion functions ([#7053](https://github.com/pyg-team/pytorch_geometric/pull/7053))
- Added support for `torch.jit.script` within `MessagePassing` layers without `torch_sparse` being installed ([#7061](https://github.com/pyg-team/pytorch_geometric/pull/7061), [#7062](https://github.com/pyg-team/pytorch_geometric/pull/7062))
- Added unbatching logic for `torch.sparse` tensors ([#7037](https://github.com/pyg-team/pytorch_geometric/pull/7037))
- Added the `RotatE` KGE model ([#7026](https://github.com/pyg-team/pytorch_geometric/pull/7026))
Expand Down
98 changes: 97 additions & 1 deletion test/utils/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@
import scipy.sparse
import torch

from torch_geometric.data import Data
from torch_geometric.data import Data, HeteroData
from torch_geometric.testing import withPackage
from torch_geometric.utils import (
from_cugraph,
from_dgl,
from_networkit,
from_networkx,
from_scipy_sparse_matrix,
from_trimesh,
sort_edge_index,
subgraph,
to_cugraph,
to_dgl,
to_networkit,
to_networkx,
to_scipy_sparse_matrix,
Expand Down Expand Up @@ -481,3 +483,97 @@ def test_from_cugraph(edge_weight, directed, relabel_nodes):
assert torch.allclose(edge_weight, cu_edge_weight.cpu())
else:
assert cu_edge_weight is None


@withPackage('dgl')
def test_to_dgl_graph():
x = torch.randn(5, 3)
edge_index = torch.tensor([[0, 1, 1, 2, 3, 0], [1, 0, 2, 1, 4, 4]])
edge_attr = torch.randn(edge_index.size(1), 2)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

g = to_dgl(data)

assert torch.equal(data.x, g.ndata['x'])
row, col = g.edges()
assert torch.equal(row, edge_index[0])
assert torch.equal(col, edge_index[1])
assert torch.equal(data.edge_attr, g.edata['edge_attr'])


@withPackage('dgl')
def test_to_dgl_hetero_graph():
data = HeteroData()
data['v1'].x = torch.randn(4, 3)
data['v2'].x = torch.randn(4, 3)
data['v1', 'v2'].edge_index = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]])
data['v1', 'v2'].edge_attr = torch.randn(4, 2)

g = to_dgl(data)

assert data['v1', 'v2'].num_edges == g.num_edges(('v1', 'to', 'v2'))
assert data['v1'].num_nodes == g.num_nodes('v1')
assert data['v2'].num_nodes == g.num_nodes('v2')
assert torch.equal(data['v1'].x, g.nodes['v1'].data['x'])
assert torch.equal(data['v2'].x, g.nodes['v2'].data['x'])
row, col = g.edges()
assert torch.equal(row, data['v1', 'v2'].edge_index[0])
assert torch.equal(col, data['v1', 'v2'].edge_index[1])
assert torch.equal(g.edata['edge_attr'], data['v1', 'v2'].edge_attr)


@withPackage('dgl')
@withPackage('torch_sparse')
def test_to_dgl_sparse():
from torch_geometric.transforms import ToSparseTensor
x = torch.randn(5, 3)
edge_index = torch.tensor([[0, 1, 1, 2, 3, 0], [1, 0, 2, 1, 4, 4]])
edge_attr = torch.randn(edge_index.size(1), 2)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
data = ToSparseTensor()(data)

g = to_dgl(data)

assert torch.equal(data.x, g.ndata["x"])
pyg_row, pyg_col, _ = data.adj_t.t().coo()
dgl_row, dgl_col = g.edges()
assert torch.equal(pyg_row, dgl_row)
assert torch.equal(pyg_col, dgl_col)
assert torch.equal(data.edge_attr, g.edata['edge_attr'])


@withPackage('dgl')
def test_from_dgl_graph():
import dgl
g = dgl.graph(([0, 0, 1, 5], [1, 2, 2, 0]))
g.ndata['x'] = torch.randn(g.num_nodes(), 3)
g.edata['edge_attr'] = torch.randn(g.num_edges())

data = from_dgl(g)

assert torch.equal(data.x, g.ndata['x'])
row, col = g.edges()
assert torch.equal(data.edge_index[0], row)
assert torch.equal(data.edge_index[1], col)
assert torch.equal(data.edge_attr, g.edata['edge_attr'])


@withPackage('dgl')
def test_from_dgl_hetero_graph():
import dgl
g = dgl.heterograph({
('v1', 'to', 'v2'): (
[0, 1, 1, 2, 3, 3, 4],
[0, 0, 1, 1, 1, 2, 2],
)
})
g.nodes['v1'].data['x'] = torch.randn(5, 3)
g.nodes['v2'].data['x'] = torch.randn(3, 3)

data = from_dgl(g)

assert data['v1', 'v2'].num_edges == g.num_edges(('v1', 'to', 'v2'))
assert data['v1'].num_nodes == g.num_nodes('v1')
assert data['v2'].num_nodes == g.num_nodes('v2')
assert torch.equal(data['v1'].x, g.nodes['v1'].data['x'])
assert torch.equal(data['v2'].x, g.nodes['v2'].data['x'])
3 changes: 3 additions & 0 deletions torch_geometric/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .convert import to_networkit, from_networkit
from .convert import to_trimesh, from_trimesh
from .convert import to_cugraph, from_cugraph
from .convert import to_dgl, from_dgl
from .smiles import from_smiles, to_smiles
from .random import (erdos_renyi_graph, stochastic_blockmodel_graph,
barabasi_albert_graph)
Expand Down Expand Up @@ -116,6 +117,8 @@
'from_trimesh',
'to_cugraph',
'from_cugraph',
'to_dgl',
'from_dgl',
'from_smiles',
'to_smiles',
'erdos_renyi_graph',
Expand Down
148 changes: 148 additions & 0 deletions torch_geometric/utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,3 +455,151 @@ def from_cugraph(g: Any) -> Tuple[Tensor, Optional[Tensor]]:
edge_weight = from_dlpack(df['weights'].to_dlpack())

return edge_index, edge_weight


def to_dgl(
data: Union['torch_geometric.data.Data', 'torch_geometric.data.HeteroData']
) -> Any:
r"""Converts a :class:`torch_geometric.data.Data` or
:class:`torch_geometric.data.HeteroData` instance to a :obj:`dgl` graph
object.

Args:
data (torch_geometric.data.Data or torch_geometric.data.HeteroData):
The data object.

Example:

>>> edge_index = torch.tensor([[0, 1, 1, 2, 3, 0], [1, 0, 2, 1, 4, 4]])
>>> x = torch.randn(5, 3)
>>> edge_attr = torch.randn(6, 2)
>>> data = Data(x=x, edge_index=edge_index, edge_attr=y)
>>> g = to_dgl(data)
>>> g
Graph(num_nodes=5, num_edges=6,
ndata_schemes={'x': Scheme(shape=(3,))}
edata_schemes={'edge_attr': Scheme(shape=(2, ))})

>>> data = HeteroData()
>>> data['paper'].x = torch.randn(5, 3)
>>> data['author'].x = torch.ones(5, 3)
>>> edge_index = torch.tensor([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]])
>>> data['author', 'cites', 'paper'].edge_index = edge_index
>>> g = to_dgl(data)
>>> g
Graph(num_nodes={'author': 5, 'paper': 5},
num_edges={('author', 'cites', 'paper'): 5},
metagraph=[('author', 'paper', 'cites')])
"""
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
import dgl

from torch_geometric.data import Data, HeteroData

if isinstance(data, Data):
if data.edge_index is not None:
row, col = data.edge_index
else:
row, col, _ = data.adj_t.t().coo()

g = dgl.graph((row, col))

for attr in data.node_attrs():
g.ndata[attr] = data[attr]
for attr in data.edge_attrs():
if attr in ['edge_index', 'adj_t']:
continue
g.edata[attr] = data[attr]

return g

if isinstance(data, HeteroData):
data_dict = {}
for edge_type, store in data.edge_items():
if store.get('edge_index') is not None:
row, col = store.edge_index
else:
row, col, _ = store['adj_t'].t().coo()

data_dict[edge_type] = (row, col)

g = dgl.heterograph(data_dict)

for node_type, store in data.node_items():
for attr, value in store.items():
g.nodes[node_type].data[attr] = value

for edge_type, store in data.edge_items():
for attr, value in store.items():
if attr in ['edge_index', 'adj_t']:
continue
g.edges[edge_type].data[attr] = value

return g

raise ValueError(f"Invalid data type (got '{type(data)}')")


def from_dgl(
g: Any,
) -> Union['torch_geometric.data.Data', 'torch_geometric.data.HeteroData']:
r"""Converts a :obj:`dgl` graph object to a
:class:`torch_geometric.data.Data` or
:class:`torch_geometric.data.HeteroData` instance.

Args:
g (dgl.DGLGraph): The :obj:`dgl` graph object.
hbenedek marked this conversation as resolved.
Show resolved Hide resolved

Example:

>>> g = dgl.graph(([0, 0, 1, 5], [1, 2, 2, 0]))
>>> g.ndata['x'] = torch.randn(g.num_nodes(), 3)
>>> g.edata['edge_attr'] = torch.randn(g.num_edges(), 2)
>>> data = from_dgl(g)
>>> data
Data(x=[6, 3], edge_attr=[4, 2], edge_index=[2, 4])

>>> g = dgl.heterograph({
>>> g = dgl.heterograph({
... ('author', 'writes', 'paper'): ([0, 1, 1, 2, 3, 3, 4],
... [0, 0, 1, 1, 1, 2, 2])})
>>> g.nodes['author'].data['x'] = torch.randn(5, 3)
>>> g.nodes['paper'].data['x'] = torch.randn(5, 3)
>>> data = from_dgl(g)
>>> data
HeteroData(
author={ x=[5, 3] },
paper={ x=[3, 3] },
(author, writes, paper)={ edge_index=[2, 7] }
)
"""
import dgl

from torch_geometric.data import Data, HeteroData

if not isinstance(g, dgl.DGLGraph):
raise ValueError(f"Invalid data type (got '{type(g)}')")

if g.is_homogeneous:
data = Data()
data.edge_index = torch.stack(g.edges(), dim=0)

for attr, value in g.ndata.items():
data[attr] = value
for attr, value in g.edata.items():
data[attr] = value

return data

data = HeteroData()

for node_type in g.ntypes:
for attr, value in g.nodes[node_type].data.items():
data[node_type][attr] = value

for edge_type in g.canonical_etypes:
row, col = g.edges(form="uv", etype=edge_type)
data[edge_type].edge_index = torch.stack([row, col], dim=0)
for attr, value in g.edge_attr_schemes(edge_type).items():
data[edge_type][attr] = value

return data