Skip to content

Commit

Permalink
Do not fill InMemoryDataset cache on dataset.num_features (#5264)
Browse files Browse the repository at this point in the history
* do not fill cache

* changelog
  • Loading branch information
rusty1s authored Aug 23, 2022
1 parent 65715da commit 7b6e199
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240))
- Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222))
### Changed
- Changed tests relying on `dblp` datasets to instead use synthetic data. ([#5250](https://github.com/pyg-team/pytorch_geometric/pull/5250))
- Fixed a bug for the initialization of activation function examples in `custom_graphgym`. ([#5243](https://github.com/pyg-team/pytorch_geometric/pull/5243))
- Do not fill `InMemoryDataset` cache on `dataset.num_features` ([#5264](https://github.com/pyg-team/pytorch_geometric/pull/5264))
- Changed tests relying on `dblp` datasets to instead use synthetic data ([#5250](https://github.com/pyg-team/pytorch_geometric/pull/5250))
- Fixed a bug for the initialization of activation function examples in `custom_graphgym` ([#5243](https://github.com/pyg-team/pytorch_geometric/pull/5243))
### Removed

## [2.1.0] - 2022-08-17
Expand Down
6 changes: 6 additions & 0 deletions torch_geometric/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def processed_dir(self) -> str:
def num_node_features(self) -> int:
r"""Returns the number of features per node in the dataset."""
data = self[0]
# Do not fill cache for `InMemoryDataset`:
if hasattr(self, '_data_list') and self._data_list is not None:
self._data_list[0] = None
data = data[0] if isinstance(data, tuple) else data
if hasattr(data, 'num_node_features'):
return data.num_node_features
Expand All @@ -117,6 +120,9 @@ def num_features(self) -> int:
def num_edge_features(self) -> int:
r"""Returns the number of features per edge in the dataset."""
data = self[0]
# Do not fill cache for `InMemoryDataset`:
if hasattr(self, '_data_list') and self._data_list is not None:
self._data_list[0] = None
data = data[0] if isinstance(data, tuple) else data
if hasattr(data, 'num_edge_features'):
return data.num_edge_features
Expand Down

0 comments on commit 7b6e199

Please sign in to comment.