Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gnnexplainer graph-explanation fix. #2615

Merged
merged 2 commits into from
May 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions test/nn/models/test_gnn_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def test_graph_explainer(model):

node_feat_mask, edge_mask = explainer.explain_graph(x, edge_index)
assert_edgemask_clear(model)
_, _ = explainer.visualize_subgraph(-1, edge_index, edge_mask,
y=torch.tensor(2), threshold=0.8)
assert node_feat_mask.size() == (x.size(1), )
assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
assert edge_mask.shape[0] == edge_index.shape[1]
Expand Down
10 changes: 6 additions & 4 deletions torch_geometric/nn/models/gnn_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,8 @@ def visualize_subgraph(self, node_idx, edge_index, edge_mask, y=None,
edge_index (LongTensor): The edge indices.
edge_mask (Tensor): The edge mask.
y (Tensor, optional): The ground-truth node-prediction labels used
as node colorings. (default: :obj:`None`)
as node colorings. All nodes will have the same color
if :attr:`node_idx` is :obj:`-1`.(default: :obj:`None`).
threshold (float, optional): Sets a threshold for visualizing
important edges. If set to :obj:`None`, will visualize all
edges with transparancy indicating the importance of edges.
Expand All @@ -292,9 +293,10 @@ def visualize_subgraph(self, node_idx, edge_index, edge_mask, y=None,
if node_idx == -1:
hard_edge_mask = torch.BoolTensor([True] * edge_index.size(1),
device=edge_mask.device)
subset = torch.arange(
edge_index.max() + 1,
device=edge_index.device if y is None else y.device)
subset = torch.arange(edge_index.max().item() + 1,
device=edge_index.device)
y = None

else:
# Only operate on a k-hop subgraph around `node_idx`.
subset, edge_index, _, hard_edge_mask = k_hop_subgraph(
Expand Down