Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added force_reload option to Dataset and InMemoryDataset [1/n] #8352

Merged
merged 4 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `force_reload` option to `Dataset` and `InMemoryDataset` to reload datasets ([#8352](https://github.com/pyg-team/pytorch_geometric/pull/8352))
- Added support for `torch.compile` in `MultiAggregation` ([#8345](https://github.com/pyg-team/pytorch_geometric/pull/8345))
- Added support for `torch.compile` in `HeteroConv` ([#8344](https://github.com/pyg-team/pytorch_geometric/pull/8344))
- Added support for weighted `sparse_cross_entropy` ([#8340](https://github.com/pyg-team/pytorch_geometric/pull/8340))
Expand Down
18 changes: 11 additions & 7 deletions torch_geometric/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class Dataset(torch.utils.data.Dataset, ABC):
included in the final dataset. (default: :obj:`None`)
log (bool, optional): Whether to print any console output while
downloading and processing the dataset. (default: :obj:`True`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
"""
@property
def raw_file_names(self) -> Union[str, List[str], Tuple]:
Expand Down Expand Up @@ -84,6 +86,7 @@ def __init__(
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None,
log: bool = True,
force_reload: bool = False,
):
super().__init__()

Expand All @@ -96,6 +99,7 @@ def __init__(
self.pre_filter = pre_filter
self.log = log
self._indices: Optional[Sequence] = None
self.force_reload = force_reload

if self.has_download:
self._download()
Expand Down Expand Up @@ -217,20 +221,20 @@ def _process(self):
f = osp.join(self.processed_dir, 'pre_transform.pt')
if osp.exists(f) and torch.load(f) != _repr(self.pre_transform):
warnings.warn(
f"The `pre_transform` argument differs from the one used in "
f"the pre-processed version of this dataset. If you want to "
f"make use of another pre-processing technique, make sure to "
f"delete '{self.processed_dir}' first")
"The `pre_transform` argument differs from the one used in "
"the pre-processed version of this dataset. If you want to "
"make use of another pre-processing technique, pass "
"`force_reload=True` explicitly to reload the dataset.")

f = osp.join(self.processed_dir, 'pre_filter.pt')
if osp.exists(f) and torch.load(f) != _repr(self.pre_filter):
warnings.warn(
"The `pre_filter` argument differs from the one used in "
"the pre-processed version of this dataset. If you want to "
"make use of another pre-fitering technique, make sure to "
"delete '{self.processed_dir}' first")
"make use of another pre-fitering technique, pass "
"`force_reload=True` explicitly to reload the dataset.")

if files_exist(self.processed_paths): # pragma: no cover
if not self.force_reload and files_exist(self.processed_paths):
return

if self.log and 'pytest' not in sys.modules:
Expand Down
6 changes: 5 additions & 1 deletion torch_geometric/data/in_memory_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class InMemoryDataset(Dataset, ABC):
included in the final dataset. (default: :obj:`None`)
log (bool, optional): Whether to print any console output while
downloading and processing the dataset. (default: :obj:`True`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
"""
@property
def raw_file_names(self) -> Union[str, List[str], Tuple]:
Expand All @@ -72,8 +74,10 @@ def __init__(
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None,
log: bool = True,
force_reload: bool = False,
):
super().__init__(root, transform, pre_transform, pre_filter, log)
super().__init__(root, transform, pre_transform, pre_filter, log,
force_reload)
self._data = None
self.slices = None
self._data_list: Optional[List[BaseData]] = None
Expand Down
14 changes: 11 additions & 3 deletions torch_geometric/datasets/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class Actor(InMemoryDataset):
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)

**STATS:**

Expand All @@ -47,9 +49,15 @@ class Actor(InMemoryDataset):

url = 'https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/master'

def __init__(self, root: str, transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None):
super().__init__(root, transform, pre_transform)
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
force_reload: bool = False,
):
super().__init__(root, transform, pre_transform,
force_reload=force_reload)
self.load(self.processed_paths[0])

@property
Expand Down
Loading