Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LinkNeighborLoader to Pytorch Lightning datamodule #4868

Merged
merged 20 commits into from
Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added `LinkeNeighborLoader` support to lightning datamodule ([#4868](https://github.com/pyg-team/pytorch_geometric/pull/4868))
- Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815))
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857))
- Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847))
Expand Down
26 changes: 25 additions & 1 deletion test/data/test_lightning_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@
import torch
import torch.nn.functional as F

from torch_geometric.data import LightningDataset, LightningNodeData
from torch_geometric.data import (
LightningDataset,
LightningLinkData,
LightningNodeData,
)
from torch_geometric.nn import global_mean_pool
from torch_geometric.testing import onlyFullTest, withCUDA, withPackage
from torch_geometric.typing import EdgeType

try:
from pytorch_lightning import LightningModule
Expand Down Expand Up @@ -264,3 +269,22 @@ def test_lightning_hetero_node_data(get_dataset):
offset += 5 * devices * math.ceil(400 / (devices * 32)) # `train`
offset += 5 * devices * math.ceil(400 / (devices * 32)) # `val`
assert torch.all(new_x > (old_x + offset - 4)) # Ensure shared data.


@withCUDA
@onlyFullTest
@withPackage('pytorch_lightning')
def test_lightning_hetero_link_data(get_dataset):
ar104 marked this conversation as resolved.
Show resolved Hide resolved
import pytorch_lightning as pl

dataset = get_dataset(name='DBLP')
data = dataset[0]
datamodule = LightningLinkData(data, loader='link_neighbor',
num_neighbors=[5], batch_size=32,
num_workers=3)
input_edges = (('author', 'dummy', 'paper'), data['author',
'paper']['edge_index'])
loader = datamodule.dataloader(input_edges=input_edges, shuffle=True)
batch = next(iter(loader))
assert (batch['author', 'dummy',
'paper']['edge_label_index'].shape[1] == 32)
7 changes: 6 additions & 1 deletion torch_geometric/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from .batch import Batch
from .dataset import Dataset
from .in_memory_dataset import InMemoryDataset
from .lightning_datamodule import LightningDataset, LightningNodeData
from .lightning_datamodule import (
LightningDataset,
LightningLinkData,
LightningNodeData,
)
from .makedirs import makedirs
from .download import download_url
from .extract import extract_tar, extract_zip, extract_bz2, extract_gz
Expand All @@ -18,6 +22,7 @@
'InMemoryDataset',
'LightningDataset',
'LightningNodeData',
'LightningLinkData',
'makedirs',
'download_url',
'extract_tar',
Expand Down
155 changes: 148 additions & 7 deletions torch_geometric/data/lightning_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import torch

from torch_geometric.data import Data, Dataset, HeteroData
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.loader.dataloader import DataLoader
from torch_geometric.loader.neighbor_loader import (
NeighborLoader,
NeighborSampler,
get_input_nodes,
)
from torch_geometric.typing import InputNodes
from torch_geometric.typing import InputEdges, InputNodes

