Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Mar 25, 2022
1 parent 3b8054d commit dc512ca
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions torch_geometric/utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def from_scipy_sparse_matrix(A):
return edge_index, edge_weight


def to_networkx(data, node_attrs=None, edge_attrs=None,
def to_networkx(data, node_attrs=None, edge_attrs=None, graph_attrs=None,
to_undirected: Union[bool, str] = False,
remove_self_loops: bool = False):
r"""Converts a :class:`torch_geometric.data.Data` instance to a
Expand All @@ -61,6 +61,8 @@ def to_networkx(data, node_attrs=None, edge_attrs=None,
copied. (default: :obj:`None`)
edge_attrs (iterable of str, optional): The edge attributes to be
copied. (default: :obj:`None`)
graph_attrs (iterable of str, optional): The graph attributes to be
copied. (default: :obj:`None`)
to_undirected (bool or str, optional): If set to :obj:`True` or
"upper", will return a :obj:`networkx.Graph` instead of a
:obj:`networkx.DiGraph`. The undirected graph will correspond to
Expand All @@ -80,10 +82,12 @@ def to_networkx(data, node_attrs=None, edge_attrs=None,

G.add_nodes_from(range(data.num_nodes))

node_attrs, edge_attrs = node_attrs or [], edge_attrs or []
node_attrs = node_attrs or []
edge_attrs = edge_attrs or []
graph_attrs = graph_attrs or []

values = {}
for key, value in data(*(node_attrs + edge_attrs)):
for key, value in data(*(node_attrs + edge_attrs + graph_attrs)):
if torch.is_tensor(value):
value = value if value.dim() <= 1 else value.squeeze(-1)
values[key] = value.tolist()
Expand Down Expand Up @@ -113,6 +117,9 @@ def to_networkx(data, node_attrs=None, edge_attrs=None,
for i, feat_dict in G.nodes(data=True):
feat_dict.update({key: values[key][i]})

for key in graph_attrs:
G.graph[key] = values[key]

return G


Expand Down Expand Up @@ -173,6 +180,10 @@ def from_networkx(G, group_node_attrs: Optional[Union[List[str], all]] = None,
key = f'edge_{key}' if key in node_attrs else key
data[str(key)].append(value)

for key, value in G.graph.items():
key = f'graph_{key}' if key in node_attrs else key
data[str(key)] = value

for key, value in data.items():
try:
data[key] = torch.tensor(value)
Expand Down

0 comments on commit dc512ca

Please sign in to comment.