From 01f1a6e6907a73e280ec2fa05d71907ac427f06c Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 23 May 2023 07:10:58 +0000 Subject: [PATCH 1/2] update --- torch_geometric/datasets/planetoid.py | 4 ++-- torch_geometric/datasets/tu_dataset.py | 13 ++++++++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/torch_geometric/datasets/planetoid.py b/torch_geometric/datasets/planetoid.py index 9c16a378815b..2ace35493cb6 100644 --- a/torch_geometric/datasets/planetoid.py +++ b/torch_geometric/datasets/planetoid.py @@ -90,7 +90,7 @@ def __init__(self, root: str, name: str, split: str = "public", assert self.split in ['public', 'full', 'geom-gcn', 'random'] super().__init__(root, transform, pre_transform) - self.data, self.slices = torch.load(self.processed_paths[0]) + self.load(self.processed_paths[0]) if split == 'full': data = self.get(0) @@ -162,7 +162,7 @@ def process(self): data.test_mask = torch.stack(test_masks, dim=1) data = data if self.pre_transform is None else self.pre_transform(data) - torch.save(self.collate([data]), self.processed_paths[0]) + self.save([data], self.processed_paths[0]) def __repr__(self) -> str: return f'{self.name}()' diff --git a/torch_geometric/datasets/tu_dataset.py b/torch_geometric/datasets/tu_dataset.py index 83524f4db55a..aa6835bcb9a3 100644 --- a/torch_geometric/datasets/tu_dataset.py +++ b/torch_geometric/datasets/tu_dataset.py @@ -5,7 +5,12 @@ import torch -from torch_geometric.data import InMemoryDataset, download_url, extract_zip +from torch_geometric.data import ( + Data, + InMemoryDataset, + download_url, + extract_zip, +) from torch_geometric.io import read_tu_data @@ -131,7 +136,8 @@ def __init__(self, root: str, name: str, "If this error occurred while loading an already existing " "dataset, remove the 'processed/' directory in the dataset's " "root folder and try again.") - self.data, self.slices, self.sizes = out + data, self.slices, self.sizes = out + self.data = Data.from_dict(data) if isinstance(data, dict) else data if self._data.x is not None and not use_node_attr: num_node_attributes = self.num_node_attributes @@ -199,7 +205,8 @@ def process(self): self.data, self.slices = self.collate(data_list) self._data_list = None # Reset cache. - torch.save((self._data, self.slices, sizes), self.processed_paths[0]) + torch.save((self._data.to_dict(), self.slices, sizes), + self.processed_paths[0]) def __repr__(self) -> str: return f'{self.name}({len(self)})' From 434b978df4fc08c585fc0f2c8a693b7f6252ab6a Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 23 May 2023 07:11:54 +0000 Subject: [PATCH 2/2] update --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a05610aadbd..46163ff28439 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,7 +22,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `NodePropertySplit` transform for creating node-level splits using structural node properties ([#6894](https://github.com/pyg-team/pytorch_geometric/pull/6894)) - Added an option to preserve directed graphs in `CitationFull` datasets ([#7275](https://github.com/pyg-team/pytorch_geometric/pull/7275)) - Added support for `torch.sparse.Tensor` in `DataLoader` ([#7252](https://github.com/pyg-team/pytorch_geometric/pull/7252)) -- Added `save` and `load` methods to `InMemoryDataset` ([#7250](https://github.com/pyg-team/pytorch_geometric/pull/7250)) +- Added `save` and `load` methods to `InMemoryDataset` ([#7250](https://github.com/pyg-team/pytorch_geometric/pull/7250), [#7413](https://github.com/pyg-team/pytorch_geometric/pull/7413)) - Added an example for heterogeneous GNN explanation via `CaptumExplainer` ([#7096](https://github.com/pyg-team/pytorch_geometric/pull/7096)) - Added `visualize_feature_importance` functionality to `HeteroExplanation` ([#7096](https://github.com/pyg-team/pytorch_geometric/pull/7096)) - Added a `AddRemainingSelfLoops` transform ([#7192](https://github.com/pyg-team/pytorch_geometric/pull/7192))