forked from pyg-team/pytorch_geometric
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improve inference loop on GPU devices (pyg-team#7896)
In layer-wise inference loop, we perform computations as shown on the following pseudocode: ``` for layer in layers for batch in loader do inference per layer ``` In models that have more than one layer, we can benefit from caching batches during the first walk through the data. This PR introduces `CachedLoader`, which transfers batches to a pointed device and caches them. Additionally, an auxiliary function was provided, `make_batches_cacheable`, which decorates `BasicGNN` instance with a custom inference loop. Selected performance results (gained on Intel PVC): ``` Speedup: gcn[2L]+Reddit: 1.53x gcn[3L]+Reddit: 1.69x sage[2L]+Reddit: 1.55x sage[3L]+Reddit: 2.02x gcn[2L]+ogbn-products: 1.72x gcn[3L]+ogbn-products: 2.11x sage[2L]+ogbn-products: 1.83x sage[3L]+ogbn-products: 2.44x ``` Caching mechanism did not have a significant impact on models with a single layer. Drawbacks: - User should be aware that caching mechanism requires additional device memory to be allocated. In experiments, approximately 1GB was needed for the `Reddit` dataset. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
- Loading branch information
1 parent
87df3fc
commit 48e8a9d
Showing
6 changed files
with
195 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import torch | ||
from torch import Tensor | ||
|
||
from torch_geometric.data import Data | ||
from torch_geometric.loader import CachedLoader, NeighborLoader | ||
from torch_geometric.testing import withCUDA, withPackage | ||
|
||
|
||
@withCUDA | ||
@withPackage('pyg_lib') | ||
def test_cached_loader(device): | ||
x = torch.randn(14, 16) | ||
edge_index = torch.tensor([ | ||
[2, 3, 4, 5, 7, 7, 10, 11, 12, 13], | ||
[0, 1, 2, 3, 2, 3, 7, 7, 7, 7], | ||
]) | ||
|
||
loader = NeighborLoader( | ||
Data(x=x, edge_index=edge_index), | ||
num_neighbors=[2], | ||
batch_size=10, | ||
shuffle=False, | ||
) | ||
cached_loader = CachedLoader(loader, device=device) | ||
|
||
assert len(cached_loader) == len(loader) | ||
assert len(cached_loader._cache) == 0 | ||
|
||
cache = [] | ||
for i, batch in enumerate(cached_loader): | ||
assert len(cached_loader._cache) == i + 1 | ||
assert batch.x.device == device | ||
assert batch.edge_index.device == device | ||
|
||
cache.append(batch) | ||
|
||
for i, batch in enumerate(cached_loader): | ||
assert batch == cache[i] | ||
|
||
cached_loader.clear() | ||
assert len(cached_loader._cache) == 0 | ||
|
||
|
||
@withCUDA | ||
@withPackage('pyg_lib') | ||
def test_cached_loader_transform(device): | ||
x = torch.randn(14, 16) | ||
edge_index = torch.tensor([ | ||
[2, 3, 4, 5, 7, 7, 10, 11, 12, 13], | ||
[0, 1, 2, 3, 2, 3, 7, 7, 7, 7], | ||
]) | ||
|
||
loader = NeighborLoader( | ||
Data(x=x, edge_index=edge_index), | ||
num_neighbors=[2], | ||
batch_size=10, | ||
shuffle=False, | ||
) | ||
cached_loader = CachedLoader( | ||
loader, | ||
device=device, | ||
transform=lambda batch: batch.edge_index, | ||
) | ||
|
||
assert len(cached_loader) == len(loader) | ||
assert len(cached_loader._cache) == 0 | ||
|
||
cache = [] | ||
for i, batch in enumerate(cached_loader): | ||
assert len(cached_loader._cache) == i + 1 | ||
assert isinstance(batch, Tensor) | ||
assert batch.dim() == 2 and batch.size(0) == 2 | ||
assert batch.device == device | ||
|
||
cache.append(batch) | ||
|
||
for i, batch in enumerate(cached_loader): | ||
assert torch.equal(batch, cache[i]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from collections.abc import Mapping | ||
from typing import Any, Callable, List, Optional, Sequence | ||
|
||
import torch | ||
from torch.utils.data import DataLoader | ||
|
||
|
||
def to_device(inputs: Any, device: Optional[torch.device] = None) -> Any: | ||
if hasattr(inputs, 'to'): | ||
return inputs.to(device) | ||
elif isinstance(inputs, Mapping): | ||
return {key: to_device(value, device) for key, value in inputs.items()} | ||
elif isinstance(inputs, tuple) and hasattr(inputs, '_fields'): | ||
return type(inputs)(*(to_device(s, device) for s in zip(*inputs))) | ||
elif isinstance(inputs, Sequence) and not isinstance(inputs, str): | ||
return [to_device(s, device) for s in zip(*inputs)] | ||
|
||
return inputs | ||
|
||
|
||
class CachedLoader: | ||
r"""A loader to cache mini-batch outputs, e.g., obtained during | ||
:class:`NeighborLoader` iterations. | ||
Args: | ||
loader (torch.utils.data.DataLoader): The data loader. | ||
device (torch.device, optional): The device to load the data to. | ||
(default: :obj:`None`) | ||
transform (callable, optional): A function/transform that takes in | ||
a sampled mini-batch and returns a transformed version. | ||
(default: :obj:`None`) | ||
""" | ||
def __init__( | ||
self, | ||
loader: DataLoader, | ||
device: Optional[torch.device] = None, | ||
transform: Optional[Callable] = None, | ||
): | ||
self.loader = loader | ||
self.device = device | ||
self.transform = transform | ||
|
||
self._cache: List[Any] = [] | ||
|
||
def clear(self): | ||
r"""Clears the cache.""" | ||
self._cache = [] | ||
|
||
def __iter__(self) -> Any: | ||
if len(self._cache): | ||
for batch in self._cache: | ||
yield batch | ||
return | ||
|
||
for batch in self.loader: | ||
|
||
if self.transform is not None: | ||
batch = self.transform(batch) | ||
|
||
batch = to_device(batch, self.device) | ||
|
||
self._cache.append(batch) | ||
|
||
yield batch | ||
|
||
def __len__(self) -> int: | ||
return len(self.loader) | ||
|
||
def __repr__(self) -> str: | ||
return f'{self.__class__.__name__}({self.loader})' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters