diff --git a/CHANGELOG.md b/CHANGELOG.md index b270e7ad05ad..54569d998a03 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.0.5] - 2022-MM-DD ### Added +- Added `LinkeNeighborLoader` support to lightning datamodule ([#4868](https://github.com/pyg-team/pytorch_geometric/pull/4868)) - Added `predict()` support to the `LightningNodeData` module ([#4884](https://github.com/pyg-team/pytorch_geometric/pull/4884)) - Added `time_attr` argument to `LinkNeighborLoader` ([#4877](https://github.com/pyg-team/pytorch_geometric/pull/4877)) - Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873)) diff --git a/test/data/test_lightning_datamodule.py b/test/data/test_lightning_datamodule.py index 21423538365f..db9f6b546b30 100644 --- a/test/data/test_lightning_datamodule.py +++ b/test/data/test_lightning_datamodule.py @@ -4,7 +4,11 @@ import torch import torch.nn.functional as F -from torch_geometric.data import LightningDataset, LightningNodeData +from torch_geometric.data import ( + LightningDataset, + LightningLinkData, + LightningNodeData, +) from torch_geometric.nn import global_mean_pool from torch_geometric.testing import onlyFullTest, withCUDA, withPackage @@ -264,3 +268,22 @@ def test_lightning_hetero_node_data(get_dataset): offset += 5 * devices * math.ceil(400 / (devices * 32)) # `train` offset += 5 * devices * math.ceil(400 / (devices * 32)) # `val` assert torch.all(new_x > (old_x + offset - 4)) # Ensure shared data. + + +@withCUDA +@onlyFullTest +@withPackage('pytorch_lightning') +def test_lightning_hetero_link_data(get_dataset): + # TODO: Add more datasets. + dataset = get_dataset(name='DBLP') + data = dataset[0] + datamodule = LightningLinkData(data, loader='link_neighbor', + num_neighbors=[5], batch_size=32, + num_workers=3) + input_edges = (('author', 'dummy', 'paper'), data['author', + 'paper']['edge_index']) + loader = datamodule.dataloader(input_edges=input_edges, input_labels=None, + shuffle=True) + batch = next(iter(loader)) + assert (batch['author', 'dummy', + 'paper']['edge_label_index'].shape[1] == 32) diff --git a/torch_geometric/data/__init__.py b/torch_geometric/data/__init__.py index d1a05a542058..d3d68baa18ab 100644 --- a/torch_geometric/data/__init__.py +++ b/torch_geometric/data/__init__.py @@ -4,7 +4,11 @@ from .batch import Batch from .dataset import Dataset from .in_memory_dataset import InMemoryDataset -from .lightning_datamodule import LightningDataset, LightningNodeData +from .lightning_datamodule import ( + LightningDataset, + LightningLinkData, + LightningNodeData, +) from .makedirs import makedirs from .download import download_url from .extract import extract_tar, extract_zip, extract_bz2, extract_gz @@ -18,6 +22,7 @@ 'InMemoryDataset', 'LightningDataset', 'LightningNodeData', + 'LightningLinkData', 'makedirs', 'download_url', 'extract_tar', diff --git a/torch_geometric/data/lightning_datamodule.py b/torch_geometric/data/lightning_datamodule.py index 2c9f68ab018b..e2ac5039bd00 100644 --- a/torch_geometric/data/lightning_datamodule.py +++ b/torch_geometric/data/lightning_datamodule.py @@ -4,13 +4,14 @@ import torch from torch_geometric.data import Data, Dataset, HeteroData +from torch_geometric.loader import LinkNeighborLoader from torch_geometric.loader.dataloader import DataLoader from torch_geometric.loader.neighbor_loader import ( NeighborLoader, NeighborSampler, get_input_nodes, ) -from torch_geometric.typing import InputNodes +from torch_geometric.typing import InputEdges, InputNodes try: from pytorch_lightning import LightningDataModule as PLLightningDataModule @@ -245,9 +246,8 @@ def __init__( if input_val_nodes is None: input_val_nodes = infer_input_nodes(data, split='val') - - if input_val_nodes is None: - input_val_nodes = infer_input_nodes(data, split='valid') + if input_val_nodes is None: + input_val_nodes = infer_input_nodes(data, split='valid') if input_test_nodes is None: input_test_nodes = infer_input_nodes(data, split='test') @@ -352,6 +352,167 @@ def __repr__(self) -> str: return f'{self.__class__.__name__}({kwargs})' +# TODO: Unify implementation with LightningNodeData via a common base class. +class LightningLinkData(LightningDataModule): + r"""Converts a :class:`~torch_geometric.data.Data` or + :class:`~torch_geometric.data.HeteroData` object into a + :class:`pytorch_lightning.LightningDataModule` variant, which can be + automatically used as a :obj:`datamodule` for multi-GPU link-level + training (such as for link prediction) via `PyTorch Lightning + `_. :class:`LightningDataset` will + take care of providing mini-batches via + :class:`~torch_geometric.loader.LinkNeighborLoader`. + + .. note:: + + Currently only the + :class:`pytorch_lightning.strategies.SingleDeviceStrategy` and + :class:`pytorch_lightning.strategies.DDPSpawnStrategy` training + strategies of `PyTorch Lightning + `__ are supported in order to correctly share data across + all devices/processes: + + .. code-block:: + + import pytorch_lightning as pl + trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu", + devices=4) + trainer.fit(model, datamodule) + + Args: + data (Data or HeteroData): The :class:`~torch_geometric.data.Data` or + :class:`~torch_geometric.data.HeteroData` graph object. + input_train_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]): + The training edges. (default: :obj:`None`) + input_train_edge_label (Tensor, optional): + The labels of train edge indices. + input_val_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]): + The validation edges. (default: :obj:`None`) + input_val_edge_label (Tensor, optional): + The labels of val edge indices. + input_test_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]): + The test edges. (default: :obj:`None`) + input_test_edge_label (Tensor, optional): + The labels of train edge indices. + loader (str): The scalability technique to use (:obj:`"full"`, + :obj:`"link_neighbor"`). (default: :obj:`"link_neighbor"`) + batch_size (int, optional): How many samples per batch to load. + (default: :obj:`1`) + num_workers: How many subprocesses to use for data loading. + :obj:`0` means that the data will be loaded in the main process. + (default: :obj:`0`) + **kwargs (optional): Additional arguments of + :class:`torch_geometric.loader.LinkNeighborLoader`. + """ + def __init__( + self, + data: Union[Data, HeteroData], + input_train_edges: InputEdges = None, + input_train_edge_label: torch.Tensor = None, + input_val_edges: InputEdges = None, + input_val_edge_label: torch.Tensor = None, + input_test_edges: InputEdges = None, + input_test_edge_label: torch.Tensor = None, + loader: str = "link_neighbor", + batch_size: int = 1, + num_workers: int = 0, + **kwargs, + ): + + assert loader in ['full', 'link_neighbor'] + # TODO: Handle or document behavior where none of train, val, test + # edges are specified. + if loader == 'full' and batch_size != 1: + warnings.warn(f"Re-setting 'batch_size' to 1 in " + f"'{self.__class__.__name__}' for loader='full' " + f"(got '{batch_size}')") + batch_size = 1 + + if loader == 'full' and num_workers != 0: + warnings.warn(f"Re-setting 'num_workers' to 0 in " + f"'{self.__class__.__name__}' for loader='full' " + f"(got '{num_workers}')") + num_workers = 0 + + super().__init__( + has_val=input_val_edges is not None, + has_test=input_test_edges is not None, + batch_size=batch_size, + num_workers=num_workers, + **kwargs, + ) + + if loader == 'full': + if kwargs.get('pin_memory', False): + warnings.warn(f"Re-setting 'pin_memory' to 'False' in " + f"'{self.__class__.__name__}' for loader='full' " + f"(got 'True')") + self.kwargs['pin_memory'] = False + + self.data = data + self.loader = loader + + self.input_train_edges = input_train_edges + self.input_train_edge_label = input_train_edge_label + self.input_val_edges = input_val_edges + self.input_val_edge_label = input_val_edge_label + self.input_test_edges = input_test_edges + self.input_test_edge_label = input_test_edge_label + + def prepare_data(self): + """""" + if self.loader == 'full': + try: + num_devices = self.trainer.num_devices + except AttributeError: + # PyTorch Lightning < 1.6 backward compatibility: + num_devices = self.trainer.num_processes + num_devices = max(num_devices, self.trainer.num_gpus) + + if num_devices > 1: + raise ValueError( + f"'{self.__class__.__name__}' with loader='full' requires " + f"training on a single device") + super().prepare_data() + + def dataloader(self, input_edges: InputEdges, input_labels: torch.Tensor, + shuffle: bool) -> DataLoader: + if self.loader == 'full': + warnings.filterwarnings('ignore', '.*does not have many workers.*') + warnings.filterwarnings('ignore', '.*data loading bottlenecks.*') + return torch.utils.data.DataLoader([self.data], shuffle=False, + collate_fn=lambda xs: xs[0], + **self.kwargs) + + if self.loader == 'link_neighbor': + return LinkNeighborLoader(data=self.data, + edge_label_index=input_edges, + edge_label=input_labels, shuffle=shuffle, + **self.kwargs) + + raise NotImplementedError + + def train_dataloader(self) -> DataLoader: + """""" + return self.dataloader(self.input_train_edges, + self.input_train_edge_label, shuffle=True) + + def val_dataloader(self) -> DataLoader: + """""" + return self.dataloader(self.input_val_edges, self.input_val_edge_label, + shuffle=False) + + def test_dataloader(self) -> DataLoader: + """""" + return self.dataloader(self.input_test_edges, + self.input_test_edge_label, shuffle=False) + + def __repr__(self) -> str: + kwargs = kwargs_repr(data=self.data, loader=self.loader, **self.kwargs) + return f'{self.__class__.__name__}({kwargs})' + + ############################################################################### diff --git a/torch_geometric/loader/link_neighbor_loader.py b/torch_geometric/loader/link_neighbor_loader.py index 98d527d6e475..b3d4cc70bd63 100644 --- a/torch_geometric/loader/link_neighbor_loader.py +++ b/torch_geometric/loader/link_neighbor_loader.py @@ -108,7 +108,7 @@ class LinkNeighborLoader(torch.utils.data.DataLoader): .. code-block:: python from torch_geometric.datasets import Planetoid - from torch_geometric.loader import NeighborLoader + from torch_geometric.loader import LinkNeighborLoader data = Planetoid(path, name='Cora')[0] @@ -276,7 +276,7 @@ def __init__( def filter_fn(self, out: Any) -> Union[Data, HeteroData]: if isinstance(self.data, Data): - node, row, col, edge, edge_label_index, edge_label = out + (node, row, col, edge, edge_label_index, edge_label) = out data = filter_data(self.data, node, row, col, edge, self.neighbor_sampler.perm) data.edge_label_index = edge_label_index @@ -355,7 +355,6 @@ def get_edge_label_index( edge_type, edge_label_index = edge_label_index edge_type = data._to_canonical(*edge_type) - assert edge_type in data.edge_types if edge_label_index is None: return edge_type, data[edge_type].edge_index