Skip to content

Commit

Permalink
PGMExplainer (#6149)
Browse files Browse the repository at this point in the history
Implementation of the PGM explainer and examples for node and graph
explanations. For now, I have left out the PGM generation part from the
original repo (https://github.com/vunhatminh/PGMExplainer/) as there
were issues with getting it to work with the example dataset (one of
their steps seem to cause a memory explosion in pandas)

I have also combined the node pertubation functions from the original
repo into 1 as there are so much overlap between them

This sort of works but would appreciate feedback/sanity checks that I
have not misunderstood anything about the explainer framerwork

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jinu Sunil <jinu.sunil@gmail.com>
  • Loading branch information
3 people authored Jan 24, 2023
1 parent 6ef6330 commit ecf4020
Show file tree
Hide file tree
Showing 7 changed files with 709 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `PGMExplainer` to `torch_geometric.contrib` ([#6149](https://github.com/pyg-team/pytorch_geometric/pull/6149))
- Added a `NumNeighbors` helper class for specifying the number of neighbors when sampling ([#6501](https://github.com/pyg-team/pytorch_geometric/pull/6501))
- Added caching to `is_node_attr()` and `is_edge_attr()` calls ([#6492](https://github.com/pyg-team/pytorch_geometric/pull/6492))
- Added `ToHeteroLinear` and `ToHeteroMessagePassing` modules to accelerate `to_hetero` functionality ([#5992](https://github.com/pyg-team/pytorch_geometric/pull/5992), [#6456](https://github.com/pyg-team/pytorch_geometric/pull/6456))
Expand Down
115 changes: 115 additions & 0 deletions examples/contrib/pgm_explainer_graph_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""
This is an example of using the PGM explainer algorithm
on a graph classification task
"""
import os.path as osp

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.contrib.explain import PGMExplainer
from torch_geometric.datasets import MNISTSuperpixels
from torch_geometric.explain import Explainer
from torch_geometric.loader import DataLoader
from torch_geometric.nn import (
NNConv,
global_mean_pool,
graclus,
max_pool,
max_pool_x,
)
from torch_geometric.utils import normalized_cut

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'MNIST')
transform = T.Cartesian(cat=False)
train_dataset = MNISTSuperpixels(path, True, transform=transform)
test_dataset = MNISTSuperpixels(path, False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
d = train_dataset


def normalized_cut_2d(edge_index, pos):
row, col = edge_index
edge_attr = torch.norm(pos[row] - pos[col], p=2, dim=1)
return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0))


class Net(nn.Module):
def __init__(self):
super().__init__()
nn1 = nn.Sequential(nn.Linear(2, 25), nn.ReLU(),
nn.Linear(25, d.num_features * 32))
self.conv1 = NNConv(d.num_features, 32, nn1, aggr='mean')

nn2 = nn.Sequential(nn.Linear(2, 25), nn.ReLU(),
nn.Linear(25, 32 * 64))
self.conv2 = NNConv(32, 64, nn2, aggr='mean')

self.fc1 = torch.nn.Linear(64, 128)
self.fc2 = torch.nn.Linear(128, d.num_classes)

def forward(self, x, edge_index, **kwargs):
data = kwargs.get('data')
data = data.detach().clone()
x = F.elu(self.conv1(x, edge_index, data.edge_attr))
weight = normalized_cut_2d(edge_index, data.pos)
cluster = graclus(edge_index, weight, x.size(0))
data.edge_attr = None
data.x = x
data.edge_index = edge_index
data = max_pool(cluster, data, transform=transform)

data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr))
weight = normalized_cut_2d(data.edge_index, data.pos)
cluster = graclus(data.edge_index, weight, data.x.size(0))
x, batch = max_pool_x(cluster, data.x, data.batch)

x = global_mean_pool(x, batch)
x = F.elu(self.fc1(x))
x = F.dropout(x, training=self.training)
return F.log_softmax(self.fc2(x), dim=1)


def train(model, dataloader):
model.train()

for data in dataloader:
data = data.to(device)
optimizer.zero_grad()
F.nll_loss(model(data.x, data), data.y).backward()
optimizer.step()


if __name__ == "__main__":

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'current device: {device}')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(2):
train(model, train_loader)

explainer = Explainer(
model=model, algorithm=PGMExplainer(perturb_feature_list=[0],
perturbation_mode="mean"),
explanation_type='phenomenon', node_mask_type="object",
model_config=dict(mode="multiclass_classification", task_level="graph",
return_type="raw"))
i = 0

for explain_dataset in test_loader:
explain_dataset.to(device)
explanation = explainer(x=explain_dataset.x,
edge_index=explain_dataset.edge_index,
target=explain_dataset.y,
edge_attr=explain_dataset.edge_attr,
data=explain_dataset)
for k in explanation.available_explanations:
print(explanation[k])
i += 1
if i > 2:
break
66 changes: 66 additions & 0 deletions examples/contrib/pgm_explainer_node_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
This is an example of using the PGM explainer algorithm
on a node classification task
"""
import os.path as osp

import torch
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.contrib.explain import PGMExplainer
from torch_geometric.datasets import Planetoid
from torch_geometric.explain import Explainer, ModelConfig
from torch_geometric.nn import GCNConv

dataset = 'Cora'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')
transform = T.Compose([T.GCNNorm(), T.NormalizeFeatures()])
dataset = Planetoid(path, dataset, transform=transform)
data = dataset[0]


class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GCNConv(dataset.num_features, 16, normalize=False)
self.conv2 = GCNConv(16, dataset.num_classes, normalize=False)

def forward(self, x, edge_index, edge_weight):
x = F.relu(self.conv1(x, edge_index, edge_weight))
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index, edge_weight)
return F.log_softmax(x, dim=1)


if __name__ == "__main__":
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01,
weight_decay=5e-4)
x, edge_index, edge_weight, target = \
data.x, data.edge_index, data.edge_weight, data.y

model.train()
for epoch in range(1, 500):
optimizer.zero_grad()
log_logits = model(x, edge_index, edge_weight)
loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()

model.eval()
log_logits = model(x, edge_index, edge_weight)
predicted_target = log_logits.argmax(dim=1)

explainer = Explainer(
model=model, algorithm=PGMExplainer(), node_mask_type='attributes',
explanation_type='phenomenon',
model_config=ModelConfig(mode='multiclass_classification',
task_level='node', return_type='raw'))
node_idx = 100
explanation = explainer(x=data.x, edge_index=edge_index, index=node_idx,
target=predicted_target, edge_weight=edge_weight)
print('significance of relevant neighbours using pgm explainer :',
explanation.pgm_stats)
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
'matplotlib',
'scikit-image',
'pytorch-memlab',
'pgmpy',
'opt_einsum', # required for pgmpy
'statsmodels',
]

benchmark_requires = [
Expand Down
89 changes: 89 additions & 0 deletions test/contrib/explain/test_pgm_explainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import pytest
import torch

from torch_geometric.contrib.explain import PGMExplainer
from torch_geometric.explain import Explainer
from torch_geometric.explain.config import ModelConfig
from torch_geometric.nn import GCNConv, global_add_pool
from torch_geometric.testing import withPackage


class GCN(torch.nn.Module):
def __init__(self, model_config: ModelConfig):
super().__init__()
self.model_config = model_config

if model_config.mode.value == 'multiclass_classification':
out_channels = 7
else:
out_channels = 1

self.conv1 = GCNConv(3, 16)
self.conv2 = GCNConv(16, out_channels)

def forward(self, x, edge_index, edge_weight=None, batch=None, **kwargs):
x = self.conv1(x, edge_index, edge_weight).relu()
x = self.conv2(x, edge_index, edge_weight).relu()

if self.model_config.task_level.value == 'graph':
x = global_add_pool(x, batch)

if self.model_config.mode.value == 'binary_classification':
x = x.sigmoid()
elif self.model_config.mode.value == 'multiclass_classification':
x = x.log_softmax(dim=-1)

return x


x = torch.randn(8, 3)
edge_index = torch.tensor([
[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],
[1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6],
])
target = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2])
edge_label_index = torch.tensor([[0, 1, 2], [3, 4, 5]])


@withPackage('pgmpy')
@withPackage('pandas')
@pytest.mark.parametrize('node_idx', [2, 6])
@pytest.mark.parametrize('task_level, perturbation_mode', [
('node', 'randint'),
('graph', 'mean'),
('graph', 'max'),
('graph', 'min'),
('graph', 'zero'),
])
def test_pgm_explainer_classification(node_idx, task_level, perturbation_mode):
model_config = ModelConfig(
mode='multiclass_classification',
task_level=task_level,
return_type='raw',
)

model = GCN(model_config)
logits = model(x, edge_index)
target = logits.argmax(dim=1)

explainer = Explainer(
model=model,
algorithm=PGMExplainer(feature_index=[0],
perturbation_mode=perturbation_mode),
explanation_type='phenomenon',
node_mask_type='object',
model_config=model_config,
)

explanation = explainer(
x=x,
edge_index=edge_index,
index=node_idx,
target=target,
)

assert 'node_mask' in explanation
assert 'pgm_stats' in explanation
assert explanation.node_mask.size(0) == explanation.num_nodes
assert explanation.node_mask.min() >= 0
assert explanation.node_mask.max() <= 1
2 changes: 2 additions & 0 deletions torch_geometric/contrib/explain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .graphmask_explainer import GraphMaskExplainer
from .pgm_explainer import PGMExplainer

__all__ = classes = [
'GraphMaskExplainer',
'PGMExplainer',
]
Loading

0 comments on commit ecf4020

Please sign in to comment.