From 5cad77bfcec35fd2a85d758891b61d84d14c0512 Mon Sep 17 00:00:00 2001 From: Amitabha Roy Date: Wed, 22 Jun 2022 18:47:28 +0000 Subject: [PATCH 01/16] First cut changes. --- torch_geometric/data/lightning_datamodule.py | 28 ++++++++++++++++--- .../loader/link_neighbor_loader.py | 1 - 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/torch_geometric/data/lightning_datamodule.py b/torch_geometric/data/lightning_datamodule.py index cafe28d4c77a..001b53b96fff 100644 --- a/torch_geometric/data/lightning_datamodule.py +++ b/torch_geometric/data/lightning_datamodule.py @@ -4,6 +4,7 @@ import torch from torch_geometric.data import Data, Dataset, HeteroData +from torch_geometric.loader import LinkNeighborLoader from torch_geometric.loader.dataloader import DataLoader from torch_geometric.loader.neighbor_loader import ( NeighborLoader, @@ -222,19 +223,31 @@ def __init__( **kwargs, ): - assert loader in ['full', 'neighbor'] + assert loader in ['full', 'neighbor', 'link_neighbor'] if input_train_nodes is None: - input_train_nodes = infer_input_nodes(data, split='train') + if loader == 'link_neighbor': + input_train_nodes = kwargs['edge_label_index_train'] + del kwargs['edge_label_index_train'] + else: + input_train_nodes = infer_input_nodes(data, split='train') if input_val_nodes is None: - input_val_nodes = infer_input_nodes(data, split='val') + if loader == 'link_neighbor': + input_val_nodes = kwargs['edge_label_index_val'] + del kwargs['edge_label_index_val'] + else: + input_val_nodes = infer_input_nodes(data, split='val') if input_val_nodes is None: input_val_nodes = infer_input_nodes(data, split='valid') if input_test_nodes is None: - input_test_nodes = infer_input_nodes(data, split='test') + if loader == 'link_neighbor': + input_test_nodes = kwargs['edge_label_index_test'] + del kwargs['edge_label_index_test'] + else: + input_test_nodes = infer_input_nodes(data, split='test') if loader == 'full' and batch_size != 1: warnings.warn(f"Re-setting 'batch_size' to 1 in " @@ -309,6 +322,11 @@ def dataloader(self, input_nodes: InputNodes, shuffle: bool) -> DataLoader: neighbor_sampler=self.neighbor_sampler, shuffle=shuffle, **self.kwargs) + if self.loader == 'link_neighbor': + return LinkNeighborLoader(data=self.data, + edge_label_index=input_nodes, + shuffle=shuffle, **self.kwargs) + raise NotImplementedError def train_dataloader(self) -> DataLoader: @@ -341,6 +359,7 @@ def infer_input_nodes(data: Union[Data, HeteroData], split: str) -> InputNodes: attr_name = f'{split}_index' if attr_name is None: + print(f'Returning none for split {split}') return None if isinstance(data, Data): @@ -351,6 +370,7 @@ def infer_input_nodes(data: Union[Data, HeteroData], split: str) -> InputNodes: raise ValueError(f"Could not automatically determine the input " f"nodes of {data} since there exists multiple " f"types with attribute '{attr_name}'") + print(f'located data for split {split}') return list(input_nodes_dict.items())[0] return None diff --git a/torch_geometric/loader/link_neighbor_loader.py b/torch_geometric/loader/link_neighbor_loader.py index e8162ea66fd9..f4673a0aee5f 100644 --- a/torch_geometric/loader/link_neighbor_loader.py +++ b/torch_geometric/loader/link_neighbor_loader.py @@ -348,7 +348,6 @@ def get_edge_label_index( edge_type, edge_label_index = edge_label_index edge_type = data._to_canonical(*edge_type) - assert edge_type in data.edge_types if edge_label_index is None: return edge_type, data[edge_type].edge_index From d6bfba42ce2f6f2b63849a3c849baffe5a8b7e78 Mon Sep 17 00:00:00 2001 From: Amitabha Roy Date: Thu, 23 Jun 2022 17:55:58 +0000 Subject: [PATCH 02/16] Include batch size in the data object returned from the sampler. --- torch_geometric/loader/link_neighbor_loader.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torch_geometric/loader/link_neighbor_loader.py b/torch_geometric/loader/link_neighbor_loader.py index f4673a0aee5f..908da6ce0183 100644 --- a/torch_geometric/loader/link_neighbor_loader.py +++ b/torch_geometric/loader/link_neighbor_loader.py @@ -80,7 +80,8 @@ def __call__(self, query: List[Tuple[Tensor]]): self.directed, ) - return node, row, col, edge, edge_label_index, edge_label + return (node, row, col, edge, edge_label_index, edge_label, + edge_label_index.shape[1]) elif issubclass(self.data_cls, HeteroData): sample_fn = torch.ops.torch_sparse.hetero_neighbor_sample @@ -114,7 +115,7 @@ def __call__(self, query: List[Tuple[Tensor]]): ) return (node_dict, row_dict, col_dict, edge_dict, edge_label_index, - edge_label) + edge_label, edge_label_index.shape[1]) class LinkNeighborLoader(torch.utils.data.DataLoader): @@ -279,21 +280,24 @@ def __init__( def transform_fn(self, out: Any) -> Union[Data, HeteroData]: if isinstance(self.data, Data): - node, row, col, edge, edge_label_index, edge_label = out + (node, row, col, edge, edge_label_index, edge_label, + batch_size) = out data = filter_data(self.data, node, row, col, edge, self.neighbor_sampler.perm) data.edge_label_index = edge_label_index + data.batch_size = batch_size if edge_label is not None: data.edge_label = edge_label elif isinstance(self.data, HeteroData): (node_dict, row_dict, col_dict, edge_dict, edge_label_index, - edge_label) = out + edge_label, batch_size) = out data = filter_hetero_data(self.data, node_dict, row_dict, col_dict, edge_dict, self.neighbor_sampler.perm_dict) edge_type = self.neighbor_sampler.input_type data[edge_type].edge_label_index = edge_label_index + data[edge_type].batch_size = batch_size if edge_label is not None: data[edge_type].edge_label = edge_label From 657444e71cd2d3a7f6cbdd3c49fdf0884f59deb4 Mon Sep 17 00:00:00 2001 From: Amitabha Roy Date: Mon, 27 Jun 2022 00:12:07 +0000 Subject: [PATCH 03/16] No need to return batch size. --- torch_geometric/data/lightning_datamodule.py | 66 ++++++++----------- .../loader/link_neighbor_loader.py | 14 ++-- 2 files changed, 33 insertions(+), 47 deletions(-) diff --git a/torch_geometric/data/lightning_datamodule.py b/torch_geometric/data/lightning_datamodule.py index 001b53b96fff..153ce8b5103e 100644 --- a/torch_geometric/data/lightning_datamodule.py +++ b/torch_geometric/data/lightning_datamodule.py @@ -1,4 +1,5 @@ import warnings +from multiprocessing.sharedctypes import Value from typing import Optional, Union import torch @@ -11,7 +12,7 @@ NeighborSampler, get_input_nodes, ) -from torch_geometric.typing import InputNodes +from torch_geometric.typing import InputEdges, InputNodes try: from pytorch_lightning import LightningDataModule as PLLightningDataModule @@ -169,7 +170,8 @@ class LightningNodeData(LightningDataModule): automatically used as a :obj:`datamodule` for multi-GPU node-level training via `PyTorch Lightning `_. :class:`LightningDataset` will take care of providing mini-batches via - :class:`~torch_geometric.loader.NeighborLoader`. + :class:`~torch_geometric.loader.NeighborLoader` or + :class:`~torch_geometric.loader.LinkNeighborLoader`. .. note:: @@ -192,17 +194,20 @@ class LightningNodeData(LightningDataModule): data (Data or HeteroData): The :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` graph object. input_train_nodes (torch.Tensor or str or (str, torch.Tensor)): The - indices of training nodes. If not given, will try to automatically - infer them from the :obj:`data` object. (default: :obj:`None`) + indices of training nodes. If not given, and loader is not + :obj:`"link_neighbor"`, will try to automatically infer them from + the :obj:`data` object. (default: :obj:`None`) input_val_nodes (torch.Tensor or str or (str, torch.Tensor)): The - indices of validation nodes. If not given, will try to - automatically infer them from the :obj:`data` object. - (default: :obj:`None`) + indices of validation nodes. If not given, and loader is not + :obj:`"link_neighbor"`, will try to automatically infer them from + the :obj:`data` object. (default: :obj:`None`) input_test_nodes (torch.Tensor or str or (str, torch.Tensor)): The - indices of test nodes. If not given, will try to automatically - infer them from the :obj:`data` object. (default: :obj:`None`) + indices of test nodes. If not given, If not given, and loader is + not :obj:`"link_neighbor"`, will try to automatically infer them + from the :obj:`data` object. (default: :obj:`None`) loader (str): The scalability technique to use (:obj:`"full"`, - :obj:`"neighbor"`). (default: :obj:`"neighbor"`) + :obj:`"neighbor"`, :obj:`"link_neighbor"`). + (default: :obj:`"neighbor"`) batch_size (int, optional): How many samples per batch to load. (default: :obj:`1`) num_workers: How many subprocesses to use for data loading. @@ -214,9 +219,9 @@ class LightningNodeData(LightningDataModule): def __init__( self, data: Union[Data, HeteroData], - input_train_nodes: InputNodes = None, - input_val_nodes: InputNodes = None, - input_test_nodes: InputNodes = None, + input_train_nodes: Union[InputEdges, InputNodes] = None, + input_val_nodes: Union[InputEdges, InputNodes] = None, + input_test_nodes: Union[InputEdges, InputNodes] = None, loader: str = "neighbor", batch_size: int = 1, num_workers: int = 0, @@ -225,29 +230,16 @@ def __init__( assert loader in ['full', 'neighbor', 'link_neighbor'] - if input_train_nodes is None: - if loader == 'link_neighbor': - input_train_nodes = kwargs['edge_label_index_train'] - del kwargs['edge_label_index_train'] - else: - input_train_nodes = infer_input_nodes(data, split='train') - - if input_val_nodes is None: - if loader == 'link_neighbor': - input_val_nodes = kwargs['edge_label_index_val'] - del kwargs['edge_label_index_val'] - else: - input_val_nodes = infer_input_nodes(data, split='val') - - if input_val_nodes is None: - input_val_nodes = infer_input_nodes(data, split='valid') - - if input_test_nodes is None: - if loader == 'link_neighbor': - input_test_nodes = kwargs['edge_label_index_test'] - del kwargs['edge_label_index_test'] - else: - input_test_nodes = infer_input_nodes(data, split='test') + if input_train_nodes is None and loader != 'link_neighbor': + input_train_nodes = infer_input_nodes(data, split='train') + + if input_val_nodes is None and loader != 'link_neighbor': + input_val_nodes = kwargs.pop('edge_label_index_val', None) + if input_val_nodes is None: + input_val_nodes = infer_input_nodes(data, split='valid') + + if input_test_nodes is None and loader != 'link_neighbor': + input_test_nodes = infer_input_nodes(data, split='test') if loader == 'full' and batch_size != 1: warnings.warn(f"Re-setting 'batch_size' to 1 in " @@ -359,7 +351,6 @@ def infer_input_nodes(data: Union[Data, HeteroData], split: str) -> InputNodes: attr_name = f'{split}_index' if attr_name is None: - print(f'Returning none for split {split}') return None if isinstance(data, Data): @@ -370,7 +361,6 @@ def infer_input_nodes(data: Union[Data, HeteroData], split: str) -> InputNodes: raise ValueError(f"Could not automatically determine the input " f"nodes of {data} since there exists multiple " f"types with attribute '{attr_name}'") - print(f'located data for split {split}') return list(input_nodes_dict.items())[0] return None diff --git a/torch_geometric/loader/link_neighbor_loader.py b/torch_geometric/loader/link_neighbor_loader.py index 50aaf2e5a408..834515419296 100644 --- a/torch_geometric/loader/link_neighbor_loader.py +++ b/torch_geometric/loader/link_neighbor_loader.py @@ -80,8 +80,7 @@ def __call__(self, query: List[Tuple[Tensor]]): self.directed, ) - return (node, row, col, edge, edge_label_index, edge_label, - edge_label_index.shape[1]) + return (node, row, col, edge, edge_label_index, edge_label) elif issubclass(self.data_cls, HeteroData): sample_fn = torch.ops.torch_sparse.hetero_neighbor_sample @@ -115,7 +114,7 @@ def __call__(self, query: List[Tuple[Tensor]]): ) return (node_dict, row_dict, col_dict, edge_dict, edge_label_index, - edge_label, edge_label_index.shape[1]) + edge_label) class LinkNeighborLoader(torch.utils.data.DataLoader): @@ -133,7 +132,7 @@ class LinkNeighborLoader(torch.utils.data.DataLoader): .. code-block:: python from torch_geometric.datasets import Planetoid - from torch_geometric.loader import NeighborLoader + from torch_geometric.loader import LinkNeighborLoader data = Planetoid(path, name='Cora')[0] @@ -281,24 +280,21 @@ def __init__( def transform_fn(self, out: Any) -> Union[Data, HeteroData]: # NOTE This function will always be executed on the main thread! if isinstance(self.data, Data): - (node, row, col, edge, edge_label_index, edge_label, - batch_size) = out + (node, row, col, edge, edge_label_index, edge_label) = out data = filter_data(self.data, node, row, col, edge, self.neighbor_sampler.perm) data.edge_label_index = edge_label_index - data.batch_size = batch_size if edge_label is not None: data.edge_label = edge_label elif isinstance(self.data, HeteroData): (node_dict, row_dict, col_dict, edge_dict, edge_label_index, - edge_label, batch_size) = out + edge_label) = out data = filter_hetero_data(self.data, node_dict, row_dict, col_dict, edge_dict, self.neighbor_sampler.perm_dict) edge_type = self.neighbor_sampler.input_type data[edge_type].edge_label_index = edge_label_index - data[edge_type].batch_size = batch_size if edge_label is not None: data[edge_type].edge_label = edge_label From c089c8b5602e1bd98db9a15c99b5124f18db7ee9 Mon Sep 17 00:00:00 2001 From: Amitabha Roy Date: Mon, 27 Jun 2022 04:13:15 +0000 Subject: [PATCH 04/16] Remove unused import. --- torch_geometric/data/lightning_datamodule.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_geometric/data/lightning_datamodule.py b/torch_geometric/data/lightning_datamodule.py index 153ce8b5103e..79fa665f5b0c 100644 --- a/torch_geometric/data/lightning_datamodule.py +++ b/torch_geometric/data/lightning_datamodule.py @@ -1,5 +1,4 @@ import warnings -from multiprocessing.sharedctypes import Value from typing import Optional, Union import torch From 5bc64ef2235adc70db9a1fe207644de7f2ca306a Mon Sep 17 00:00:00 2001 From: Amitabha Roy Date: Mon, 27 Jun 2022 09:54:53 -0400 Subject: [PATCH 05/16] Refactor into a LightningLinkData module. --- torch_geometric/data/lightning_datamodule.py | 180 ++++++++++++++++--- 1 file changed, 156 insertions(+), 24 deletions(-) diff --git a/torch_geometric/data/lightning_datamodule.py b/torch_geometric/data/lightning_datamodule.py index 79fa665f5b0c..3a729a5ae657 100644 --- a/torch_geometric/data/lightning_datamodule.py +++ b/torch_geometric/data/lightning_datamodule.py @@ -169,8 +169,7 @@ class LightningNodeData(LightningDataModule): automatically used as a :obj:`datamodule` for multi-GPU node-level training via `PyTorch Lightning `_. :class:`LightningDataset` will take care of providing mini-batches via - :class:`~torch_geometric.loader.NeighborLoader` or - :class:`~torch_geometric.loader.LinkNeighborLoader`. + :class:`~torch_geometric.loader.NeighborLoader`. .. note:: @@ -193,20 +192,16 @@ class LightningNodeData(LightningDataModule): data (Data or HeteroData): The :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` graph object. input_train_nodes (torch.Tensor or str or (str, torch.Tensor)): The - indices of training nodes. If not given, and loader is not - :obj:`"link_neighbor"`, will try to automatically infer them from - the :obj:`data` object. (default: :obj:`None`) + indices of training nodes. If not given, will try to automatically + infer them from the :obj:`data` object. (default: :obj:`None`) input_val_nodes (torch.Tensor or str or (str, torch.Tensor)): The - indices of validation nodes. If not given, and loader is not - :obj:`"link_neighbor"`, will try to automatically infer them from - the :obj:`data` object. (default: :obj:`None`) + indices of validation nodes. If not given will try to automatically + infer them from the :obj:`data` object. (default: :obj:`None`) input_test_nodes (torch.Tensor or str or (str, torch.Tensor)): The - indices of test nodes. If not given, If not given, and loader is - not :obj:`"link_neighbor"`, will try to automatically infer them - from the :obj:`data` object. (default: :obj:`None`) + indices of test nodes. If not given, will try to automatically + infer them from the :obj:`data` object. (default: :obj:`None`) loader (str): The scalability technique to use (:obj:`"full"`, - :obj:`"neighbor"`, :obj:`"link_neighbor"`). - (default: :obj:`"neighbor"`) + :obj:`"neighbor"`). (default: :obj:`"neighbor"`) batch_size (int, optional): How many samples per batch to load. (default: :obj:`1`) num_workers: How many subprocesses to use for data loading. @@ -218,26 +213,26 @@ class LightningNodeData(LightningDataModule): def __init__( self, data: Union[Data, HeteroData], - input_train_nodes: Union[InputEdges, InputNodes] = None, - input_val_nodes: Union[InputEdges, InputNodes] = None, - input_test_nodes: Union[InputEdges, InputNodes] = None, + input_train_nodes: InputNodes = None, + input_val_nodes: InputNodes = None, + input_test_nodes: InputNodes = None, loader: str = "neighbor", batch_size: int = 1, num_workers: int = 0, **kwargs, ): - assert loader in ['full', 'neighbor', 'link_neighbor'] + assert loader in ['full', 'neighbor'] - if input_train_nodes is None and loader != 'link_neighbor': + if input_train_nodes is None: input_train_nodes = infer_input_nodes(data, split='train') - if input_val_nodes is None and loader != 'link_neighbor': + if input_val_nodes is None: input_val_nodes = kwargs.pop('edge_label_index_val', None) if input_val_nodes is None: input_val_nodes = infer_input_nodes(data, split='valid') - if input_test_nodes is None and loader != 'link_neighbor': + if input_test_nodes is None: input_test_nodes = infer_input_nodes(data, split='test') if loader == 'full' and batch_size != 1: @@ -313,24 +308,161 @@ def dataloader(self, input_nodes: InputNodes, shuffle: bool) -> DataLoader: neighbor_sampler=self.neighbor_sampler, shuffle=shuffle, **self.kwargs) + raise NotImplementedError + + def train_dataloader(self) -> DataLoader: + """""" + return self.dataloader(self.input_train_nodes, shuffle=True) + + def val_dataloader(self) -> DataLoader: + """""" + return self.dataloader(self.input_val_nodes, shuffle=False) + + def test_dataloader(self) -> DataLoader: + """""" + return self.dataloader(self.input_test_nodes, shuffle=False) + + def __repr__(self) -> str: + kwargs = kwargs_repr(data=self.data, loader=self.loader, **self.kwargs) + return f'{self.__class__.__name__}({kwargs})' + + +class LightningLinkData(LightningDataModule): + r"""Converts a :class:`~torch_geometric.data.Data` or + :class:`~torch_geometric.data.HeteroData` object into a + :class:`pytorch_lightning.LightningDataModule` variant, which can be + automatically used as a :obj:`datamodule` for multi-GPU link-level + training (such as for link prediction) via `PyTorch Lightning + `_. :class:`LightningDataset` will + take care of providing mini-batches via + :class:`~torch_geometric.loader.LinkNeighborLoader`. + + .. note:: + + Currently only the + :class:`pytorch_lightning.strategies.SingleDeviceStrategy` and + :class:`pytorch_lightning.strategies.DDPSpawnStrategy` training + strategies of `PyTorch Lightning + `__ are supported in order to correctly share data across + all devices/processes: + + .. code-block:: + + import pytorch_lightning as pl + trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu", + devices=4) + trainer.fit(model, datamodule) + + Args: + data (Data or HeteroData): The :class:`~torch_geometric.data.Data` or + :class:`~torch_geometric.data.HeteroData` graph object. + input_train_edges (torch.Tensor or str or (str, torch.Tensor)): The + training edges. (default: :obj:`None`) + input_val_edges (torch.Tensor or str or (str, torch.Tensor)): The + validation edges. (default: :obj:`None`) + input_test_nodes (torch.Tensor or str or (str, torch.Tensor)): The + test edges. (default: :obj:`None`) + loader (str): The scalability technique to use (:obj:`"full"`, + :obj:`"link_neighbor"`). (default: :obj:`"link_neighbor"`) + batch_size (int, optional): How many samples per batch to load. + (default: :obj:`1`) + num_workers: How many subprocesses to use for data loading. + :obj:`0` means that the data will be loaded in the main process. + (default: :obj:`0`) + **kwargs (optional): Additional arguments of + :class:`torch_geometric.loader.LinkNeighborLoader`. + """ + def __init__( + self, + data: Union[Data, HeteroData], + input_train_edges: InputEdges = None, + input_val_edges: InputEdges = None, + input_test_edges: InputEdges = None, + loader: str = "link_neighbor", + batch_size: int = 1, + num_workers: int = 0, + **kwargs, + ): + + assert loader in ['full', 'link_neighbor'] + + if loader == 'full' and batch_size != 1: + warnings.warn(f"Re-setting 'batch_size' to 1 in " + f"'{self.__class__.__name__}' for loader='full' " + f"(got '{batch_size}')") + batch_size = 1 + + if loader == 'full' and num_workers != 0: + warnings.warn(f"Re-setting 'num_workers' to 0 in " + f"'{self.__class__.__name__}' for loader='full' " + f"(got '{num_workers}')") + num_workers = 0 + + super().__init__( + has_val=input_val_edges is not None, + has_test=input_test_edges is not None, + batch_size=batch_size, + num_workers=num_workers, + **kwargs, + ) + + if loader == 'full': + if kwargs.get('pin_memory', False): + warnings.warn(f"Re-setting 'pin_memory' to 'False' in " + f"'{self.__class__.__name__}' for loader='full' " + f"(got 'True')") + self.kwargs['pin_memory'] = False + + self.data = data + self.loader = loader + + self.input_train_edges = input_train_edges + self.input_val_edges = input_val_edges + self.input_test_edges = input_test_edges + + def prepare_data(self): + """""" + if self.loader == 'full': + try: + num_devices = self.trainer.num_devices + except AttributeError: + # PyTorch Lightning < 1.6 backward compatibility: + num_devices = self.trainer.num_processes + num_devices = max(num_devices, self.trainer.num_gpus) + + if num_devices > 1: + raise ValueError( + f"'{self.__class__.__name__}' with loader='full' requires " + f"training on a single device") + super().prepare_data() + + def dataloader(self, input_edges: InputEdges, shuffle: bool) -> DataLoader: + if self.loader == 'full': + warnings.filterwarnings('ignore', '.*does not have many workers.*') + warnings.filterwarnings('ignore', '.*data loading bottlenecks.*') + return torch.utils.data.DataLoader([self.data], shuffle=False, + collate_fn=lambda xs: xs[0], + **self.kwargs) + if self.loader == 'link_neighbor': return LinkNeighborLoader(data=self.data, - edge_label_index=input_nodes, + edge_label_index=input_edges, shuffle=shuffle, **self.kwargs) raise NotImplementedError def train_dataloader(self) -> DataLoader: """""" - return self.dataloader(self.input_train_nodes, shuffle=True) + return self.dataloader(self.input_train_edges, shuffle=True) def val_dataloader(self) -> DataLoader: """""" - return self.dataloader(self.input_val_nodes, shuffle=False) + return self.dataloader(self.input_val_edges, shuffle=False) def test_dataloader(self) -> DataLoader: """""" - return self.dataloader(self.input_test_nodes, shuffle=False) + return self.dataloader(self.input_test_edges, shuffle=False) def __repr__(self) -> str: kwargs = kwargs_repr(data=self.data, loader=self.loader, **self.kwargs) From d5b99de5cd1e26355c325586af8952d43068af5a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 27 Jun 2022 13:57:00 +0000 Subject: [PATCH 06/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_geometric/data/lightning_datamodule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_geometric/data/lightning_datamodule.py b/torch_geometric/data/lightning_datamodule.py index 3a729a5ae657..c90e1439048a 100644 --- a/torch_geometric/data/lightning_datamodule.py +++ b/torch_geometric/data/lightning_datamodule.py @@ -332,8 +332,8 @@ class LightningLinkData(LightningDataModule): :class:`~torch_geometric.data.HeteroData` object into a :class:`pytorch_lightning.LightningDataModule` variant, which can be automatically used as a :obj:`datamodule` for multi-GPU link-level - training (such as for link prediction) via `PyTorch Lightning - `_. :class:`LightningDataset` will + training (such as for link prediction) via `PyTorch Lightning + `_. :class:`LightningDataset` will take care of providing mini-batches via :class:`~torch_geometric.loader.LinkNeighborLoader`. From ab86c1acf5c6e02ccaf847edc3cf52abd53d5955 Mon Sep 17 00:00:00 2001 From: Amitabha Roy Date: Mon, 27 Jun 2022 14:21:44 +0000 Subject: [PATCH 07/16] Fixes. --- torch_geometric/data/__init__.py | 21 ++++++++++++-------- torch_geometric/data/lightning_datamodule.py | 2 +- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/torch_geometric/data/__init__.py b/torch_geometric/data/__init__.py index d1a05a542058..aa5e90589fc7 100644 --- a/torch_geometric/data/__init__.py +++ b/torch_geometric/data/__init__.py @@ -4,7 +4,11 @@ from .batch import Batch from .dataset import Dataset from .in_memory_dataset import InMemoryDataset -from .lightning_datamodule import LightningDataset, LightningNodeData +from .lightning_datamodule import ( + LightningDataset, + LightningLinkData, + LightningNodeData, +) from .makedirs import makedirs from .download import download_url from .extract import extract_tar, extract_zip, extract_bz2, extract_gz @@ -18,6 +22,7 @@ 'InMemoryDataset', 'LightningDataset', 'LightningNodeData', + 'LightningLinkData', 'makedirs', 'download_url', 'extract_tar', @@ -29,18 +34,18 @@ classes = __all__ from torch_geometric.deprecation import deprecated # noqa -from torch_geometric.loader import NeighborSampler # noqa from torch_geometric.loader import ClusterData # noqa from torch_geometric.loader import ClusterLoader # noqa -from torch_geometric.loader import GraphSAINTSampler # noqa -from torch_geometric.loader import GraphSAINTNodeSampler # noqa +from torch_geometric.loader import DataListLoader # noqa +from torch_geometric.loader import DataLoader # noqa +from torch_geometric.loader import DenseDataLoader # noqa from torch_geometric.loader import GraphSAINTEdgeSampler # noqa +from torch_geometric.loader import GraphSAINTNodeSampler # noqa from torch_geometric.loader import GraphSAINTRandomWalkSampler # noqa -from torch_geometric.loader import ShaDowKHopSampler # noqa +from torch_geometric.loader import GraphSAINTSampler # noqa +from torch_geometric.loader import NeighborSampler # noqa from torch_geometric.loader import RandomNodeSampler # noqa -from torch_geometric.loader import DataLoader # noqa -from torch_geometric.loader import DataListLoader # noqa -from torch_geometric.loader import DenseDataLoader # noqa +from torch_geometric.loader import ShaDowKHopSampler # noqa NeighborSampler = deprecated("use 'loader.NeighborSampler' instead", 'data.NeighborSampler')(NeighborSampler) diff --git a/torch_geometric/data/lightning_datamodule.py b/torch_geometric/data/lightning_datamodule.py index 3a729a5ae657..a50da6fd4d7f 100644 --- a/torch_geometric/data/lightning_datamodule.py +++ b/torch_geometric/data/lightning_datamodule.py @@ -228,7 +228,7 @@ def __init__( input_train_nodes = infer_input_nodes(data, split='train') if input_val_nodes is None: - input_val_nodes = kwargs.pop('edge_label_index_val', None) + input_val_nodes = infer_input_nodes(data, split='val') if input_val_nodes is None: input_val_nodes = infer_input_nodes(data, split='valid') From 8329fea12c83ce92ef8672e7e757c3d3fb0bdbe3 Mon Sep 17 00:00:00 2001 From: Amitabha Roy Date: Mon, 27 Jun 2022 14:54:33 +0000 Subject: [PATCH 08/16] Add changelog message and fix import order. --- CHANGELOG.md | 1 + torch_geometric/data/__init__.py | 14 +++++++------- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 18f00c8b36fd..e39612e29f14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.0.5] - 2022-MM-DD ### Added +- Added `LinkeNeighborLoader` support to lightning datamodule ([#4868](https://github.com/pyg-team/pytorch_geometric/pull/4868)) - Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815)) - Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857)) - Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847)) diff --git a/torch_geometric/data/__init__.py b/torch_geometric/data/__init__.py index aa5e90589fc7..d3d68baa18ab 100644 --- a/torch_geometric/data/__init__.py +++ b/torch_geometric/data/__init__.py @@ -34,18 +34,18 @@ classes = __all__ from torch_geometric.deprecation import deprecated # noqa +from torch_geometric.loader import NeighborSampler # noqa from torch_geometric.loader import ClusterData # noqa from torch_geometric.loader import ClusterLoader # noqa -from torch_geometric.loader import DataListLoader # noqa -from torch_geometric.loader import DataLoader # noqa -from torch_geometric.loader import DenseDataLoader # noqa -from torch_geometric.loader import GraphSAINTEdgeSampler # noqa +from torch_geometric.loader import GraphSAINTSampler # noqa from torch_geometric.loader import GraphSAINTNodeSampler # noqa +from torch_geometric.loader import GraphSAINTEdgeSampler # noqa from torch_geometric.loader import GraphSAINTRandomWalkSampler # noqa -from torch_geometric.loader import GraphSAINTSampler # noqa -from torch_geometric.loader import NeighborSampler # noqa -from torch_geometric.loader import RandomNodeSampler # noqa from torch_geometric.loader import ShaDowKHopSampler # noqa +from torch_geometric.loader import RandomNodeSampler # noqa +from torch_geometric.loader import DataLoader # noqa +from torch_geometric.loader import DataListLoader # noqa +from torch_geometric.loader import DenseDataLoader # noqa NeighborSampler = deprecated("use 'loader.NeighborSampler' instead", 'data.NeighborSampler')(NeighborSampler) From fd95b6b4cf931dc87c48552bf06358c6104daec3 Mon Sep 17 00:00:00 2001 From: Amitabha Roy Date: Mon, 27 Jun 2022 16:17:31 +0000 Subject: [PATCH 09/16] Add a link neighbor loader test. --- test/data/test_lightning_datamodule.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/test/data/test_lightning_datamodule.py b/test/data/test_lightning_datamodule.py index 21423538365f..7601aa0cd301 100644 --- a/test/data/test_lightning_datamodule.py +++ b/test/data/test_lightning_datamodule.py @@ -4,9 +4,14 @@ import torch import torch.nn.functional as F -from torch_geometric.data import LightningDataset, LightningNodeData +from torch_geometric.data import ( + LightningDataset, + LightningLinkData, + LightningNodeData, +) from torch_geometric.nn import global_mean_pool from torch_geometric.testing import onlyFullTest, withCUDA, withPackage +from torch_geometric.typing import EdgeType try: from pytorch_lightning import LightningModule @@ -264,3 +269,22 @@ def test_lightning_hetero_node_data(get_dataset): offset += 5 * devices * math.ceil(400 / (devices * 32)) # `train` offset += 5 * devices * math.ceil(400 / (devices * 32)) # `val` assert torch.all(new_x > (old_x + offset - 4)) # Ensure shared data. + + +@withCUDA +@onlyFullTest +@withPackage('pytorch_lightning') +def test_lightning_hetero_link_data(get_dataset): + import pytorch_lightning as pl + + dataset = get_dataset(name='DBLP') + data = dataset[0] + datamodule = LightningLinkData(data, loader='link_neighbor', + num_neighbors=[5], batch_size=32, + num_workers=3) + input_edges = (('author', 'dummy', 'paper'), data['author', + 'paper']['edge_index']) + loader = datamodule.dataloader(input_edges=input_edges, shuffle=True) + batch = next(iter(loader)) + assert (batch['author', 'dummy', + 'paper']['edge_label_index'].shape[1] == 32) From 9b55dbd648c3da1bb31de767311f2349580e23c0 Mon Sep 17 00:00:00 2001 From: Amitabha Roy Date: Wed, 29 Jun 2022 10:55:13 -0400 Subject: [PATCH 10/16] Update torch_geometric/data/lightning_datamodule.py Co-authored-by: Jinu Sunil --- torch_geometric/data/lightning_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/data/lightning_datamodule.py b/torch_geometric/data/lightning_datamodule.py index 0fbfa654f9b0..4a547d202964 100644 --- a/torch_geometric/data/lightning_datamodule.py +++ b/torch_geometric/data/lightning_datamodule.py @@ -361,7 +361,7 @@ class LightningLinkData(LightningDataModule): training edges. (default: :obj:`None`) input_val_edges (torch.Tensor or str or (str, torch.Tensor)): The validation edges. (default: :obj:`None`) - input_test_nodes (torch.Tensor or str or (str, torch.Tensor)): The + input_test_edges (torch.Tensor or str or (str, torch.Tensor)): The test edges. (default: :obj:`None`) loader (str): The scalability technique to use (:obj:`"full"`, :obj:`"link_neighbor"`). (default: :obj:`"link_neighbor"`) From 860b815489eb6acc71e3ce122b31f083a6918eae Mon Sep 17 00:00:00 2001 From: Amitabha Roy Date: Wed, 29 Jun 2022 10:55:46 -0400 Subject: [PATCH 11/16] Update torch_geometric/data/lightning_datamodule.py Co-authored-by: Jinu Sunil --- torch_geometric/data/lightning_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/data/lightning_datamodule.py b/torch_geometric/data/lightning_datamodule.py index 4a547d202964..d0fc9d194d0f 100644 --- a/torch_geometric/data/lightning_datamodule.py +++ b/torch_geometric/data/lightning_datamodule.py @@ -357,7 +357,7 @@ class LightningLinkData(LightningDataModule): Args: data (Data or HeteroData): The :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` graph object. - input_train_edges (torch.Tensor or str or (str, torch.Tensor)): The + input_train_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]): The training edges. (default: :obj:`None`) input_val_edges (torch.Tensor or str or (str, torch.Tensor)): The validation edges. (default: :obj:`None`) From 9e68fefdaa7bcee0b2230f3bb14cba3edb883014 Mon Sep 17 00:00:00 2001 From: Amitabha Roy Date: Wed, 29 Jun 2022 15:07:10 +0000 Subject: [PATCH 12/16] Address review comments. --- test/data/test_lightning_datamodule.py | 1 + torch_geometric/data/lightning_datamodule.py | 16 +++++++++------- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/test/data/test_lightning_datamodule.py b/test/data/test_lightning_datamodule.py index 7601aa0cd301..6e3f7f3f4d0d 100644 --- a/test/data/test_lightning_datamodule.py +++ b/test/data/test_lightning_datamodule.py @@ -277,6 +277,7 @@ def test_lightning_hetero_node_data(get_dataset): def test_lightning_hetero_link_data(get_dataset): import pytorch_lightning as pl + # TODO: Add more datasets. dataset = get_dataset(name='DBLP') data = dataset[0] datamodule = LightningLinkData(data, loader='link_neighbor', diff --git a/torch_geometric/data/lightning_datamodule.py b/torch_geometric/data/lightning_datamodule.py index d0fc9d194d0f..106a90ad2b56 100644 --- a/torch_geometric/data/lightning_datamodule.py +++ b/torch_geometric/data/lightning_datamodule.py @@ -327,6 +327,7 @@ def __repr__(self) -> str: return f'{self.__class__.__name__}({kwargs})' +# TODO: Unify implementation with LightningNodeData via a common base class. class LightningLinkData(LightningDataModule): r"""Converts a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object into a @@ -357,12 +358,12 @@ class LightningLinkData(LightningDataModule): Args: data (Data or HeteroData): The :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` graph object. - input_train_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]): The - training edges. (default: :obj:`None`) - input_val_edges (torch.Tensor or str or (str, torch.Tensor)): The - validation edges. (default: :obj:`None`) - input_test_edges (torch.Tensor or str or (str, torch.Tensor)): The - test edges. (default: :obj:`None`) + input_train_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor], + optional): The training edges. (default: :obj:`None`) + input_val_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor], + optional): The validation edges. (default: :obj:`None`) + input_test_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor], + optional): The test edges. (default: :obj:`None`) loader (str): The scalability technique to use (:obj:`"full"`, :obj:`"link_neighbor"`). (default: :obj:`"link_neighbor"`) batch_size (int, optional): How many samples per batch to load. @@ -386,7 +387,8 @@ def __init__( ): assert loader in ['full', 'link_neighbor'] - + # TODO: Handle or document behavior where none of train, val, test + # edges are specified. if loader == 'full' and batch_size != 1: warnings.warn(f"Re-setting 'batch_size' to 1 in " f"'{self.__class__.__name__}' for loader='full' " From c52b35efe9276bf25d5ca6a854fd23127f8bd27f Mon Sep 17 00:00:00 2001 From: Amitabha Roy Date: Wed, 29 Jun 2022 15:26:23 +0000 Subject: [PATCH 13/16] Address review comments. --- test/data/test_lightning_datamodule.py | 3 ++- torch_geometric/data/lightning_datamodule.py | 24 ++++++++++++++++---- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/test/data/test_lightning_datamodule.py b/test/data/test_lightning_datamodule.py index 6e3f7f3f4d0d..efcec7d261f9 100644 --- a/test/data/test_lightning_datamodule.py +++ b/test/data/test_lightning_datamodule.py @@ -285,7 +285,8 @@ def test_lightning_hetero_link_data(get_dataset): num_workers=3) input_edges = (('author', 'dummy', 'paper'), data['author', 'paper']['edge_index']) - loader = datamodule.dataloader(input_edges=input_edges, shuffle=True) + loader = datamodule.dataloader(input_edges=input_edges, input_labels=None, + shuffle=True) batch = next(iter(loader)) assert (batch['author', 'dummy', 'paper']['edge_label_index'].shape[1] == 32) diff --git a/torch_geometric/data/lightning_datamodule.py b/torch_geometric/data/lightning_datamodule.py index 106a90ad2b56..6a693b262781 100644 --- a/torch_geometric/data/lightning_datamodule.py +++ b/torch_geometric/data/lightning_datamodule.py @@ -360,10 +360,13 @@ class LightningLinkData(LightningDataModule): :class:`~torch_geometric.data.HeteroData` graph object. input_train_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor], optional): The training edges. (default: :obj:`None`) + input_train_edge_label (Tensor): The labels of train edge indices. input_val_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor], optional): The validation edges. (default: :obj:`None`) + input_val_edge_label (Tensor): The labels of val edge indices. input_test_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor], optional): The test edges. (default: :obj:`None`) + input_val_edge_label (Tensor): The labels of train edge indices. loader (str): The scalability technique to use (:obj:`"full"`, :obj:`"link_neighbor"`). (default: :obj:`"link_neighbor"`) batch_size (int, optional): How many samples per batch to load. @@ -378,8 +381,11 @@ def __init__( self, data: Union[Data, HeteroData], input_train_edges: InputEdges = None, + input_train_edge_label: torch.Tensor = None, input_val_edges: InputEdges = None, + input_val_edge_label: torch.Tensor = None, input_test_edges: InputEdges = None, + input_test_edge_label: torch.Tensor = None, loader: str = "link_neighbor", batch_size: int = 1, num_workers: int = 0, @@ -420,8 +426,11 @@ def __init__( self.loader = loader self.input_train_edges = input_train_edges + self.input_train_edge_label = input_train_edge_label self.input_val_edges = input_val_edges + self.input_val_edge_label = input_val_edge_label self.input_test_edges = input_test_edges + self.input_test_edge_label = input_test_edge_label def prepare_data(self): """""" @@ -439,7 +448,8 @@ def prepare_data(self): f"training on a single device") super().prepare_data() - def dataloader(self, input_edges: InputEdges, shuffle: bool) -> DataLoader: + def dataloader(self, input_edges: InputEdges, input_labels: torch.Tensor, + shuffle: bool) -> DataLoader: if self.loader == 'full': warnings.filterwarnings('ignore', '.*does not have many workers.*') warnings.filterwarnings('ignore', '.*data loading bottlenecks.*') @@ -450,21 +460,25 @@ def dataloader(self, input_edges: InputEdges, shuffle: bool) -> DataLoader: if self.loader == 'link_neighbor': return LinkNeighborLoader(data=self.data, edge_label_index=input_edges, - shuffle=shuffle, **self.kwargs) + edge_label=input_labels, shuffle=shuffle, + **self.kwargs) raise NotImplementedError def train_dataloader(self) -> DataLoader: """""" - return self.dataloader(self.input_train_edges, shuffle=True) + return self.dataloader(self.input_train_edges, + self.input_train_edge_label, shuffle=True) def val_dataloader(self) -> DataLoader: """""" - return self.dataloader(self.input_val_edges, shuffle=False) + return self.dataloader(self.input_val_edges, self.input_val_edge_label, + shuffle=False) def test_dataloader(self) -> DataLoader: """""" - return self.dataloader(self.input_test_edges, shuffle=False) + return self.dataloader(self.input_test_edges, + self.input_test_edge_label, shuffle=False) def __repr__(self) -> str: kwargs = kwargs_repr(data=self.data, loader=self.loader, **self.kwargs) From 804a2fec5edf0218aca5da819dd8e39e601fa3e7 Mon Sep 17 00:00:00 2001 From: Amitabha Roy Date: Wed, 29 Jun 2022 13:21:51 -0400 Subject: [PATCH 14/16] Fix docstring issues. --- torch_geometric/data/lightning_datamodule.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/torch_geometric/data/lightning_datamodule.py b/torch_geometric/data/lightning_datamodule.py index e764ec6abc04..c0993a261b2b 100644 --- a/torch_geometric/data/lightning_datamodule.py +++ b/torch_geometric/data/lightning_datamodule.py @@ -383,15 +383,15 @@ class LightningLinkData(LightningDataModule): Args: data (Data or HeteroData): The :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` graph object. - input_train_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor], - optional): The training edges. (default: :obj:`None`) - input_train_edge_label (Tensor): The labels of train edge indices. - input_val_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor], - optional): The validation edges. (default: :obj:`None`) - input_val_edge_label (Tensor): The labels of val edge indices. - input_test_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor], - optional): The test edges. (default: :obj:`None`) - input_val_edge_label (Tensor): The labels of train edge indices. + input_train_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor], optional): + The training edges. (default: :obj:`None`) + input_train_edge_label (Tensor, optional): The labels of train edge indices. + input_val_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor], optional): + The validation edges. (default: :obj:`None`) + input_val_edge_label (Tensor, optional): The labels of val edge indices. + input_test_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor], optional): + The test edges. (default: :obj:`None`) + input_test_edge_label (Tensor, optional): The labels of train edge indices. loader (str): The scalability technique to use (:obj:`"full"`, :obj:`"link_neighbor"`). (default: :obj:`"link_neighbor"`) batch_size (int, optional): How many samples per batch to load. From 5d96a57cac6a6fb542844fbcc64892d7e4e28f80 Mon Sep 17 00:00:00 2001 From: Amitabha Roy Date: Wed, 29 Jun 2022 13:24:58 -0400 Subject: [PATCH 15/16] Remove unused imports. --- test/data/test_lightning_datamodule.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/data/test_lightning_datamodule.py b/test/data/test_lightning_datamodule.py index efcec7d261f9..db9f6b546b30 100644 --- a/test/data/test_lightning_datamodule.py +++ b/test/data/test_lightning_datamodule.py @@ -11,7 +11,6 @@ ) from torch_geometric.nn import global_mean_pool from torch_geometric.testing import onlyFullTest, withCUDA, withPackage -from torch_geometric.typing import EdgeType try: from pytorch_lightning import LightningModule @@ -275,8 +274,6 @@ def test_lightning_hetero_node_data(get_dataset): @onlyFullTest @withPackage('pytorch_lightning') def test_lightning_hetero_link_data(get_dataset): - import pytorch_lightning as pl - # TODO: Add more datasets. dataset = get_dataset(name='DBLP') data = dataset[0] From 3e7c7e2ed275fa70e07f82d55587ac87c95b2d40 Mon Sep 17 00:00:00 2001 From: Amitabha Roy Date: Wed, 29 Jun 2022 13:43:49 -0400 Subject: [PATCH 16/16] Fix long lines. --- torch_geometric/data/lightning_datamodule.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/torch_geometric/data/lightning_datamodule.py b/torch_geometric/data/lightning_datamodule.py index c0993a261b2b..e2ac5039bd00 100644 --- a/torch_geometric/data/lightning_datamodule.py +++ b/torch_geometric/data/lightning_datamodule.py @@ -383,15 +383,18 @@ class LightningLinkData(LightningDataModule): Args: data (Data or HeteroData): The :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` graph object. - input_train_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor], optional): + input_train_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]): The training edges. (default: :obj:`None`) - input_train_edge_label (Tensor, optional): The labels of train edge indices. - input_val_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor], optional): + input_train_edge_label (Tensor, optional): + The labels of train edge indices. + input_val_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]): The validation edges. (default: :obj:`None`) - input_val_edge_label (Tensor, optional): The labels of val edge indices. - input_test_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor], optional): + input_val_edge_label (Tensor, optional): + The labels of val edge indices. + input_test_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]): The test edges. (default: :obj:`None`) - input_test_edge_label (Tensor, optional): The labels of train edge indices. + input_test_edge_label (Tensor, optional): + The labels of train edge indices. loader (str): The scalability technique to use (:obj:`"full"`, :obj:`"link_neighbor"`). (default: :obj:`"link_neighbor"`) batch_size (int, optional): How many samples per batch to load.