Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile benchmark for HeteroConv; Allow device conversions of datasets #8402

Merged
merged 3 commits into from
Nov 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for device conversions of `InMemoryDataset` ([#8402] (https://github.com/pyg-team/pytorch_geometric/pull/8402))
- Added support for edge-level temporal sampling in `NeighborLoader` and `LinkNeighborLoader` ([#8372] (https://github.com/pyg-team/pytorch_geometric/pull/8372))
- Added support for `torch.compile` in `ModuleDict` and `ParameterDict` ([#8363](https://github.com/pyg-team/pytorch_geometric/pull/8363))
- Added `force_reload` option to `Dataset` and `InMemoryDataset` to reload datasets ([#8352](https://github.com/pyg-team/pytorch_geometric/pull/8352), [#8357](https://github.com/pyg-team/pytorch_geometric/pull/8357))
Expand Down
65 changes: 65 additions & 0 deletions test/nn/conv/test_hetero_conv.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import random

import pytest
import torch

import torch_geometric
from torch_geometric.data import HeteroData
from torch_geometric.datasets import FakeHeteroDataset
from torch_geometric.nn import (
GATConv,
GCN2Conv,
Expand All @@ -12,6 +15,7 @@
MessagePassing,
SAGEConv,
)
from torch_geometric.profile import benchmark
from torch_geometric.testing import (
disableExtensions,
get_random_edge_index,
Expand Down Expand Up @@ -205,3 +209,64 @@ def test_compile_hetero_conv_graph_breaks(device):
assert len(out) == len(expected)
for key in expected.keys():
assert torch.allclose(out[key], expected[key], atol=1e-6)


if __name__ == '__main__':
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--backward', action='store_true')
args = parser.parse_args()

dataset = FakeHeteroDataset(num_graphs=10).to(args.device)

def gen_args():
data = dataset[random.randrange(len(dataset))]
return data.x_dict, data.edge_index_dict

class HeteroGNN(torch.nn.Module):
def __init__(self, channels: int = 32, num_layers: int = 2):
super().__init__()
self.convs = torch.nn.ModuleList()

conv = HeteroConv({
edge_type:
SAGEConv(
in_channels=(
dataset.num_features[edge_type[0]],
dataset.num_features[edge_type[-1]],
),
out_channels=channels,
)
for edge_type in dataset[0].edge_types
})
self.convs.append(conv)

for _ in range(num_layers - 1):
conv = HeteroConv({
edge_type:
SAGEConv((channels, channels), channels)
for edge_type in dataset[0].edge_types
})
self.convs.append(conv)

self.lin = Linear(channels, 1)

def forward(self, x_dict, edge_index_dict):
for conv in self.convs:
x_dict = conv(x_dict, edge_index_dict)
x_dict = {key: x.relu() for key, x in x_dict.items()}
return self.lin(x_dict['v0'])

model = HeteroGNN().to(args.device)
compiled_model = torch_geometric.compile(model)

benchmark(
funcs=[model, compiled_model],
func_names=['Vanilla', 'Compiled'],
args=gen_args,
num_steps=50 if args.device == 'cpu' else 500,
num_warmups=10 if args.device == 'cpu' else 100,
backward=args.backward,
)
25 changes: 25 additions & 0 deletions torch_geometric/data/in_memory_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,31 @@
raise AttributeError(f"'{self.__class__.__name__}' object has no "
f"attribute '{key}'")

def to(self, device: Union[int, str]) -> 'InMemoryDataset':
r"""Performs device conversion of the whole dataset."""
if self._indices is not None:
raise ValueError("The given 'InMemoryDataset' only references a "

Check warning on line 313 in torch_geometric/data/in_memory_dataset.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/data/in_memory_dataset.py#L312-L313

Added lines #L312 - L313 were not covered by tests
"subset of examples of the full dataset")
if self._data_list is not None:
raise ValueError("The data of the dataset is already cached")
self._data.to(device)
return self

Check warning on line 318 in torch_geometric/data/in_memory_dataset.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/data/in_memory_dataset.py#L315-L318

Added lines #L315 - L318 were not covered by tests

def cpu(self, *args: str) -> 'InMemoryDataset':
r"""Moves the dataset to CPU memory."""
return self.to(torch.device('cpu'))

Check warning on line 322 in torch_geometric/data/in_memory_dataset.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/data/in_memory_dataset.py#L322

Added line #L322 was not covered by tests

def cuda(
self,
device: Optional[Union[int, str]] = None,
) -> 'InMemoryDataset':
r"""Moves the dataset toto CUDA memory."""
if isinstance(device, int):
device = f'cuda:{int}'
elif device is None:
device = 'cuda'
return self.to(device)

Check warning on line 333 in torch_geometric/data/in_memory_dataset.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/data/in_memory_dataset.py#L329-L333

Added lines #L329 - L333 were not covered by tests


def nested_iter(node: Union[Mapping, Sequence]) -> Iterable:
if isinstance(node, Mapping):
Expand Down
Loading