Skip to content

Commit

Permalink
[Type Hints] nn.GNNExplainer (#5716)
Browse files Browse the repository at this point in the history
@rusty1s Please note that I am using `torch.jit.export()` instead of
`torch.jit.script()` for the tests. Can you please confirm if this is
ok?

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
sebastian-montero and rusty1s authored Oct 19, 2022
1 parent 099430c commit 396f183
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 12 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641))
- Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642))
- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5724](https://github.com/pyg-team/pytorch_geometric/pull/5724), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752), [#5753](https://github.com/pyg-team/pytorch_geometric/pull/5753), [#5754](https://github.com/pyg-team/pytorch_geometric/pull/5754), [#5756](https://github.com/pyg-team/pytorch_geometric/pull/5756), [#5757](https://github.com/pyg-team/pytorch_geometric/pull/5757), [#5758](https://github.com/pyg-team/pytorch_geometric/pull/5758), [#5760](https://github.com/pyg-team/pytorch_geometric/pull/5760), [#5766](https://github.com/pyg-team/pytorch_geometric/pull/5766), [#5767](https://github.com/pyg-team/pytorch_geometric/pull/5767), [#5768](https://github.com/pyg-team/pytorch_geometric/pull/5768))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5716](https://github.com/pyg-team/pytorch_geometric/pull/5716), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5724](https://github.com/pyg-team/pytorch_geometric/pull/5724), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752), [#5753](https://github.com/pyg-team/pytorch_geometric/pull/5753), [#5754](https://github.com/pyg-team/pytorch_geometric/pull/5754), [#5756](https://github.com/pyg-team/pytorch_geometric/pull/5756), [#5757](https://github.com/pyg-team/pytorch_geometric/pull/5757), [#5758](https://github.com/pyg-team/pytorch_geometric/pull/5758), [#5760](https://github.com/pyg-team/pytorch_geometric/pull/5760), [#5766](https://github.com/pyg-team/pytorch_geometric/pull/5766), [#5767](https://github.com/pyg-team/pytorch_geometric/pull/5767), [#5768](https://github.com/pyg-team/pytorch_geometric/pull/5768))
- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614))
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))
Expand Down
42 changes: 41 additions & 1 deletion test/nn/models/test_gnn_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.nn import Linear

from torch_geometric.nn import GATConv, GCNConv, GNNExplainer, global_add_pool
from torch_geometric.testing import withPackage
from torch_geometric.testing import is_full_test, withPackage


