diff --git a/CHANGELOG.md b/CHANGELOG.md index c2d1fcd28c2e..4bd7fd80d2e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.3.0] - 2023-MM-DD ### Added +- Added the `BA2MotifDataset` explainer dataset ([#6257](https://github.com/pyg-team/pytorch_geometric/pull/6257)) - Added `CycleMotif` motif generator to generate `n`-node cycle shaped motifs ([#6256](https://github.com/pyg-team/pytorch_geometric/pull/6256)) - Added the `InfectionDataset` to evaluate explanations ([#6222](https://github.com/pyg-team/pytorch_geometric/pull/6222)) - Added `characterization_score` and `fidelity_curve_auc` explainer metrics ([#6188](https://github.com/pyg-team/pytorch_geometric/pull/6188)) diff --git a/torch_geometric/datasets/__init__.py b/torch_geometric/datasets/__init__.py index f54d00370e1a..e5120398fe58 100644 --- a/torch_geometric/datasets/__init__.py +++ b/torch_geometric/datasets/__init__.py @@ -80,6 +80,7 @@ from .hydro_net import HydroNet from .explainer_dataset import ExplainerDataset from .infection_dataset import InfectionDataset +from .ba2motif_dataset import BA2MotifDataset import torch_geometric.datasets.utils # noqa @@ -169,6 +170,7 @@ 'HydroNet', 'ExplainerDataset', 'InfectionDataset', + 'BA2MotifDataset', ] classes = __all__ diff --git a/torch_geometric/datasets/ba2motif_dataset.py b/torch_geometric/datasets/ba2motif_dataset.py new file mode 100644 index 000000000000..e801bb8f3e9a --- /dev/null +++ b/torch_geometric/datasets/ba2motif_dataset.py @@ -0,0 +1,120 @@ +import pickle +from typing import Callable, List, Optional + +import torch + +from torch_geometric.data import Data, InMemoryDataset, download_url + + +class BA2MotifDataset(InMemoryDataset): + r"""The synthetic BA-2motifs graph classification dataset for evaluating + explainabilty algorithms, as described in the `"Parameterized Explainer + for Graph Neural Network" `_ paper. + :class:`~torch_geometric.datasets.BA2MotifDataset` contains 1000 random + Barabasi-Albert (BA) graphs. + Half of the graphs are attached with a + :class:`~torch_geometric.datasets.motif_generator.HouseMotif`, and the rest + are attached with a five-node + :class:`~torch_geometric.datasets.motif_generator.CycleMotif`. + The graphs are assigned to one of the two classes according to the type of + attached motifs. + + This dataset is pre-computed from the official implementation. If you want + to create own variations of it, you can make use of the + :class:`~torch_geometric.datasets.ExplainerDataset`: + + .. code-block:: python + + import torch + from torch_geometric.datasets import ExplainerDataset + from torch_geometric.datasets.graph_generator import BAGraph + from torch_geometric.datasets.motif_generator import HouseMotif + from torch_geometric.datasets.motif_generator import CycleMotif + + dataset1 = ExplainerDataset( + graph_generator=BAGraph(num_nodes=25, num_edges=1), + motif_generator=HouseMotif(), + num_motifs=1, + num_graphs=500, + ) + + dataset2 = ExplainerDataset( + graph_generator=BAGraph(num_nodes=25, num_edges=1), + motif_generator=CycleMotif(5), + num_motifs=1, + num_graphs=500, + ) + + dataset = torch.utils.data.ConcatDataset([dataset1, dataset2]) + + Args: + root (string): Root directory where the dataset should be saved. + transform (callable, optional): A function/transform that takes in an + :obj:`torch_geometric.data.Data` object and returns a transformed + version. The data object will be transformed before every access. + (default: :obj:`None`) + pre_transform (callable, optional): A function/transform that takes in + an :obj:`torch_geometric.data.Data` object and returns a + transformed version. The data object will be transformed before + being saved to disk. (default: :obj:`None`) + + Stats: + .. list-table:: + :widths: 10 10 10 10 10 + :header-rows: 1 + + * - #graphs + - #nodes + - #edges + - #features + - #classes + * - 1000 + - 25 + - ~51.0 + - 10 + - 2 + """ + url = 'https://github.com/flyingdoog/PGExplainer/raw/master/dataset' + filename = 'BA-2motif.pkl' + + def __init__( + self, + root: str, + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + ): + super().__init__(root, transform, pre_transform) + self.data, self.slices = torch.load(self.processed_paths[0]) + + def raw_file_names(self) -> str: + return self.filename + + @property + def processed_file_names(self) -> str: + return 'data.pt' + + def download(self): + download_url(f'{self.url}/{self.filename}', self.raw_dir) + + def process(self): + with open(self.raw_paths[0], 'rb') as f: + adj, x, y = pickle.load(f) + + adjs = torch.from_numpy(adj) + xs = torch.from_numpy(x).to(torch.float) + ys = torch.from_numpy(y) + + data_list: List[Data] = [] + for i in range(xs.size(0)): + edge_index = adjs[i].nonzero().t() + x = xs[i] + y = int(ys[i].nonzero()) + + data = Data(x=x, edge_index=edge_index, y=y) + + if self.pre_transform is not None: + data = self.pre_transform(data) + + data_list.append(data) + + torch.save(self.collate(data_list), self.processed_paths[0])