diff --git a/examples/explain/gnn_explainer.py b/examples/explain/gnn_explainer.py index 22ffa2462778..3e28bef2fafc 100644 --- a/examples/explain/gnn_explainer.py +++ b/examples/explain/gnn_explainer.py @@ -60,5 +60,6 @@ def forward(self, x, edge_index): print(f"Feature importance plot has been saved to '{path}'") path = 'subgraph.pdf' -explanation.visualize_graph(path) +explanation.visualize_graph(path, node_label=None, color_dict=None, + target_node=node_index, draw_node_idx=True) print(f"Subgraph visualization plot has been saved to '{path}'") diff --git a/examples/explain/gnn_explainer_ba_shapes.py b/examples/explain/gnn_explainer_ba_shapes.py index d9c82a3cfae5..6913493223b7 100644 --- a/examples/explain/gnn_explainer_ba_shapes.py +++ b/examples/explain/gnn_explainer_ba_shapes.py @@ -80,6 +80,9 @@ def test(): # Explanation ROC AUC over all test nodes: targets, preds = [], [] + # The color for each node class label. For BAshape dataset, according to GNNExplainer Fig.3 + color_dict = {0: 'orange', 1: 'green', 2: 'green', 3: 'green'} + node_indices = range(400, data.num_nodes, 5) for node_index in tqdm(node_indices, leave=False, desc='Train Explainer'): target = data.y if explanation_type == 'phenomenon' else None @@ -94,3 +97,14 @@ def test(): auc = roc_auc_score(torch.cat(targets), torch.cat(preds)) print(f'Mean ROC AUC (explanation type {explanation_type:10}): {auc:.4f}') + +node_index = 500 +explanation = explainer(data.x, data.edge_index, index=node_index, + target=data.y) +explanation.visualize_graph( + f"GNNExplainer_BAshapes_{node_index}.png", + node_label=data.y, + color_dict=color_dict, + target_node=node_index, + draw_node_idx=False, +) diff --git a/torch_geometric/explain/explanation.py b/torch_geometric/explain/explanation.py index 8897a32d166a..1aaac9281fa5 100644 --- a/torch_geometric/explain/explanation.py +++ b/torch_geometric/explain/explanation.py @@ -233,7 +233,7 @@ def visualize_feature_importance( return _visualize_score(score, feat_labels, path, top_k) def visualize_graph(self, path: Optional[str] = None, - backend: Optional[str] = None): + backend: Optional[str] = None, **kwargs): r"""Visualizes the explanation graph with edge opacity corresponding to edge importance. @@ -246,13 +246,14 @@ def visualize_graph(self, path: Optional[str] = None, If set to :obj:`None`, will use the most appropriate visualization backend based on available system packages. (default: :obj:`None`) + **kwargs: include """ edge_mask = self.get('edge_mask') if edge_mask is None: raise ValueError(f"The attribute 'edge_mask' is not available " f"in '{self.__class__.__name__}' " f"(got {self.available_explanations})") - visualize_graph(self.edge_index, edge_mask, path, backend) + visualize_graph(self.edge_index, edge_mask, path, backend, **kwargs) class HeteroExplanation(HeteroData, ExplanationMixin): diff --git a/torch_geometric/visualization/graph.py b/torch_geometric/visualization/graph.py index 9225184df3ff..85063bb375e9 100644 --- a/torch_geometric/visualization/graph.py +++ b/torch_geometric/visualization/graph.py @@ -26,6 +26,7 @@ def visualize_graph( edge_weight: Optional[Tensor] = None, path: Optional[str] = None, backend: Optional[str] = None, + **kwargs, ) -> Any: r"""Visualizes the graph given via :obj:`edge_index` and (optional) :obj:`edge_weight`. @@ -58,7 +59,8 @@ def visualize_graph( backend = 'graphviz' if has_graphviz() else 'networkx' if backend.lower() == 'networkx': - return _visualize_graph_via_networkx(edge_index, edge_weight, path) + return _visualize_graph_via_networkx(edge_index, edge_weight, path, + **kwargs) elif backend.lower() == 'graphviz': return _visualize_graph_via_graphviz(edge_index, edge_weight, path) @@ -98,10 +100,15 @@ def _visualize_graph_via_networkx( edge_index: Tensor, edge_weight: Tensor, path: Optional[str] = None, + **kwargs, ) -> Any: import matplotlib.pyplot as plt import networkx as nx + node_label = kwargs['node_label'] + color_dict = kwargs['color_dict'] + target_node = kwargs['target_node'] + g = nx.DiGraph() node_size = 800 @@ -127,10 +134,26 @@ def _visualize_graph_via_networkx( ), ) + node_color = ['white'] * len(g.nodes) + if node_label != None: + assert color_dict != None + node_color = [] + for i, node_id in enumerate(list(g.nodes)): + node_color.append(color_dict[int(node_label[node_id])]) + + if target_node != None: + for i, node_id in enumerate(list(g.nodes)): + if node_id == target_node: + print("kylin") + node_color[i] = 'red' + break + nodes = nx.draw_networkx_nodes(g, pos, node_size=node_size, - node_color='white', margins=0.1) + node_color=node_color, margins=0.1) nodes.set_edgecolor('black') - nx.draw_networkx_labels(g, pos, font_size=10) + + if kwargs['draw_node_idx'] == True: + nx.draw_networkx_labels(g, pos, font_size=10) if path is not None: plt.savefig(path)