Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Type Hints] nn.GNNExplainer #5716

Merged
merged 11 commits into from
Oct 19, 2022
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641))
- Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642))
- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5724](https://github.com/pyg-team/pytorch_geometric/pull/5724), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752), [#5753](https://github.com/pyg-team/pytorch_geometric/pull/5753), [#5754](https://github.com/pyg-team/pytorch_geometric/pull/5754), [#5756](https://github.com/pyg-team/pytorch_geometric/pull/5756), [#5757](https://github.com/pyg-team/pytorch_geometric/pull/5757), [#5758](https://github.com/pyg-team/pytorch_geometric/pull/5758), [#5760](https://github.com/pyg-team/pytorch_geometric/pull/5760), [#5766](https://github.com/pyg-team/pytorch_geometric/pull/5766), [#5767](https://github.com/pyg-team/pytorch_geometric/pull/5767), [#5768](https://github.com/pyg-team/pytorch_geometric/pull/5768))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5716](https://github.com/pyg-team/pytorch_geometric/pull/5716), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5724](https://github.com/pyg-team/pytorch_geometric/pull/5724), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752), [#5753](https://github.com/pyg-team/pytorch_geometric/pull/5753), [#5754](https://github.com/pyg-team/pytorch_geometric/pull/5754), [#5756](https://github.com/pyg-team/pytorch_geometric/pull/5756), [#5757](https://github.com/pyg-team/pytorch_geometric/pull/5757), [#5758](https://github.com/pyg-team/pytorch_geometric/pull/5758), [#5760](https://github.com/pyg-team/pytorch_geometric/pull/5760), [#5766](https://github.com/pyg-team/pytorch_geometric/pull/5766), [#5767](https://github.com/pyg-team/pytorch_geometric/pull/5767), [#5768](https://github.com/pyg-team/pytorch_geometric/pull/5768))
- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614))
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))
Expand Down
42 changes: 41 additions & 1 deletion test/nn/models/test_gnn_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.nn import Linear

from torch_geometric.nn import GATConv, GCNConv, GNNExplainer, global_add_pool
from torch_geometric.testing import withPackage
from torch_geometric.testing import is_full_test, withPackage


