Skip to content

Commit

Permalink
Connect configurations to ExplainerAlgorithm (#6089)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Nov 29, 2022
1 parent d3a493f commit 8adfd02
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 147 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added missing test labels in `HGBDataset` ([#5233](https://github.com/pyg-team/pytorch_geometric/pull/5233))
- Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240))
- Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222))
- Added `torch_geometric.explain` module with base functionality for explainability methods ([#5804](https://github.com/pyg-team/pytorch_geometric/pull/5804), [#6054](https://github.com/pyg-team/pytorch_geometric/pull/6054))
- Added `torch_geometric.explain` module with base functionality for explainability methods ([#5804](https://github.com/pyg-team/pytorch_geometric/pull/5804), [#6054](https://github.com/pyg-team/pytorch_geometric/pull/6054), [#6089](https://github.com/pyg-team/pytorch_geometric/pull/6089))
### Changed
- Moved and adapted `GNNExplainer` from `torch_geometric.nn` to `torch_geometric.explain.algorithm` ([#5967](https://github.com/pyg-team/pytorch_geometric/pull/5967), [#6065](https://github.com/pyg-team/pytorch_geometric/pull/6065))
- Optimized scatter implementations for CPU/GPU, both with and without backward computation ([#6051](https://github.com/pyg-team/pytorch_geometric/pull/6051), [#6052](https://github.com/pyg-team/pytorch_geometric/pull/6052))
Expand Down
69 changes: 47 additions & 22 deletions torch_geometric/explain/algorithm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ def forward(
x: Tensor,
edge_index: Tensor,
*,
explainer_config: ExplainerConfig,
model_config: ModelConfig,
target: Tensor,
index: Optional[Union[int, Tensor]] = None,
target_index: Optional[int] = None,
Expand All @@ -36,8 +34,6 @@ def forward(
model (torch.nn.Module): The model to explain.
x (torch.Tensor): The input node features.
edge_index (torch.Tensor): The input edge indices.
explainer_config (ExplainerConfig): The explainer configuration.
model_config (ModelConfig): The model configuration.
target (torch.Tensor): The target of the model.
index (Union[int, Tensor], optional): The index of the model
output to explain. Can be a single index or a tensor of
Expand All @@ -52,23 +48,54 @@ def forward(
"""

@abstractmethod
def supports(
def supports(self) -> bool:
r"""Checks if the explainer supports the user-defined settings provided
in :obj:`self.explainer_config` and :obj:`self.model_config`."""
pass

###########################################################################

@property
def explainer_config(self) -> ExplainerConfig:
r"""Returns the connected explainer configuration."""
if not hasattr(self, '_explainer_config'):
raise ValueError(
f"The explanation algorithm '{self.__class__.__name__}' is "
f"not yet connected to any explainer configuration. Please "
f"call `{self.__class__.__name__}.connect(...)` before "
f"proceeding.")
return self._explainer_config

@property
def model_config(self) -> ModelConfig:
r"""Returns the connected model configuration."""
if not hasattr(self, '_model_config'):
raise ValueError(
f"The explanation algorithm '{self.__class__.__name__}' is "
f"not yet connected to any model configuration. Please call "
f"`{self.__class__.__name__}.connect(...)` before "
f"proceeding.")
return self._model_config

def connect(
self,
explainer_config: ExplainerConfig,
model_config: ModelConfig,
) -> bool:
r"""Checks if the explainer supports the user-defined settings.
):
r"""Connects an explainer and model configuration to the explainer
algorithm."""
self._explainer_config = ExplainerConfig.cast(explainer_config)
self._model_config = ModelConfig.cast(model_config)

Args:
explainer_config (ExplainerConfig): The explainer configuration.
model_config (ModelConfig): the model configuration.
"""
pass
if not self.supports():
raise ValueError(
f"The explanation algorithm '{self.__class__.__name__}' does "
f"not support the given explanation settings.")

# Helper functions ########################################################

@staticmethod
def _post_process_mask(
self,
mask: Optional[Tensor],
num_elems: int,
hard_mask: Optional[Tensor] = None,
Expand All @@ -92,8 +119,8 @@ def _post_process_mask(

return mask

@staticmethod
def _get_hard_masks(
self,
model: torch.nn.Module,
index: Optional[Union[int, Tensor]],
edge_index: Tensor,
Expand All @@ -106,10 +133,10 @@ def _get_hard_masks(

index, _, _, edge_mask = k_hop_subgraph(
index,
num_hops=self._num_hops(model),
num_hops=ExplainerAlgorithm._num_hops(model),
edge_index=edge_index,
num_nodes=num_nodes,
flow=self._flow(model),
flow=ExplainerAlgorithm._flow(model),
)

node_mask = edge_index.new_zeros(num_nodes, dtype=torch.bool)
Expand All @@ -136,18 +163,16 @@ def _flow(model: torch.nn.Module) -> str:
return module.flow
return 'source_to_target'

@staticmethod
def _to_log_prob(y: Tensor, return_type: ModelReturnType) -> Tensor:
def _to_log_prob(self, y: Tensor) -> Tensor:
r"""Converts the model output to log-probabilities.
Args:
y (Tensor): The output of the model.
return_type (ModelReturnType): The model return type.
"""
if return_type == ModelReturnType.probs:
if self.model_config.return_type == ModelReturnType.probs:
return y.log()
if return_type == ModelReturnType.raw:
if self.model_config.return_type == ModelReturnType.raw:
return y.log_softmax(dim=-1)
if return_type == ModelReturnType.log_probs:
if self.model_config.return_type == ModelReturnType.log_probs:
return y
raise NotImplementedError
28 changes: 15 additions & 13 deletions torch_geometric/explain/algorithm/dummy_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,26 @@

from torch_geometric.explain import Explanation
from torch_geometric.explain.algorithm import ExplainerAlgorithm
from torch_geometric.explain.config import ExplainerConfig, ModelConfig


class DummyExplainer(ExplainerAlgorithm):
r"""A dummy explainer for testing purposes."""
def forward(self, x: Tensor, edge_index: Tensor,
edge_attr: Optional[Tensor] = None, **kwargs) -> Explanation:
def forward(
self,
model: torch.nn.Module,
x: Tensor,
edge_index: Tensor,
edge_attr: Optional[Tensor] = None,
**kwargs,
) -> Explanation:
r"""Returns random explanations based on the shape of the inputs.
Args:
x (torch.Tensor): node features.
edge_index (torch.Tensor): edge indices.
edge_attr (torch.Tensor, optional): edge attributes.
Defaults to None.
model (torch.nn.Module): The model to explain.
x (torch.Tensor): The node features.
edge_index (torch.Tensor): The edge indices.
edge_attr (torch.Tensor, optional): The edge attributes.
(default: :obj:`None`)
Returns:
Explanation: A random explanation based on the shape of the inputs.
Expand All @@ -31,7 +37,7 @@ def forward(self, x: Tensor, edge_index: Tensor,
mask_dict['edge_mask'] = torch.rand(num_edges, device=x.device)

if edge_attr is not None:
mask_dict["edge_features_mask"] = torch.rand_like(edge_attr)
mask_dict['edge_feat_mask'] = torch.rand_like(edge_attr)

return Explanation(
edge_index=edge_index,
Expand All @@ -40,9 +46,5 @@ def forward(self, x: Tensor, edge_index: Tensor,
**mask_dict,
)

def supports(
self,
explainer_config: ExplainerConfig,
model_config: ModelConfig,
) -> bool:
def supports(self) -> bool:
return True
Loading

0 comments on commit 8adfd02

Please sign in to comment.