From 5afd0752d1823b6b7ec63b94242df834e48fb68b Mon Sep 17 00:00:00 2001 From: Guy Aglionby Date: Sun, 19 Nov 2023 08:35:31 +0000 Subject: [PATCH] Fix `GraphMaskExplainer` for deep GNNs (#8401) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The current GraphMask explainer gave me an error for a model with > 2 layers. If you take the GCN Node Classification task in the example file, and modify the GNN https://github.com/pyg-team/pytorch_geometric/blob/cf24b4bcb4e825537ba08d8fc5f31073e2cd84c7/examples/explain/graphmask_explainer.py#L19-L29 to ```python class GCN(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(dataset.num_features, 16) self.conv2 = GCNConv(16, 16) self.conv3 = GCNConv(16, dataset.num_classes) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index).relu() x = F.dropout(x, training=self.training) x = self.conv3(x, edge_index) return F.log_softmax(x, dim=1) ``` And run the example, I get the following output: ``` Train explainer for node(s) tensor([5]) with layer 2: 100%|████████████████████████████████| 1/1 [00:01<00:00, 1.14s/it] Train explainer for node(s) tensor([5]) with layer 1: 100%|████████████████████████████████| 1/1 [00:01<00:00, 1.15s/it] Train explainer for node(s) tensor([5]) with layer 0: 100%|████████████████████████████████| 1/1 [00:01<00:00, 1.34s/it] Explain: 67%|███████████████████████████████████████████████████▎ | 2/3 [00:00<00:00, 12.54it/s]Traceback (most recent call last): File "/local/scratch/ga384/pyg/ex.py", line 101, in explanation = explainer(data.x, data.edge_index, index=node_index) File "/local/scratch/ga384/miniconda3/envs/pytorch-geo/lib/python3.10/site-packages/torch_geometric/explain/explainer.py", line 204, in __call__ explanation = self.algorithm( File "/local/scratch/ga384/miniconda3/envs/pytorch-geo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/local/scratch/ga384/miniconda3/envs/pytorch-geo/lib/python3.10/site-packages/torch_geometric/explain/algorithm/graphmask_explainer.py", line 133, in forward edge_mask = self._explain(model, index=index) File "/local/scratch/ga384/miniconda3/envs/pytorch-geo/lib/python3.10/site-packages/torch_geometric/explain/algorithm/graphmask_explainer.py", line 526, in _explain sampling_weights = F.pad( RuntimeError: Padding length too large ``` I don't think the padding is necessary here: it looks like `edge_weight` is just accumulating the results per layer, given this line: https://github.com/pyg-team/pytorch_geometric/blob/cf24b4bcb4e825537ba08d8fc5f31073e2cd84c7/torch_geometric/explain/algorithm/graphmask_explainer.py#L541-L542 This PR therefore removes the padding. With that change, my GCN works. --------- Co-authored-by: rusty1s --- CHANGELOG.md | 1 + torch_geometric/explain/algorithm/graphmask_explainer.py | 6 ------ 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ecc97adb613..8ac45e0b0de5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed `GraphMaskExplainer` for GNNs with more than two layers ([#8401](https://github.com/pyg-team/pytorch_geometric/pull/8401)) - Breaking Change: Properly initialize modules in `GATConv` depending on whether the input is bipartite or non-bipartite ([#8397](https://github.com/pyg-team/pytorch_geometric/pull/8397)) - Fixed `input_id` computation in `NeighborLoader` in case a `mask` is given ([#8312](https://github.com/pyg-team/pytorch_geometric/pull/8312)) - Respect current device when deep-copying `Linear` layers ([#8311](https://github.com/pyg-team/pytorch_geometric/pull/8311)) diff --git a/torch_geometric/explain/algorithm/graphmask_explainer.py b/torch_geometric/explain/algorithm/graphmask_explainer.py index 912026ce6c56..089fe7659a6e 100644 --- a/torch_geometric/explain/algorithm/graphmask_explainer.py +++ b/torch_geometric/explain/algorithm/graphmask_explainer.py @@ -526,12 +526,6 @@ def _explain( if i == 0: edge_weight = sampling_weights else: - if edge_weight.size(-1) != sampling_weights.size(-1): - sampling_weights = F.pad( - input=sampling_weights, - pad=(0, edge_weight.size(-1) - - sampling_weights.size(-1), 0, 0), - mode='constant', value=0) edge_weight = torch.cat((edge_weight, sampling_weights), 0) if self.log: pbar.update(1)