Skip to content

Commit

Permalink
Fix loading of InMemoryDataset in case a pre_transform modifies t…
Browse files Browse the repository at this point in the history
…he underlying `Data` class (#8692)

Fixes #8688
  • Loading branch information
rusty1s authored Dec 30, 2023
1 parent 5a5950c commit d1f305b
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed a bug in which `InMemoryDataset` did not reconstruct the correct data class when a `pre_transform` has modified it ([#8692](https://github.com/pyg-team/pytorch_geometric/pull/8692))
- Fixed a bug in which transforms were not applied for `OnDiskDataset` ([#8663](https://github.com/pyg-team/pytorch_geometric/pull/8663))
- Fixed mini-batch computation in `DMoNPooing` loss function ([#8285](https://github.com/pyg-team/pytorch_geometric/pull/8285))
- Fixed `NaN` handling in `SQLDatabase` ([#8479](https://github.com/pyg-team/pytorch_geometric/pull/8479))
Expand Down
18 changes: 13 additions & 5 deletions torch_geometric/data/in_memory_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,22 @@ def get(self, idx: int) -> BaseData:
def save(cls, data_list: Sequence[BaseData], path: str) -> None:
r"""Saves a list of data objects to the file path :obj:`path`."""
data, slices = cls.collate(data_list)
fs.torch_save((data.to_dict(), slices), path)
fs.torch_save((data.to_dict(), slices, data.__class__), path)

def load(self, path: str, data_cls: Type[BaseData] = Data) -> None:
r"""Loads the dataset from the file path :obj:`path`."""
data, self.slices = fs.torch_load(path)
if isinstance(data, dict): # Backward compatibility.
data = data_cls.from_dict(data)
self.data = data
out = fs.torch_load(path)
assert isinstance(out, tuple)
assert len(out) == 2 or len(out) == 3
if len(out) == 2: # Backward compatibility.
data, self.slices = out
else:
data, self.slices, data_cls = out

if not isinstance(data, dict): # Backward compatibility.
self.data = data
else:
self.data = data_cls.from_dict(data)

@staticmethod
def collate(
Expand Down
22 changes: 17 additions & 5 deletions torch_geometric/datasets/tu_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,24 @@ def __init__(
force_reload=force_reload)

out = fs.torch_load(self.processed_paths[0])
if not isinstance(out, tuple) or len(out) != 3:
if not isinstance(out, tuple) or len(out) < 3:
raise RuntimeError(
"The 'data' object was created by an older version of PyG. "
"If this error occurred while loading an already existing "
"dataset, remove the 'processed/' directory in the dataset's "
"root folder and try again.")
data, self.slices, self.sizes = out
self.data = Data.from_dict(data) if isinstance(data, dict) else data
assert len(out) == 3 or len(out) == 4

if len(out) == 3: # Backward compatibility.
data, self.slices, self.sizes = out
data_cls = Data
else:
data, self.slices, self.sizes, data_cls = out

if not isinstance(data, dict): # Backward compatibility.
self.data = data
else:
self.data = data_cls.from_dict(data)

assert isinstance(self._data, Data)
if self._data.x is not None and not use_node_attr:
Expand Down Expand Up @@ -205,8 +215,10 @@ def process(self) -> None:
self._data_list = None # Reset cache.

assert isinstance(self._data, Data)
fs.torch_save((self._data.to_dict(), self.slices, sizes),
self.processed_paths[0])
fs.torch_save(
(self._data.to_dict(), self.slices, sizes, self._data.__class__),
self.processed_paths[0],
)

def __repr__(self) -> str:
return f'{self.name}({len(self)})'

0 comments on commit d1f305b

Please sign in to comment.