diff --git a/CHANGELOG.md b/CHANGELOG.md index 8e352af49689..b270e7ad05ad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.0.5] - 2022-MM-DD ### Added +- Added `predict()` support to the `LightningNodeData` module ([#4884](https://github.com/pyg-team/pytorch_geometric/pull/4884)) - Added `time_attr` argument to `LinkNeighborLoader` ([#4877](https://github.com/pyg-team/pytorch_geometric/pull/4877)) - Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873)) - Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815)) diff --git a/torch_geometric/data/lightning_datamodule.py b/torch_geometric/data/lightning_datamodule.py index cafe28d4c77a..2c9f68ab018b 100644 --- a/torch_geometric/data/lightning_datamodule.py +++ b/torch_geometric/data/lightning_datamodule.py @@ -191,15 +191,30 @@ class LightningNodeData(LightningDataModule): data (Data or HeteroData): The :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` graph object. input_train_nodes (torch.Tensor or str or (str, torch.Tensor)): The - indices of training nodes. If not given, will try to automatically - infer them from the :obj:`data` object. (default: :obj:`None`) + indices of training nodes. + If not given, will try to automatically infer them from the + :obj:`data` object by searching for :obj:`train_mask`, + :obj:`train_idx`, or :obj:`train_index` attributes. + (default: :obj:`None`) input_val_nodes (torch.Tensor or str or (str, torch.Tensor)): The - indices of validation nodes. If not given, will try to - automatically infer them from the :obj:`data` object. + indices of validation nodes. + If not given, will try to automatically infer them from the + :obj:`data` object by searching for :obj:`val_mask`, + :obj:`valid_mask`, :obj:`val_idx`, :obj:`valid_idx`, + :obj:`val_index`, or :obj:`valid_index` attributes. (default: :obj:`None`) input_test_nodes (torch.Tensor or str or (str, torch.Tensor)): The - indices of test nodes. If not given, will try to automatically - infer them from the :obj:`data` object. (default: :obj:`None`) + indices of test nodes. + If not given, will try to automatically infer them from the + :obj:`data` object by searching for :obj:`test_mask`, + :obj:`test_idx`, or :obj:`test_index` attributes. + (default: :obj:`None`) + input_pred_nodes (torch.Tensor or str or (str, torch.Tensor)): The + indices of prediction nodes. + If not given, will try to automatically infer them from the + :obj:`data` object by searching for :obj:`pred_mask`, + :obj:`pred_idx`, or :obj:`pred_index` attributes. + (default: :obj:`None`) loader (str): The scalability technique to use (:obj:`"full"`, :obj:`"neighbor"`). (default: :obj:`"neighbor"`) batch_size (int, optional): How many samples per batch to load. @@ -216,6 +231,7 @@ def __init__( input_train_nodes: InputNodes = None, input_val_nodes: InputNodes = None, input_test_nodes: InputNodes = None, + input_pred_nodes: InputNodes = None, loader: str = "neighbor", batch_size: int = 1, num_workers: int = 0, @@ -236,6 +252,9 @@ def __init__( if input_test_nodes is None: input_test_nodes = infer_input_nodes(data, split='test') + if input_pred_nodes is None: + input_pred_nodes = infer_input_nodes(data, split='pred') + if loader == 'full' and batch_size != 1: warnings.warn(f"Re-setting 'batch_size' to 1 in " f"'{self.__class__.__name__}' for loader='full' " @@ -279,6 +298,7 @@ def __init__( self.input_train_nodes = input_train_nodes self.input_val_nodes = input_val_nodes self.input_test_nodes = input_test_nodes + self.input_pred_nodes = input_pred_nodes def prepare_data(self): """""" @@ -323,6 +343,10 @@ def test_dataloader(self) -> DataLoader: """""" return self.dataloader(self.input_test_nodes, shuffle=False) + def predict_dataloader(self) -> DataLoader: + """""" + return self.dataloader(self.input_pred_nodes, shuffle=False) + def __repr__(self) -> str: kwargs = kwargs_repr(data=self.data, loader=self.loader, **self.kwargs) return f'{self.__class__.__name__}({kwargs})'