Skip to content

Commit

Permalink
Improve inference loop on GPU devices (pyg-team#7896)
Browse files Browse the repository at this point in the history
In layer-wise inference loop, we perform computations as shown on the
following pseudocode:
```
for layer in layers
    for batch in loader
        do inference per layer
```
In models that have more than one layer, we can benefit from caching
batches during the
first walk through the data. This PR introduces `CachedLoader`, which
transfers
batches to a pointed device and caches them. Additionally, an auxiliary
function was provided,
`make_batches_cacheable`, which decorates `BasicGNN` instance with a
custom inference
loop.

Selected performance results (gained on Intel PVC):
```
Speedup:
gcn[2L]+Reddit: 1.53x
gcn[3L]+Reddit: 1.69x
sage[2L]+Reddit: 1.55x
sage[3L]+Reddit: 2.02x
gcn[2L]+ogbn-products: 1.72x
gcn[3L]+ogbn-products: 2.11x
sage[2L]+ogbn-products: 1.83x
sage[3L]+ogbn-products: 2.44x
```
Caching mechanism did not have a significant impact on models with a
single layer.

Drawbacks:
- User should be aware that caching mechanism requires additional device
memory to be allocated.
In experiments, approximately 1GB was needed for the `Reddit` dataset.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored and erfanloghmani committed Aug 31, 2023
1 parent 87df3fc commit 48e8a9d
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Added `MyketDataset` ([#7959](https://github.com/pyg-team/pytorch_geometric/pull/7959))
- Added `CachedLoader` implementation ([#7896](https://github.com/pyg-team/pytorch_geometric/pull/7896))
- Added possibility to run training benchmarks on XPU device ([#7925](https://github.com/pyg-team/pytorch_geometric/pull/7925))
- Added `utils.ppr` for personalized PageRank computation ([#7917](https://github.com/pyg-team/pytorch_geometric/pull/7917))
- Added support for XPU device in `PrefetchLoader` ([#7918](https://github.com/pyg-team/pytorch_geometric/pull/7918))
Expand Down
78 changes: 78 additions & 0 deletions test/loader/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import torch
from torch import Tensor

from torch_geometric.data import Data
from torch_geometric.loader import CachedLoader, NeighborLoader
from torch_geometric.testing import withCUDA, withPackage


@withCUDA
@withPackage('pyg_lib')
def test_cached_loader(device):
x = torch.randn(14, 16)
edge_index = torch.tensor([
[2, 3, 4, 5, 7, 7, 10, 11, 12, 13],
[0, 1, 2, 3, 2, 3, 7, 7, 7, 7],
])

loader = NeighborLoader(
Data(x=x, edge_index=edge_index),
num_neighbors=[2],
batch_size=10,
shuffle=False,
)
cached_loader = CachedLoader(loader, device=device)

assert len(cached_loader) == len(loader)
assert len(cached_loader._cache) == 0

cache = []
for i, batch in enumerate(cached_loader):
assert len(cached_loader._cache) == i + 1
assert batch.x.device == device
assert batch.edge_index.device == device

cache.append(batch)

for i, batch in enumerate(cached_loader):
assert batch == cache[i]

cached_loader.clear()
assert len(cached_loader._cache) == 0


@withCUDA
@withPackage('pyg_lib')
def test_cached_loader_transform(device):
x = torch.randn(14, 16)
edge_index = torch.tensor([
[2, 3, 4, 5, 7, 7, 10, 11, 12, 13],
[0, 1, 2, 3, 2, 3, 7, 7, 7, 7],
])

loader = NeighborLoader(
Data(x=x, edge_index=edge_index),
num_neighbors=[2],
batch_size=10,
shuffle=False,
)
cached_loader = CachedLoader(
loader,
device=device,
transform=lambda batch: batch.edge_index,
)

assert len(cached_loader) == len(loader)
assert len(cached_loader._cache) == 0

cache = []
for i, batch in enumerate(cached_loader):
assert len(cached_loader._cache) == i + 1
assert isinstance(batch, Tensor)
assert batch.dim() == 2 and batch.size(0) == 2
assert batch.device == device

cache.append(batch)

for i, batch in enumerate(cached_loader):
assert torch.equal(batch, cache[i])
23 changes: 23 additions & 0 deletions test/nn/models/test_basic_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,29 @@ def my_custom_backend(gm, *args):
assert num_compile_calls - num_previous_compile_calls == 1


@withPackage('pyg_lib')
def test_basic_gnn_cache():
x = torch.randn(14, 16)
edge_index = torch.tensor([
[2, 3, 4, 5, 7, 7, 10, 11, 12, 13],
[0, 1, 2, 3, 2, 3, 7, 7, 7, 7],
])

loader = NeighborLoader(
Data(x=x, edge_index=edge_index),
num_neighbors=[-1],
batch_size=2,
)

model = GCN(in_channels=16, hidden_channels=16, num_layers=2)
model.eval()

out1 = model.inference(loader, cache=False)
out2 = model.inference(loader, cache=True)

assert torch.allclose(out1, out2)


if __name__ == '__main__':
import argparse

Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .imbalanced_sampler import ImbalancedSampler
from .dynamic_batch_sampler import DynamicBatchSampler
from .prefetch import PrefetchLoader
from .cache import CachedLoader
from .mixin import AffinityMixin

__all__ = classes = [
Expand All @@ -44,6 +45,7 @@
'ImbalancedSampler',
'DynamicBatchSampler',
'PrefetchLoader',
'CachedLoader',
'AffinityMixin',
]

Expand Down
70 changes: 70 additions & 0 deletions torch_geometric/loader/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from collections.abc import Mapping
from typing import Any, Callable, List, Optional, Sequence

import torch
from torch.utils.data import DataLoader


def to_device(inputs: Any, device: Optional[torch.device] = None) -> Any:
if hasattr(inputs, 'to'):
return inputs.to(device)
elif isinstance(inputs, Mapping):
return {key: to_device(value, device) for key, value in inputs.items()}
elif isinstance(inputs, tuple) and hasattr(inputs, '_fields'):
return type(inputs)(*(to_device(s, device) for s in zip(*inputs)))
elif isinstance(inputs, Sequence) and not isinstance(inputs, str):
return [to_device(s, device) for s in zip(*inputs)]

return inputs


class CachedLoader:
r"""A loader to cache mini-batch outputs, e.g., obtained during
:class:`NeighborLoader` iterations.
Args:
loader (torch.utils.data.DataLoader): The data loader.
device (torch.device, optional): The device to load the data to.
(default: :obj:`None`)
transform (callable, optional): A function/transform that takes in
a sampled mini-batch and returns a transformed version.
(default: :obj:`None`)
"""
def __init__(
self,
loader: DataLoader,
device: Optional[torch.device] = None,
transform: Optional[Callable] = None,
):
self.loader = loader
self.device = device
self.transform = transform

self._cache: List[Any] = []

def clear(self):
r"""Clears the cache."""
self._cache = []

def __iter__(self) -> Any:
if len(self._cache):
for batch in self._cache:
yield batch
return

for batch in self.loader:

if self.transform is not None:
batch = self.transform(batch)

batch = to_device(batch, self.device)

self._cache.append(batch)

yield batch

def __len__(self) -> int:
return len(self.loader)

def __repr__(self) -> str:
return f'{self.__class__.__name__}({self.loader})'
22 changes: 21 additions & 1 deletion torch_geometric/nn/models/basic_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from torch.nn import Linear, ModuleList
from tqdm import tqdm

from torch_geometric.loader import NeighborLoader
from torch_geometric.data import Data
from torch_geometric.loader import CachedLoader, NeighborLoader
from torch_geometric.nn.conv import (
EdgeConv,
GATConv,
Expand Down Expand Up @@ -303,6 +304,7 @@ def inference(
device: Optional[Union[str, torch.device]] = None,
embedding_device: Union[str, torch.device] = 'cpu',
progress_bar: bool = False,
cache: bool = False,
) -> Tensor:
r"""Performs layer-wise inference on large-graphs using a
:class:`~torch_geometric.loader.NeighborLoader`, where
Expand All @@ -324,6 +326,10 @@ def inference(
(default: :obj:`"cpu"`)
progress_bar (bool, optional): If set to :obj:`True`, will print a
progress bar during computation. (default: :obj:`False`)
cache (bool, optional): If set to :obj:`True`, caches intermediate
sampler outputs for usage in later epochs.
This will avoid repeated sampling to accelerate inference.
(default: :obj:`False`)
"""
assert self.jk_mode is None or self.jk_mode == 'last'
assert isinstance(loader, NeighborLoader)
Expand All @@ -337,6 +343,20 @@ def inference(

x_all = loader.data.x.to(embedding_device)

if cache:

# Only cache necessary attributes:
def transform(data: Data) -> Data:
kwargs = dict(n_id=data.n_id, batch_size=data.batch_size)
if hasattr(data, 'adj_t'):
kwargs['adj_t'] = data.adj_t
else:
kwargs['edge_index'] = data.edge_index

return Data.from_dict(kwargs)

loader = CachedLoader(loader, device=device, transform=transform)

for i in range(self.num_layers):
xs: List[Tensor] = []
for batch in loader:
Expand Down

0 comments on commit 48e8a9d

Please sign in to comment.