-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implementation of the PGM explainer and examples for node and graph explanations. For now, I have left out the PGM generation part from the original repo (https://github.com/vunhatminh/PGMExplainer/) as there were issues with getting it to work with the example dataset (one of their steps seem to cause a memory explosion in pandas) I have also combined the node pertubation functions from the original repo into 1 as there are so much overlap between them This sort of works but would appreciate feedback/sanity checks that I have not misunderstood anything about the explainer framerwork Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jinu Sunil <jinu.sunil@gmail.com>
- Loading branch information
1 parent
6ef6330
commit ecf4020
Showing
7 changed files
with
709 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
""" | ||
This is an example of using the PGM explainer algorithm | ||
on a graph classification task | ||
""" | ||
import os.path as osp | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
import torch_geometric.transforms as T | ||
from torch_geometric.contrib.explain import PGMExplainer | ||
from torch_geometric.datasets import MNISTSuperpixels | ||
from torch_geometric.explain import Explainer | ||
from torch_geometric.loader import DataLoader | ||
from torch_geometric.nn import ( | ||
NNConv, | ||
global_mean_pool, | ||
graclus, | ||
max_pool, | ||
max_pool_x, | ||
) | ||
from torch_geometric.utils import normalized_cut | ||
|
||
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'MNIST') | ||
transform = T.Cartesian(cat=False) | ||
train_dataset = MNISTSuperpixels(path, True, transform=transform) | ||
test_dataset = MNISTSuperpixels(path, False, transform=transform) | ||
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True) | ||
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False) | ||
d = train_dataset | ||
|
||
|
||
def normalized_cut_2d(edge_index, pos): | ||
row, col = edge_index | ||
edge_attr = torch.norm(pos[row] - pos[col], p=2, dim=1) | ||
return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0)) | ||
|
||
|
||
class Net(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
nn1 = nn.Sequential(nn.Linear(2, 25), nn.ReLU(), | ||
nn.Linear(25, d.num_features * 32)) | ||
self.conv1 = NNConv(d.num_features, 32, nn1, aggr='mean') | ||
|
||
nn2 = nn.Sequential(nn.Linear(2, 25), nn.ReLU(), | ||
nn.Linear(25, 32 * 64)) | ||
self.conv2 = NNConv(32, 64, nn2, aggr='mean') | ||
|
||
self.fc1 = torch.nn.Linear(64, 128) | ||
self.fc2 = torch.nn.Linear(128, d.num_classes) | ||
|
||
def forward(self, x, edge_index, **kwargs): | ||
data = kwargs.get('data') | ||
data = data.detach().clone() | ||
x = F.elu(self.conv1(x, edge_index, data.edge_attr)) | ||
weight = normalized_cut_2d(edge_index, data.pos) | ||
cluster = graclus(edge_index, weight, x.size(0)) | ||
data.edge_attr = None | ||
data.x = x | ||
data.edge_index = edge_index | ||
data = max_pool(cluster, data, transform=transform) | ||
|
||
data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr)) | ||
weight = normalized_cut_2d(data.edge_index, data.pos) | ||
cluster = graclus(data.edge_index, weight, data.x.size(0)) | ||
x, batch = max_pool_x(cluster, data.x, data.batch) | ||
|
||
x = global_mean_pool(x, batch) | ||
x = F.elu(self.fc1(x)) | ||
x = F.dropout(x, training=self.training) | ||
return F.log_softmax(self.fc2(x), dim=1) | ||
|
||
|
||
def train(model, dataloader): | ||
model.train() | ||
|
||
for data in dataloader: | ||
data = data.to(device) | ||
optimizer.zero_grad() | ||
F.nll_loss(model(data.x, data), data.y).backward() | ||
optimizer.step() | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
print(f'current device: {device}') | ||
model = Net().to(device) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) | ||
|
||
for epoch in range(2): | ||
train(model, train_loader) | ||
|
||
explainer = Explainer( | ||
model=model, algorithm=PGMExplainer(perturb_feature_list=[0], | ||
perturbation_mode="mean"), | ||
explanation_type='phenomenon', node_mask_type="object", | ||
model_config=dict(mode="multiclass_classification", task_level="graph", | ||
return_type="raw")) | ||
i = 0 | ||
|
||
for explain_dataset in test_loader: | ||
explain_dataset.to(device) | ||
explanation = explainer(x=explain_dataset.x, | ||
edge_index=explain_dataset.edge_index, | ||
target=explain_dataset.y, | ||
edge_attr=explain_dataset.edge_attr, | ||
data=explain_dataset) | ||
for k in explanation.available_explanations: | ||
print(explanation[k]) | ||
i += 1 | ||
if i > 2: | ||
break |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
""" | ||
This is an example of using the PGM explainer algorithm | ||
on a node classification task | ||
""" | ||
import os.path as osp | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
import torch_geometric.transforms as T | ||
from torch_geometric.contrib.explain import PGMExplainer | ||
from torch_geometric.datasets import Planetoid | ||
from torch_geometric.explain import Explainer, ModelConfig | ||
from torch_geometric.nn import GCNConv | ||
|
||
dataset = 'Cora' | ||
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') | ||
transform = T.Compose([T.GCNNorm(), T.NormalizeFeatures()]) | ||
dataset = Planetoid(path, dataset, transform=transform) | ||
data = dataset[0] | ||
|
||
|
||
class Net(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.conv1 = GCNConv(dataset.num_features, 16, normalize=False) | ||
self.conv2 = GCNConv(16, dataset.num_classes, normalize=False) | ||
|
||
def forward(self, x, edge_index, edge_weight): | ||
x = F.relu(self.conv1(x, edge_index, edge_weight)) | ||
x = F.dropout(x, training=self.training) | ||
x = self.conv2(x, edge_index, edge_weight) | ||
return F.log_softmax(x, dim=1) | ||
|
||
|
||
if __name__ == "__main__": | ||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
model = Net().to(device) | ||
data = data.to(device) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, | ||
weight_decay=5e-4) | ||
x, edge_index, edge_weight, target = \ | ||
data.x, data.edge_index, data.edge_weight, data.y | ||
|
||
model.train() | ||
for epoch in range(1, 500): | ||
optimizer.zero_grad() | ||
log_logits = model(x, edge_index, edge_weight) | ||
loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask]) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
model.eval() | ||
log_logits = model(x, edge_index, edge_weight) | ||
predicted_target = log_logits.argmax(dim=1) | ||
|
||
explainer = Explainer( | ||
model=model, algorithm=PGMExplainer(), node_mask_type='attributes', | ||
explanation_type='phenomenon', | ||
model_config=ModelConfig(mode='multiclass_classification', | ||
task_level='node', return_type='raw')) | ||
node_idx = 100 | ||
explanation = explainer(x=data.x, edge_index=edge_index, index=node_idx, | ||
target=predicted_target, edge_weight=edge_weight) | ||
print('significance of relevant neighbours using pgm explainer :', | ||
explanation.pgm_stats) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import pytest | ||
import torch | ||
|
||
from torch_geometric.contrib.explain import PGMExplainer | ||
from torch_geometric.explain import Explainer | ||
from torch_geometric.explain.config import ModelConfig | ||
from torch_geometric.nn import GCNConv, global_add_pool | ||
from torch_geometric.testing import withPackage | ||
|
||
|
||
class GCN(torch.nn.Module): | ||
def __init__(self, model_config: ModelConfig): | ||
super().__init__() | ||
self.model_config = model_config | ||
|
||
if model_config.mode.value == 'multiclass_classification': | ||
out_channels = 7 | ||
else: | ||
out_channels = 1 | ||
|
||
self.conv1 = GCNConv(3, 16) | ||
self.conv2 = GCNConv(16, out_channels) | ||
|
||
def forward(self, x, edge_index, edge_weight=None, batch=None, **kwargs): | ||
x = self.conv1(x, edge_index, edge_weight).relu() | ||
x = self.conv2(x, edge_index, edge_weight).relu() | ||
|
||
if self.model_config.task_level.value == 'graph': | ||
x = global_add_pool(x, batch) | ||
|
||
if self.model_config.mode.value == 'binary_classification': | ||
x = x.sigmoid() | ||
elif self.model_config.mode.value == 'multiclass_classification': | ||
x = x.log_softmax(dim=-1) | ||
|
||
return x | ||
|
||
|
||
x = torch.randn(8, 3) | ||
edge_index = torch.tensor([ | ||
[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], | ||
[1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6], | ||
]) | ||
target = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2]) | ||
edge_label_index = torch.tensor([[0, 1, 2], [3, 4, 5]]) | ||
|
||
|
||
@withPackage('pgmpy') | ||
@withPackage('pandas') | ||
@pytest.mark.parametrize('node_idx', [2, 6]) | ||
@pytest.mark.parametrize('task_level, perturbation_mode', [ | ||
('node', 'randint'), | ||
('graph', 'mean'), | ||
('graph', 'max'), | ||
('graph', 'min'), | ||
('graph', 'zero'), | ||
]) | ||
def test_pgm_explainer_classification(node_idx, task_level, perturbation_mode): | ||
model_config = ModelConfig( | ||
mode='multiclass_classification', | ||
task_level=task_level, | ||
return_type='raw', | ||
) | ||
|
||
model = GCN(model_config) | ||
logits = model(x, edge_index) | ||
target = logits.argmax(dim=1) | ||
|
||
explainer = Explainer( | ||
model=model, | ||
algorithm=PGMExplainer(feature_index=[0], | ||
perturbation_mode=perturbation_mode), | ||
explanation_type='phenomenon', | ||
node_mask_type='object', | ||
model_config=model_config, | ||
) | ||
|
||
explanation = explainer( | ||
x=x, | ||
edge_index=edge_index, | ||
index=node_idx, | ||
target=target, | ||
) | ||
|
||
assert 'node_mask' in explanation | ||
assert 'pgm_stats' in explanation | ||
assert explanation.node_mask.size(0) == explanation.num_nodes | ||
assert explanation.node_mask.min() >= 0 | ||
assert explanation.node_mask.max() <= 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
from .graphmask_explainer import GraphMaskExplainer | ||
from .pgm_explainer import PGMExplainer | ||
|
||
__all__ = classes = [ | ||
'GraphMaskExplainer', | ||
'PGMExplainer', | ||
] |
Oops, something went wrong.