try:
from pytorch_lightning import LightningDataModule as PLLightningDataModule
Expand Down Expand Up @@ -194,9 +195,8 @@ class LightningNodeData(LightningDataModule):
indices of training nodes. If not given, will try to automatically
infer them from the :obj:`data` object. (default: :obj:`None`)
input_val_nodes (torch.Tensor or str or (str, torch.Tensor)): The
indices of validation nodes. If not given, will try to
automatically infer them from the :obj:`data` object.
(default: :obj:`None`)
indices of validation nodes. If not given will try to automatically
infer them from the :obj:`data` object. (default: :obj:`None`)
input_test_nodes (torch.Tensor or str or (str, torch.Tensor)): The
indices of test nodes. If not given, will try to automatically
infer them from the :obj:`data` object. (default: :obj:`None`)
Expand Down Expand Up @@ -229,9 +229,8 @@ def __init__(

if input_val_nodes is None:
input_val_nodes = infer_input_nodes(data, split='val')

if input_val_nodes is None:
input_val_nodes = infer_input_nodes(data, split='valid')
if input_val_nodes is None:
input_val_nodes = infer_input_nodes(data, split='valid')

if input_test_nodes is None:
input_test_nodes = infer_input_nodes(data, split='test')
Expand Down Expand Up @@ -328,6 +327,148 @@ def __repr__(self) -> str:
return f'{self.__class__.__name__}({kwargs})'


class LightningLinkData(LightningDataModule):
r"""Converts a :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` object into a
:class:`pytorch_lightning.LightningDataModule` variant, which can be
automatically used as a :obj:`datamodule` for multi-GPU link-level
training (such as for link prediction) via `PyTorch Lightning
<https://www.pytorchlightning.ai>`_. :class:`LightningDataset` will
take care of providing mini-batches via
:class:`~torch_geometric.loader.LinkNeighborLoader`.

.. note::

Currently only the
:class:`pytorch_lightning.strategies.SingleDeviceStrategy` and
:class:`pytorch_lightning.strategies.DDPSpawnStrategy` training
strategies of `PyTorch Lightning
<https://pytorch-lightning.readthedocs.io/en/latest/guides/
speed.html>`__ are supported in order to correctly share data across
all devices/processes:

.. code-block::

import pytorch_lightning as pl
trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu",
devices=4)
trainer.fit(model, datamodule)

Args:
data (Data or HeteroData): The :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` graph object.
input_train_edges (torch.Tensor or str or (str, torch.Tensor)): The
ar104 marked this conversation as resolved.
Show resolved Hide resolved
ar104 marked this conversation as resolved.
Show resolved Hide resolved
ar104 marked this conversation as resolved.
Show resolved Hide resolved
training edges. (default: :obj:`None`)
input_val_edges (torch.Tensor or str or (str, torch.Tensor)): The
validation edges. (default: :obj:`None`)
input_test_nodes (torch.Tensor or str or (str, torch.Tensor)): The
ar104 marked this conversation as resolved.
Show resolved Hide resolved
test edges. (default: :obj:`None`)
loader (str): The scalability technique to use (:obj:`"full"`,
:obj:`"link_neighbor"`). (default: :obj:`"link_neighbor"`)
batch_size (int, optional): How many samples per batch to load.
(default: :obj:`1`)
num_workers: How many subprocesses to use for data loading.
:obj:`0` means that the data will be loaded in the main process.
(default: :obj:`0`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.loader.LinkNeighborLoader`.
"""
def __init__(
self,
data: Union[Data, HeteroData],
input_train_edges: InputEdges = None,
input_val_edges: InputEdges = None,
input_test_edges: InputEdges = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unclear what would happen when train/test/val edges are None. If input_val_edges or input_test_edges are None then all edges would be considered.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. I would be okay to make them required for now, and add a TODO to revisit this later on.

loader: str = "link_neighbor",
batch_size: int = 1,
num_workers: int = 0,
**kwargs,
):

assert loader in ['full', 'link_neighbor']
ar104 marked this conversation as resolved.
Show resolved Hide resolved

if loader == 'full' and batch_size != 1:
warnings.warn(f"Re-setting 'batch_size' to 1 in "
f"'{self.__class__.__name__}' for loader='full' "
f"(got '{batch_size}')")
batch_size = 1

if loader == 'full' and num_workers != 0:
warnings.warn(f"Re-setting 'num_workers' to 0 in "
f"'{self.__class__.__name__}' for loader='full' "
f"(got '{num_workers}')")
num_workers = 0

super().__init__(
has_val=input_val_edges is not None,
has_test=input_test_edges is not None,
batch_size=batch_size,
num_workers=num_workers,
**kwargs,
)

if loader == 'full':
if kwargs.get('pin_memory', False):
warnings.warn(f"Re-setting 'pin_memory' to 'False' in "
f"'{self.__class__.__name__}' for loader='full' "
f"(got 'True')")
self.kwargs['pin_memory'] = False

self.data = data
self.loader = loader

self.input_train_edges = input_train_edges
self.input_val_edges = input_val_edges
self.input_test_edges = input_test_edges

def prepare_data(self):
""""""
if self.loader == 'full':
try:
num_devices = self.trainer.num_devices
except AttributeError:
# PyTorch Lightning < 1.6 backward compatibility:
num_devices = self.trainer.num_processes
num_devices = max(num_devices, self.trainer.num_gpus)

if num_devices > 1:
raise ValueError(
f"'{self.__class__.__name__}' with loader='full' requires "
f"training on a single device")
super().prepare_data()

def dataloader(self, input_edges: InputEdges, shuffle: bool) -> DataLoader:
if self.loader == 'full':
warnings.filterwarnings('ignore', '.*does not have many workers.*')
warnings.filterwarnings('ignore', '.*data loading bottlenecks.*')
return torch.utils.data.DataLoader([self.data], shuffle=False,
collate_fn=lambda xs: xs[0],
**self.kwargs)

if self.loader == 'link_neighbor':
return LinkNeighborLoader(data=self.data,
edge_label_index=input_edges,
shuffle=shuffle, **self.kwargs)

raise NotImplementedError

def train_dataloader(self) -> DataLoader:
""""""
return self.dataloader(self.input_train_edges, shuffle=True)

def val_dataloader(self) -> DataLoader:
""""""
return self.dataloader(self.input_val_edges, shuffle=False)

def test_dataloader(self) -> DataLoader:
""""""
return self.dataloader(self.input_test_edges, shuffle=False)

def __repr__(self) -> str:
kwargs = kwargs_repr(data=self.data, loader=self.loader, **self.kwargs)
return f'{self.__class__.__name__}({kwargs})'


###############################################################################


Expand Down
7 changes: 3 additions & 4 deletions torch_geometric/loader/link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __call__(self, query: List[Tuple[Tensor]]):
self.directed,
)

return node, row, col, edge, edge_label_index, edge_label
return (node, row, col, edge, edge_label_index, edge_label)

elif issubclass(self.data_cls, HeteroData):
sample_fn = torch.ops.torch_sparse.hetero_neighbor_sample
Expand Down Expand Up @@ -132,7 +132,7 @@ class LinkNeighborLoader(torch.utils.data.DataLoader):
.. code-block:: python

from torch_geometric.datasets import Planetoid
from torch_geometric.loader import NeighborLoader
from torch_geometric.loader import LinkNeighborLoader

data = Planetoid(path, name='Cora')[0]

Expand Down Expand Up @@ -280,7 +280,7 @@ def __init__(
def transform_fn(self, out: Any) -> Union[Data, HeteroData]:
# NOTE This function will always be executed on the main thread!
if isinstance(self.data, Data):
node, row, col, edge, edge_label_index, edge_label = out
(node, row, col, edge, edge_label_index, edge_label) = out
data = filter_data(self.data, node, row, col, edge,
self.neighbor_sampler.perm)
data.edge_label_index = edge_label_index
Expand Down Expand Up @@ -349,7 +349,6 @@ def get_edge_label_index(

edge_type, edge_label_index = edge_label_index
edge_type = data._to_canonical(*edge_type)
assert edge_type in data.edge_types

if edge_label_index is None:
return edge_type, data[edge_type].edge_index
Expand Down