class GCN(torch.nn.Module):
Expand Down Expand Up @@ -94,6 +94,21 @@ def test_gnn_explainer_explain_node(model, return_type, allow_edge_mask,
assert edge_mask[:8].tolist() == [1.] * 8
assert edge_mask[8:].tolist() == [0.] * 6

if is_full_test():
jit = torch.jit.export(explainer)

node_feat_mask, edge_mask = jit.explain_node(2, x, edge_index)

if feat_mask_type == 'individual_feature':
assert node_feat_mask.size() == x.size()
elif feat_mask_type == 'scalar':
assert node_feat_mask.size() == (x.size(0), )
else:
assert node_feat_mask.size() == (x.size(1), )
assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
assert edge_mask.size() == (edge_index.size(1), )
assert edge_mask.min() >= 0 and edge_mask.max() <= 1


@withPackage('matplotlib')
@pytest.mark.parametrize('allow_edge_mask', [True, False])
Expand Down Expand Up @@ -132,6 +147,21 @@ def test_gnn_explainer_explain_graph(model, return_type, allow_edge_mask,
assert edge_mask.size() == (edge_index.size(1), )
assert edge_mask.max() <= 1 and edge_mask.min() >= 0

if is_full_test():
jit = torch.jit.export(explainer)

node_feat_mask, edge_mask = jit.explain_graph(x, edge_index,
edge_attr=edge_attr)
if feat_mask_type == 'individual_feature':
assert node_feat_mask.size() == x.size()
elif feat_mask_type == 'scalar':
assert node_feat_mask.size() == (x.size(0), )
else:
assert node_feat_mask.size() == (x.size(1), )
assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
assert edge_mask.size() == (edge_index.size(1), )
assert edge_mask.max() <= 1 and edge_mask.min() >= 0


@pytest.mark.parametrize('return_type', return_types)
@pytest.mark.parametrize('model', [GAT()])
Expand All @@ -148,3 +178,13 @@ def test_gnn_explainer_with_existing_self_loops(model, return_type):
assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
assert edge_mask.size() == (edge_index.size(1), )
assert edge_mask.max() <= 1 and edge_mask.min() >= 0

if is_full_test():
jit = torch.jit.export(explainer)

node_feat_mask, edge_mask = jit.explain_node(2, x, edge_index)

assert node_feat_mask.size() == (x.size(1), )
assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
assert edge_mask.size() == (edge_index.size(1), )
assert edge_mask.max() <= 1 and edge_mask.min() >= 0
51 changes: 41 additions & 10 deletions torch_geometric/nn/models/gnn_explainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from math import sqrt
from typing import Optional
from typing import Optional, Tuple

import torch
from torch import Tensor
from torch.nn import Module
from tqdm import tqdm

from torch_geometric.nn.models.explainer import (
Expand Down Expand Up @@ -66,17 +68,30 @@ class GNNExplainer(Explainer):
'node_feat_ent': 0.1,
}

def __init__(self, model, epochs: int = 100, lr: float = 0.01,
num_hops: Optional[int] = None, return_type: str = 'log_prob',
feat_mask_type: str = 'feature', allow_edge_mask: bool = True,
log: bool = True, **kwargs):
def __init__(
self,
model: Module,
epochs: int = 100,
lr: float = 0.01,
num_hops: Optional[int] = None,
return_type: str = 'log_prob',
feat_mask_type: str = 'feature',
allow_edge_mask: bool = True,
log: bool = True,
**kwargs,
):
super().__init__(model, lr, epochs, num_hops, return_type, log)
assert feat_mask_type in ['feature', 'individual_feature', 'scalar']
self.allow_edge_mask = allow_edge_mask
self.feat_mask_type = feat_mask_type
self.coeffs.update(kwargs)

def _initialize_masks(self, x, edge_index, init="normal"):
def _initialize_masks(
self,
x: Tensor,
edge_index: Tensor,
init: str = "normal",
):
(N, F), E = x.size(), edge_index.size(1)
std = 0.1

Expand All @@ -97,7 +112,12 @@ def _clear_masks(self):
self.node_feat_masks = None
self.edge_mask = None

def _loss(self, log_logits, prediction, node_idx: Optional[int] = None):
def _loss(
self,
log_logits: Tensor,
prediction: Tensor,
node_idx: Optional[Tensor] = None,
) -> Tensor:
if self.return_type == 'regression':
if node_idx is not None and node_idx >= 0:
loss = torch.cdist(log_logits[node_idx], prediction[node_idx])
Expand All @@ -124,7 +144,12 @@ def _loss(self, log_logits, prediction, node_idx: Optional[int] = None):

return loss

def explain_graph(self, x, edge_index, **kwargs):
def explain_graph(
self,
x: Tensor,
edge_index: Tensor,
**kwargs,
) -> Tuple[Tensor, Tensor]:
r"""Learns and returns a node feature mask and an edge mask that play a
crucial role to explain the prediction made by the GNN for a graph.
Expand Down Expand Up @@ -183,7 +208,13 @@ def explain_graph(self, x, edge_index, **kwargs):
self._clear_masks()
return node_feat_mask, edge_mask

def explain_node(self, node_idx, x, edge_index, **kwargs):
def explain_node(
self,
node_idx: Tensor,
x: Tensor,
edge_index: Tensor,
**kwargs,
) -> Tuple[Tensor, Tensor]:
r"""Learns and returns a node feature mask and an edge mask that play a
crucial role to explain the prediction made by the GNN for node
:attr:`node_idx`.
Expand Down Expand Up @@ -261,5 +292,5 @@ def explain_node(self, node_idx, x, edge_index, **kwargs):

return node_feat_mask, edge_mask

def __repr__(self):
def __repr__(self) -> str:
return f'{self.__class__.__name__}()'

0 comments on commit 396f183

Please sign in to comment.