Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds an option to pass node_labels to visualize_graph #8816

Merged
merged 12 commits into from
Jan 29, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for custom node labels in `visualize_graph()` ([#8816](https://github.com/pyg-team/pytorch_geometric/pull/8816))
- Added support for graph partitioning for temporal data in `torch_geometric.distributed` ([#8718](https://github.com/pyg-team/pytorch_geometric/pull/8718), [#8815](https://github.com/pyg-team/pytorch_geometric/pull/8815))
- Added `TreeGraph` and `GridMotif` generators ([#8736](https://github.com/pyg-team/pytorch_geometric/pull/8736))
- Added an example for edge-level temporal sampling on a heterogenous graph ([#8383](https://github.com/pyg-team/pytorch_geometric/pull/8383))
Expand Down
15 changes: 15 additions & 0 deletions test/visualization/test_graph_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,21 @@ def test_visualize_graph_via_graphviz(tmp_path, backend):
assert osp.exists(path)


@onlyGraphviz
@pytest.mark.parametrize('backend', [None, 'graphviz'])
def test_visualize_graph_via_graphviz_with_node_labels(tmp_path, backend):
edge_index = torch.tensor([
[0, 1, 1, 2, 2, 3, 3, 4],
[1, 0, 2, 1, 3, 2, 4, 3],
])
edge_weight = (torch.rand(edge_index.size(1)) > 0.5).float()
node_labels = ['A', 'B', 'C', 'D', 'E']

path = osp.join(tmp_path, 'graph.pdf')
visualize_graph(edge_index, edge_weight, path, backend, node_labels)
assert osp.exists(path)


@withPackage('networkx', 'matplotlib')
@pytest.mark.parametrize('backend', [None, 'networkx'])
def test_visualize_graph_via_networkx(tmp_path, backend):
Expand Down
12 changes: 9 additions & 3 deletions torch_geometric/explain/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,12 @@ def visualize_feature_importance(

return _visualize_score(score, feat_labels, path, top_k)

def visualize_graph(self, path: Optional[str] = None,
backend: Optional[str] = None):
def visualize_graph(
self,
path: Optional[str] = None,
backend: Optional[str] = None,
node_labels: Optional[List[str]] = None,
) -> None:
r"""Visualizes the explanation graph with edge opacity corresponding to
edge importance.

Expand All @@ -246,13 +250,15 @@ def visualize_graph(self, path: Optional[str] = None,
If set to :obj:`None`, will use the most appropriate
visualization backend based on available system packages.
(default: :obj:`None`)
node_labels (list[str], optional): The labels/IDs of nodes.
(default: :obj:`None`)
"""
edge_mask = self.get('edge_mask')
if edge_mask is None:
raise ValueError(f"The attribute 'edge_mask' is not available "
f"in '{self.__class__.__name__}' "
f"(got {self.available_explanations})")
visualize_graph(self.edge_index, edge_mask, path, backend)
visualize_graph(self.edge_index, edge_mask, path, backend, node_labels)


class HeteroExplanation(HeteroData, ExplanationMixin):
Expand Down
23 changes: 18 additions & 5 deletions torch_geometric/visualization/graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from math import sqrt
from typing import Any, Optional
from typing import Any, List, Optional

import torch
from torch import Tensor
Expand All @@ -26,6 +26,7 @@ def visualize_graph(
edge_weight: Optional[Tensor] = None,
path: Optional[str] = None,
backend: Optional[str] = None,
node_labels: Optional[List[str]] = None,
) -> Any:
r"""Visualizes the graph given via :obj:`edge_index` and (optional)
:obj:`edge_weight`.
Expand All @@ -41,6 +42,8 @@ def visualize_graph(
If set to :obj:`None`, will use the most appropriate
visualization backend based on available system packages.
(default: :obj:`None`)
node_labels (List[str], optional): The labels/IDs of nodes.
(default: :obj:`None`)
"""
if edge_weight is not None: # Normalize edge weights.
edge_weight = edge_weight - edge_weight.min()
Expand All @@ -58,9 +61,11 @@ def visualize_graph(
backend = 'graphviz' if has_graphviz() else 'networkx'

if backend.lower() == 'networkx':
return _visualize_graph_via_networkx(edge_index, edge_weight, path)
return _visualize_graph_via_networkx(edge_index, edge_weight, path,
node_labels)
elif backend.lower() == 'graphviz':
return _visualize_graph_via_graphviz(edge_index, edge_weight, path)
return _visualize_graph_via_graphviz(edge_index, edge_weight, path,
node_labels)

raise ValueError(f"Expected graph drawing backend to be in "
f"{BACKENDS} (got '{backend}')")
Expand All @@ -70,6 +75,7 @@ def _visualize_graph_via_graphviz(
edge_index: Tensor,
edge_weight: Tensor,
path: Optional[str] = None,
node_labels: Optional[List[str]] = None,
) -> Any:
import graphviz

Expand All @@ -78,11 +84,14 @@ def _visualize_graph_via_graphviz(
g.attr('node', shape='circle', fontsize='11pt')

for node in edge_index.view(-1).unique().tolist():
g.node(str(node))
g.node(str(node) if node_labels is None else node_labels[node])

for (src, dst), w in zip(edge_index.t().tolist(), edge_weight.tolist()):
hex_color = hex(255 - round(255 * w))[2:]
hex_color = f'{hex_color}0' if len(hex_color) == 1 else hex_color
if node_labels is not None:
src = node_labels[src]
dst = node_labels[dst]
g.edge(str(src), str(dst), color=f'#{hex_color}{hex_color}{hex_color}')

if path is not None:
Expand All @@ -98,6 +107,7 @@ def _visualize_graph_via_networkx(
edge_index: Tensor,
edge_weight: Tensor,
path: Optional[str] = None,
node_labels: Optional[List[str]] = None,
) -> Any:
import matplotlib.pyplot as plt
import networkx as nx
Expand All @@ -106,9 +116,12 @@ def _visualize_graph_via_networkx(
node_size = 800

for node in edge_index.view(-1).unique().tolist():
g.add_node(node)
g.add_node(node if node_labels is None else node_labels[node])

for (src, dst), w in zip(edge_index.t().tolist(), edge_weight.tolist()):
if node_labels is not None:
src = node_labels[src]
dst = node_labels[dst]
g.add_edge(src, dst, alpha=w)

ax = plt.gca()
Expand Down
Loading