diff --git a/CHANGELOG.md b/CHANGELOG.md index 196f79ba6f69..b27833378315 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support for device conversions of `InMemoryDataset` ([#8402] (https://github.com/pyg-team/pytorch_geometric/pull/8402)) - Added support for edge-level temporal sampling in `NeighborLoader` and `LinkNeighborLoader` ([#8372] (https://github.com/pyg-team/pytorch_geometric/pull/8372)) - Added support for `torch.compile` in `ModuleDict` and `ParameterDict` ([#8363](https://github.com/pyg-team/pytorch_geometric/pull/8363)) - Added `force_reload` option to `Dataset` and `InMemoryDataset` to reload datasets ([#8352](https://github.com/pyg-team/pytorch_geometric/pull/8352), [#8357](https://github.com/pyg-team/pytorch_geometric/pull/8357)) diff --git a/test/nn/conv/test_hetero_conv.py b/test/nn/conv/test_hetero_conv.py index b0c03656aa87..04def1f8fb37 100644 --- a/test/nn/conv/test_hetero_conv.py +++ b/test/nn/conv/test_hetero_conv.py @@ -1,8 +1,11 @@ +import random + import pytest import torch import torch_geometric from torch_geometric.data import HeteroData +from torch_geometric.datasets import FakeHeteroDataset from torch_geometric.nn import ( GATConv, GCN2Conv, @@ -12,6 +15,7 @@ MessagePassing, SAGEConv, ) +from torch_geometric.profile import benchmark from torch_geometric.testing import ( disableExtensions, get_random_edge_index, @@ -205,3 +209,64 @@ def test_compile_hetero_conv_graph_breaks(device): assert len(out) == len(expected) for key in expected.keys(): assert torch.allclose(out[key], expected[key], atol=1e-6) + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--device', type=str, default='cuda') + parser.add_argument('--backward', action='store_true') + args = parser.parse_args() + + dataset = FakeHeteroDataset(num_graphs=10).to(args.device) + + def gen_args(): + data = dataset[random.randrange(len(dataset))] + return data.x_dict, data.edge_index_dict + + class HeteroGNN(torch.nn.Module): + def __init__(self, channels: int = 32, num_layers: int = 2): + super().__init__() + self.convs = torch.nn.ModuleList() + + conv = HeteroConv({ + edge_type: + SAGEConv( + in_channels=( + dataset.num_features[edge_type[0]], + dataset.num_features[edge_type[-1]], + ), + out_channels=channels, + ) + for edge_type in dataset[0].edge_types + }) + self.convs.append(conv) + + for _ in range(num_layers - 1): + conv = HeteroConv({ + edge_type: + SAGEConv((channels, channels), channels) + for edge_type in dataset[0].edge_types + }) + self.convs.append(conv) + + self.lin = Linear(channels, 1) + + def forward(self, x_dict, edge_index_dict): + for conv in self.convs: + x_dict = conv(x_dict, edge_index_dict) + x_dict = {key: x.relu() for key, x in x_dict.items()} + return self.lin(x_dict['v0']) + + model = HeteroGNN().to(args.device) + compiled_model = torch_geometric.compile(model) + + benchmark( + funcs=[model, compiled_model], + func_names=['Vanilla', 'Compiled'], + args=gen_args, + num_steps=50 if args.device == 'cpu' else 500, + num_warmups=10 if args.device == 'cpu' else 100, + backward=args.backward, + ) diff --git a/torch_geometric/data/in_memory_dataset.py b/torch_geometric/data/in_memory_dataset.py index 62b22bf85c43..34bc2f5c1dde 100644 --- a/torch_geometric/data/in_memory_dataset.py +++ b/torch_geometric/data/in_memory_dataset.py @@ -307,6 +307,31 @@ def __getattr__(self, key: str) -> Any: raise AttributeError(f"'{self.__class__.__name__}' object has no " f"attribute '{key}'") + def to(self, device: Union[int, str]) -> 'InMemoryDataset': + r"""Performs device conversion of the whole dataset.""" + if self._indices is not None: + raise ValueError("The given 'InMemoryDataset' only references a " + "subset of examples of the full dataset") + if self._data_list is not None: + raise ValueError("The data of the dataset is already cached") + self._data.to(device) + return self + + def cpu(self, *args: str) -> 'InMemoryDataset': + r"""Moves the dataset to CPU memory.""" + return self.to(torch.device('cpu')) + + def cuda( + self, + device: Optional[Union[int, str]] = None, + ) -> 'InMemoryDataset': + r"""Moves the dataset toto CUDA memory.""" + if isinstance(device, int): + device = f'cuda:{int}' + elif device is None: + device = 'cuda' + return self.to(device) + def nested_iter(node: Union[Mapping, Sequence]) -> Iterable: if isinstance(node, Mapping):