Skip to content

Commit

Permalink
[Code Coverage] explain/pg_explainer.py (#6824)
Browse files Browse the repository at this point in the history
Improve code coverage for `explain/pg_explainer.py` also clean up
`test_explain_algorithm_utils.py`

---------

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
zechengz and rusty1s authored Mar 1, 2023
1 parent c4a19e4 commit 7ed914b
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 48 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Properly reset the `data_list` cache of an `InMemoryDataset` when accessing `dataset.data` ([#6685](https://github.com/pyg-team/pytorch_geometric/pull/6685))
- Fixed a bug in `Data.subgraph()` and `HeteroData.subgraph()` ([#6613](https://github.com/pyg-team/pytorch_geometric/pull/6613))
- Fixed a bug in `PNAConv` and `DegreeScalerAggregation` to correctly incorporate degree statistics of isolated nodes ([#6609](https://github.com/pyg-team/pytorch_geometric/pull/6609))
- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6653](https://github.com/pyg-team/pytorch_geometric/pull/6653), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6662](https://github.com/pyg-team/pytorch_geometric/pull/6662), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664), [#6667](https://github.com/pyg-team/pytorch_geometric/pull/6667), [#6668](https://github.com/pyg-team/pytorch_geometric/pull/6668), [#6669](https://github.com/pyg-team/pytorch_geometric/pull/6669), [#6670](https://github.com/pyg-team/pytorch_geometric/pull/6670), [#6671](https://github.com/pyg-team/pytorch_geometric/pull/6671), [#6673](https://github.com/pyg-team/pytorch_geometric/pull/6673), [#6675](https://github.com/pyg-team/pytorch_geometric/pull/6675), [#6676](https://github.com/pyg-team/pytorch_geometric/pull/6676), [#6677](https://github.com/pyg-team/pytorch_geometric/pull/6677), [#6678](https://github.com/pyg-team/pytorch_geometric/pull/6678), [#6681](https://github.com/pyg-team/pytorch_geometric/pull/6681), [#6683](https://github.com/pyg-team/pytorch_geometric/pull/6683), [#6703](https://github.com/pyg-team/pytorch_geometric/pull/6703), [#6720](https://github.com/pyg-team/pytorch_geometric/pull/6720), [#6735](https://github.com/pyg-team/pytorch_geometric/pull/6735), [#6736](https://github.com/pyg-team/pytorch_geometric/pull/6736), [#6763](https://github.com/pyg-team/pytorch_geometric/pull/6763), [#6781](https://github.com/pyg-team/pytorch_geometric/pull/6781), [#6797](https://github.com/pyg-team/pytorch_geometric/pull/6797), [#6799](https://github.com/pyg-team/pytorch_geometric/pull/6799))
- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6653](https://github.com/pyg-team/pytorch_geometric/pull/6653), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6662](https://github.com/pyg-team/pytorch_geometric/pull/6662), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664), [#6667](https://github.com/pyg-team/pytorch_geometric/pull/6667), [#6668](https://github.com/pyg-team/pytorch_geometric/pull/6668), [#6669](https://github.com/pyg-team/pytorch_geometric/pull/6669), [#6670](https://github.com/pyg-team/pytorch_geometric/pull/6670), [#6671](https://github.com/pyg-team/pytorch_geometric/pull/6671), [#6673](https://github.com/pyg-team/pytorch_geometric/pull/6673), [#6675](https://github.com/pyg-team/pytorch_geometric/pull/6675), [#6676](https://github.com/pyg-team/pytorch_geometric/pull/6676), [#6677](https://github.com/pyg-team/pytorch_geometric/pull/6677), [#6678](https://github.com/pyg-team/pytorch_geometric/pull/6678), [#6681](https://github.com/pyg-team/pytorch_geometric/pull/6681), [#6683](https://github.com/pyg-team/pytorch_geometric/pull/6683), [#6703](https://github.com/pyg-team/pytorch_geometric/pull/6703), [#6720](https://github.com/pyg-team/pytorch_geometric/pull/6720), [#6735](https://github.com/pyg-team/pytorch_geometric/pull/6735), [#6736](https://github.com/pyg-team/pytorch_geometric/pull/6736), [#6763](https://github.com/pyg-team/pytorch_geometric/pull/6763), [#6781](https://github.com/pyg-team/pytorch_geometric/pull/6781), [#6797](https://github.com/pyg-team/pytorch_geometric/pull/6797), [#6799](https://github.com/pyg-team/pytorch_geometric/pull/6799), [#6824](https://github.com/pyg-team/pytorch_geometric/pull/6824))
- Fixed a bug in which `data.to_heterogeneous()` filtered attributs in the wrong dimension ([#6522](https://github.com/pyg-team/pytorch_geometric/pull/6522))
- Breaking Change: Temporal sampling will now also sample nodes with an equal timestamp to the seed time (requires `pyg-lib>0.1.0`) ([#6517](https://github.com/pyg-team/pytorch_geometric/pull/6517))
- Changed `DataLoader` workers with affinity to start at `cpu0` ([#6512](https://github.com/pyg-team/pytorch_geometric/pull/6512))
Expand Down
37 changes: 10 additions & 27 deletions test/explain/algorithm/test_explain_algorithm_utils.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,12 @@
import torch

from torch_geometric.data import HeteroData
from torch_geometric.explain.algorithm.utils import (
clear_masks,
set_hetero_masks,
)
from torch_geometric.nn import GCNConv, HeteroConv, SAGEConv, to_hetero


def get_edge_index(num_src_nodes, num_dst_nodes, num_edges):
row = torch.randint(num_src_nodes, (num_edges, ), dtype=torch.long)
col = torch.randint(num_dst_nodes, (num_edges, ), dtype=torch.long)
return torch.stack([row, col], dim=0)


def get_hetero_data():
data = HeteroData()
data['paper'].x = torch.randn(8, 16)
data['author'].x = torch.randn(10, 8)
data['paper', 'paper'].edge_index = get_edge_index(8, 8, 10)
data['author', 'paper'].edge_index = get_edge_index(10, 8, 10)
data['paper', 'author'].edge_index = get_edge_index(8, 10, 10)
return data


class HeteroModel(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -58,8 +41,7 @@ def forward(self, x, edge_index):
return self.conv2(x, edge_index)


def test_set_clear_mask():
data = get_hetero_data()
def test_set_clear_mask(hetero_data):
edge_mask_dict = {
('paper', 'to', 'paper'): torch.ones(200),
('author', 'to', 'paper'): torch.ones(100),
Expand All @@ -68,31 +50,32 @@ def test_set_clear_mask():

model = HeteroModel()

set_hetero_masks(model, edge_mask_dict, data.edge_index_dict)
for edge_type in data.edge_types: # Check that masks are correctly set:
set_hetero_masks(model, edge_mask_dict, hetero_data.edge_index_dict)
for edge_type in hetero_data.edge_types:
# Check that masks are correctly set:
str_edge_type = '__'.join(edge_type)
assert torch.allclose(model.conv1.convs[str_edge_type]._edge_mask,
edge_mask_dict[edge_type])
assert model.conv1.convs[str_edge_type].explain

clear_masks(model)
for edge_type in data.edge_types:
for edge_type in hetero_data.edge_types:
str_edge_type = '__'.join(edge_type)
assert model.conv1.convs[str_edge_type]._edge_mask is None
assert not model.conv1.convs[str_edge_type].explain

model = GraphSAGE()
model = to_hetero(GraphSAGE(), data.metadata(), debug=False)
model = to_hetero(GraphSAGE(), hetero_data.metadata(), debug=False)

set_hetero_masks(model, edge_mask_dict, data.edge_index_dict)
for edge_type in data.edge_types: # Check that masks are correctly set:
set_hetero_masks(model, edge_mask_dict, hetero_data.edge_index_dict)
for edge_type in hetero_data.edge_types:
# Check that masks are correctly set:
str_edge_type = '__'.join(edge_type)
assert torch.allclose(model.conv1[str_edge_type]._edge_mask,
edge_mask_dict[edge_type])
assert model.conv1[str_edge_type].explain

clear_masks(model)
for edge_type in data.edge_types:
for edge_type in hetero_data.edge_types:
str_edge_type = '__'.join(edge_type)
assert model.conv1[str_edge_type]._edge_mask is None
assert not model.conv1[str_edge_type].explain
112 changes: 92 additions & 20 deletions test/explain/algorithm/test_pg_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
import torch

from torch_geometric.explain import Explainer, PGExplainer
from torch_geometric.explain.config import ModelConfig, ModelTaskLevel
from torch_geometric.explain.config import (
ModelConfig,
ModelMode,
ModelTaskLevel,
)
from torch_geometric.nn import GCNConv, global_add_pool
from torch_geometric.testing import withCUDA

Expand All @@ -12,8 +16,13 @@ def __init__(self, model_config: ModelConfig):
super().__init__()
self.model_config = model_config

if model_config.mode == ModelMode.multiclass_classification:
out_channels = 7
else:
out_channels = 1

self.conv1 = GCNConv(3, 16)
self.conv2 = GCNConv(16, 7)
self.conv2 = GCNConv(16, out_channels)

def forward(self, x, edge_index, batch=None, edge_label_index=None):
x = self.conv1(x, edge_index).relu()
Expand All @@ -24,19 +33,26 @@ def forward(self, x, edge_index, batch=None, edge_label_index=None):


@withCUDA
def test_pg_explainer_node(device, check_explanation):
@pytest.mark.parametrize('mode', [
ModelMode.binary_classification,
ModelMode.multiclass_classification,
ModelMode.regression,
])
def test_pg_explainer_node(device, check_explanation, mode):
x = torch.randn(8, 3, device=device)
edge_index = torch.tensor([
[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],
[1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6],
], device=device)
target = torch.randint(7, (x.size(0), ), device=device)

model_config = ModelConfig(
mode='multiclass_classification',
task_level='node',
return_type='raw',
)
if mode == ModelMode.binary_classification:
target = torch.randint(2, (x.size(0), ), device=device)
elif mode == ModelMode.multiclass_classification:
target = torch.randint(7, (x.size(0), ), device=device)
elif mode == ModelMode.regression:
target = torch.randn((x.size(0), 1), device=device)

model_config = ModelConfig(mode=mode, task_level='node', return_type='raw')

model = GCN(model_config).to(device)

Expand All @@ -49,7 +65,7 @@ def test_pg_explainer_node(device, check_explanation):
)

with pytest.raises(ValueError, match="not yet fully trained"):
explanation = explainer(x, edge_index, target=target)
explainer(x, edge_index, target=target)

explainer.algorithm.reset_parameters()
for epoch in range(2):
Expand All @@ -64,39 +80,95 @@ def test_pg_explainer_node(device, check_explanation):


@withCUDA
def test_pg_explainer_graph(device, check_explanation):
@pytest.mark.parametrize('mode', [
ModelMode.binary_classification,
ModelMode.multiclass_classification,
ModelMode.regression,
])
def test_pg_explainer_graph(device, check_explanation, mode):
x = torch.randn(8, 3, device=device)
edge_index = torch.tensor([
[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],
[1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6],
], device=device)
target = torch.randint(7, (1, ), device=device)

model_config = ModelConfig(
mode='multiclass_classification',
task_level='graph',
return_type='raw',
)
if mode == ModelMode.binary_classification:
target = torch.randint(2, (1, ), device=device)
elif mode == ModelMode.multiclass_classification:
target = torch.randint(7, (1, ), device=device)
elif mode == ModelMode.regression:
target = torch.randn((1, 1), device=device)

model_config = ModelConfig(mode=mode, task_level='graph',
return_type='raw')

model = GCN(model_config).to(device)

explainer = Explainer(
model=model,
algorithm=PGExplainer(epochs=10).to(device),
algorithm=PGExplainer(epochs=2).to(device),
explanation_type='phenomenon',
edge_mask_type='object',
model_config=model_config,
)

with pytest.raises(ValueError, match="not yet fully trained"):
explanation = explainer(x, edge_index, target=target)
explainer(x, edge_index, target=target)

explainer.algorithm.reset_parameters()
for epoch in range(10):
for epoch in range(2):
loss = explainer.algorithm.train(epoch, model, x, edge_index,
target=target)
assert loss >= 0.0

explanation = explainer(x, edge_index, target=target)

check_explanation(explanation, None, explainer.edge_mask_type)


def test_pg_explainer_supports():
# Test unsupported model task level:
with pytest.raises(ValueError, match="not support the given explanation"):
model_config = ModelConfig(
mode='binary_classification',
task_level='edge',
return_type='raw',
)
Explainer(
model=GCN(model_config),
algorithm=PGExplainer(epochs=2),
explanation_type='phenomenon',
edge_mask_type='object',
model_config=model_config,
)

# Test unsupported explanation type:
with pytest.raises(ValueError, match="not support the given explanation"):
model_config = ModelConfig(
mode='binary_classification',
task_level='node',
return_type='raw',
)
Explainer(
model=GCN(model_config),
algorithm=PGExplainer(epochs=2),
explanation_type='model',
edge_mask_type='object',
model_config=model_config,
)

# Test unsupported node mask:
with pytest.raises(ValueError, match="not support the given explanation"):
model_config = ModelConfig(
mode='binary_classification',
task_level='node',
return_type='raw',
)
Explainer(
model=GCN(model_config),
algorithm=PGExplainer(epochs=2),
explanation_type='model',
node_mask_type='object',
edge_mask_type='object',
model_config=model_config,
)

0 comments on commit 7ed914b

Please sign in to comment.