diff --git a/torch_geometric/utils/convert.py b/torch_geometric/utils/convert.py index a80791a4b575..3186ddb1dbc5 100644 --- a/torch_geometric/utils/convert.py +++ b/torch_geometric/utils/convert.py @@ -48,7 +48,7 @@ def from_scipy_sparse_matrix(A): return edge_index, edge_weight -def to_networkx(data, node_attrs=None, edge_attrs=None, +def to_networkx(data, node_attrs=None, edge_attrs=None, graph_attrs=None, to_undirected: Union[bool, str] = False, remove_self_loops: bool = False): r"""Converts a :class:`torch_geometric.data.Data` instance to a @@ -61,6 +61,8 @@ def to_networkx(data, node_attrs=None, edge_attrs=None, copied. (default: :obj:`None`) edge_attrs (iterable of str, optional): The edge attributes to be copied. (default: :obj:`None`) + graph_attrs (iterable of str, optional): The graph attributes to be + copied. (default: :obj:`None`) to_undirected (bool or str, optional): If set to :obj:`True` or "upper", will return a :obj:`networkx.Graph` instead of a :obj:`networkx.DiGraph`. The undirected graph will correspond to @@ -80,10 +82,12 @@ def to_networkx(data, node_attrs=None, edge_attrs=None, G.add_nodes_from(range(data.num_nodes)) - node_attrs, edge_attrs = node_attrs or [], edge_attrs or [] + node_attrs = node_attrs or [] + edge_attrs = edge_attrs or [] + graph_attrs = graph_attrs or [] values = {} - for key, value in data(*(node_attrs + edge_attrs)): + for key, value in data(*(node_attrs + edge_attrs + graph_attrs)): if torch.is_tensor(value): value = value if value.dim() <= 1 else value.squeeze(-1) values[key] = value.tolist() @@ -113,6 +117,9 @@ def to_networkx(data, node_attrs=None, edge_attrs=None, for i, feat_dict in G.nodes(data=True): feat_dict.update({key: values[key][i]}) + for key in graph_attrs: + G.graph[key] = values[key] + return G @@ -173,6 +180,10 @@ def from_networkx(G, group_node_attrs: Optional[Union[List[str], all]] = None, key = f'edge_{key}' if key in node_attrs else key data[str(key)].append(value) + for key, value in G.graph.items(): + key = f'graph_{key}' if key in node_attrs else key + data[str(key)] = value + for key, value in data.items(): try: data[key] = torch.tensor(value)