class GCN(torch.nn.Module):
Expand Down Expand Up @@ -94,6 +94,21 @@ def test_gnn_explainer_explain_node(model, return_type, allow_edge_mask,
assert edge_mask[:8].tolist() == [1.] * 8
assert edge_mask[8:].tolist() == [0.] * 6

if is_full_test():
jit = torch.jit.export(explainer)

node_feat_mask, edge_mask = jit.explain_node(2, x, edge_index)

if feat_mask_type == 'individual_feature':
assert node_feat_mask.size() == x.size()
elif feat_mask_type == 'scalar':
assert node_feat_mask.size() == (x.size(0), )
else:
assert node_feat_mask.size() == (x.size(1), )
assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
assert edge_mask.size() == (edge_index.size(1), )
assert edge_mask.min() >= 0 and edge_mask.max() <= 1


@withPackage('matplotlib')
@pytest.mark.parametrize('allow_edge_mask', [True, False])
Expand Down Expand Up @@ -132,6 +147,21 @@ def test_gnn_explainer_explain_graph(model, return_type, allow_edge_mask,
assert edge_mask.size() == (edge_index.size(1), )
assert edge_mask.max() <= 1 and edge_mask.min() >= 0

if is_full_test():
jit = torch.jit.export(explainer)

node_feat_mask, edge_mask = jit.explain_graph(x, edge_index,
edge_attr=edge_attr)
if feat_mask_type == 'individual_feature':
assert node_feat_mask.size() == x.size()
elif feat_mask_type == 'scalar':
assert node_feat_mask.size() == (x.size(0), )
else:
assert node_feat_mask.size() == (x.size(1), )
assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
assert edge_mask.size() == (edge_index.size(1), )
assert edge_mask.max() <= 1 and edge_mask.min() >= 0


@pytest.mark.parametrize('return_type', return_types)
@pytest.mark.parametrize('model', [GAT()])
Expand All @@ -148,3 +178,13 @@ def test_gnn_explainer_with_existing_self_loops(model, return_type):
assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
assert edge_mask.size() == (edge_index.size(1), )
assert edge_mask.max() <= 1 and edge_mask.min() >= 0

if is_full_test():
jit = torch.jit.export(explainer)

node_feat_mask, edge_mask = jit.explain_node(2, x, edge_index)

assert node_feat_mask.size() == (x.size(1), )
assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
assert edge_mask.size() == (edge_index.size(1), )
assert edge_mask.max() <= 1 and edge_mask.min() >= 0
51 changes: 41 additions & 10 deletions torch_geometric/nn/models/gnn_explainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from math import sqrt
from typing import Optional
from typing import Optional, Tuple

import torch
from torch import Tensor
from torch.nn import Module
from tqdm import tqdm

from torch_geometric.nn.models.explainer import (
Expand Down Expand Up @@ -66,17 +68,30 @@ class GNNExplainer(Explainer):
'node_feat_ent': 0.1,
}

def __init__(self, model, epochs: int = 100, lr: float = 0.01,
num_hops: Optional[int] = None, return_type: str = 'log_prob',
feat_mask_type: str = 'feature', allow_edge_mask: bool = True,
log: bool = True, **kwargs):
def __init__(
self,
model: Module,
epochs: int = 100,
lr: float = 0.01,
num_hops: Optional[int] = None,
return_type: str = 'log_prob',
feat_mask_type: str = 'feature',
allow_edge_mask: bool = True,
log: bool = True,
**kwargs,
):
super().__init__(model, lr, epochs, num_hops, return_type, log)
assert feat_mask_type in ['feature', 'individual_feature', 'scalar']
self.allow_edge_mask = allow_edge_mask
self.feat_mask_type = feat_mask_type
self.coeffs.update(kwargs)

def _initialize_masks(self, x, edge_index, init="normal"):
def _initialize_masks(
self,
x: Tensor,
edge_index: Tensor,
init: str = "normal",
):
(N, F), E = x.size(), edge_index.size(1)
std = 0.1

Expand All @@ -97,7 +112,12 @@ def _clear_masks(self):
self.node_feat_masks = None
self.edge_mask = None

def _loss(self, log_logits, prediction, node_idx: Optional[int] = None):
def _loss(
self,
log_logits: Tensor,
prediction: Tensor,
node_idx: Optional[Tensor] = None,
) -> Tensor:
if self.return_type == 'regression':
if node_idx is not None and node_idx >= 0:
loss = torch.cdist(log_logits[node_idx], prediction[node_idx])
Expand All @@ -124,7 +144,12 @@ def _loss(self, log_logits, prediction, node_idx: Optional[int] = None):

return loss

def explain_graph(self, x, edge_index, **kwargs):
def explain_graph(
self,
x: Tensor,
edge_index: Tensor,
**kwargs,
) -> Tuple[Tensor, Tensor]:
r"""Learns and returns a node feature mask and an edge mask that play a
crucial role to explain the prediction made by the GNN for a graph.

Expand Down Expand Up @@ -183,7 +208,13 @@ def explain_graph(self, x, edge_index, **kwargs):
self._clear_masks()
return node_feat_mask, edge_mask

def explain_node(self, node_idx, x, edge_index, **kwargs):
def explain_node(
self,
node_idx: Tensor,
x: Tensor,
edge_index: Tensor,
**kwargs,
) -> Tuple[Tensor, Tensor]:
r"""Learns and returns a node feature mask and an edge mask that play a
crucial role to explain the prediction made by the GNN for node
:attr:`node_idx`.
Expand Down Expand Up @@ -261,5 +292,5 @@ def explain_node(self, node_idx, x, edge_index, **kwargs):

return node_feat_mask, edge_mask

def __repr__(self):
def __repr__(self) -> str:
return f'{self.__class__.__name__}()'