Skip to content

Commit

Permalink
Add support for predict_dataloader in LightningNodeData (#4884)
Browse files Browse the repository at this point in the history
* Update lightning_datamodule.py

* Update lightning_datamodule.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

* update

* update

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored Jun 29, 2022
1 parent 9fc80f3 commit 65ab1e0
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added `predict()` support to the `LightningNodeData` module ([#4884](https://github.com/pyg-team/pytorch_geometric/pull/4884))
- Added `time_attr` argument to `LinkNeighborLoader` ([#4877](https://github.com/pyg-team/pytorch_geometric/pull/4877))
- Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873))
- Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815))
Expand Down
36 changes: 30 additions & 6 deletions torch_geometric/data/lightning_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,30 @@ class LightningNodeData(LightningDataModule):
data (Data or HeteroData): The :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` graph object.
input_train_nodes (torch.Tensor or str or (str, torch.Tensor)): The
indices of training nodes. If not given, will try to automatically
infer them from the :obj:`data` object. (default: :obj:`None`)
indices of training nodes.
If not given, will try to automatically infer them from the
:obj:`data` object by searching for :obj:`train_mask`,
:obj:`train_idx`, or :obj:`train_index` attributes.
(default: :obj:`None`)
input_val_nodes (torch.Tensor or str or (str, torch.Tensor)): The
indices of validation nodes. If not given, will try to
automatically infer them from the :obj:`data` object.
indices of validation nodes.
If not given, will try to automatically infer them from the
:obj:`data` object by searching for :obj:`val_mask`,
:obj:`valid_mask`, :obj:`val_idx`, :obj:`valid_idx`,
:obj:`val_index`, or :obj:`valid_index` attributes.
(default: :obj:`None`)
input_test_nodes (torch.Tensor or str or (str, torch.Tensor)): The
indices of test nodes. If not given, will try to automatically
infer them from the :obj:`data` object. (default: :obj:`None`)
indices of test nodes.
If not given, will try to automatically infer them from the
:obj:`data` object by searching for :obj:`test_mask`,
:obj:`test_idx`, or :obj:`test_index` attributes.
(default: :obj:`None`)
input_pred_nodes (torch.Tensor or str or (str, torch.Tensor)): The
indices of prediction nodes.
If not given, will try to automatically infer them from the
:obj:`data` object by searching for :obj:`pred_mask`,
:obj:`pred_idx`, or :obj:`pred_index` attributes.
(default: :obj:`None`)
loader (str): The scalability technique to use (:obj:`"full"`,
:obj:`"neighbor"`). (default: :obj:`"neighbor"`)
batch_size (int, optional): How many samples per batch to load.
Expand All @@ -216,6 +231,7 @@ def __init__(
input_train_nodes: InputNodes = None,
input_val_nodes: InputNodes = None,
input_test_nodes: InputNodes = None,
input_pred_nodes: InputNodes = None,
loader: str = "neighbor",
batch_size: int = 1,
num_workers: int = 0,
Expand All @@ -236,6 +252,9 @@ def __init__(
if input_test_nodes is None:
input_test_nodes = infer_input_nodes(data, split='test')

if input_pred_nodes is None:
input_pred_nodes = infer_input_nodes(data, split='pred')

if loader == 'full' and batch_size != 1:
warnings.warn(f"Re-setting 'batch_size' to 1 in "
f"'{self.__class__.__name__}' for loader='full' "
Expand Down Expand Up @@ -279,6 +298,7 @@ def __init__(
self.input_train_nodes = input_train_nodes
self.input_val_nodes = input_val_nodes
self.input_test_nodes = input_test_nodes
self.input_pred_nodes = input_pred_nodes

def prepare_data(self):
""""""
Expand Down Expand Up @@ -323,6 +343,10 @@ def test_dataloader(self) -> DataLoader:
""""""
return self.dataloader(self.input_test_nodes, shuffle=False)

def predict_dataloader(self) -> DataLoader:
""""""
return self.dataloader(self.input_pred_nodes, shuffle=False)

def __repr__(self) -> str:
kwargs = kwargs_repr(data=self.data, loader=self.loader, **self.kwargs)
return f'{self.__class__.__name__}({kwargs})'
Expand Down

0 comments on commit 65ab1e0

Please sign in to comment.