From 7b7fdc1434a0cf1c1eb2f4e814e47cc336c6deda Mon Sep 17 00:00:00 2001 From: Jinu Sunil Date: Sat, 7 Oct 2023 18:47:55 +0530 Subject: [PATCH] Add `TensorFrame` support in `DataLoader` (#8151) Co-authored-by: rusty1s --- CHANGELOG.md | 2 +- test/loader/test_dataloader.py | 26 +++++++++++++++++++++++++- torch_geometric/data/collate.py | 14 +++++++++++++- torch_geometric/data/hetero_data.py | 11 ++--------- torch_geometric/data/separate.py | 8 +++++++- torch_geometric/loader/dataloader.py | 3 +++ torch_geometric/typing.py | 1 + 7 files changed, 52 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c3ac19f6bc43..b5b4c9732615 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added the `NeuroGraphDataset` benchmark collection ([#8122](https://github.com/pyg-team/pytorch_geometric/pull/8122)) - Added support for a node-level `mask` tensor in `dense_to_sparse` ([#8117](https://github.com/pyg-team/pytorch_geometric/pull/8117)) - Added the `to_on_disk_dataset()` method to convert `InMemoryDataset` instances to `OnDiskDataset` instances ([#8116](https://github.com/pyg-team/pytorch_geometric/pull/8116)) -- Added `torch-frame` support ([#8110](https://github.com/pyg-team/pytorch_geometric/pull/8110), [#8118](https://github.com/pyg-team/pytorch_geometric/pull/8118)) +- Added `torch-frame` support ([#8110](https://github.com/pyg-team/pytorch_geometric/pull/8110), [#8118](https://github.com/pyg-team/pytorch_geometric/pull/8118), [#8151](https://github.com/pyg-team/pytorch_geometric/pull/8151)) - Added the `DistLoader` base class ([#8079](https://github.com/pyg-team/pytorch_geometric/pull/8079)) - Added `HyperGraphData` to support hypergraphs ([#7611](https://github.com/pyg-team/pytorch_geometric/pull/7611)) - Added the `PCQM4Mv2` dataset as a reference implementation for `OnDiskDataset` ([#8102](https://github.com/pyg-team/pytorch_geometric/pull/8102)) diff --git a/test/loader/test_dataloader.py b/test/loader/test_dataloader.py index 50113f380bb3..8ebe9ca010cd 100644 --- a/test/loader/test_dataloader.py +++ b/test/loader/test_dataloader.py @@ -7,7 +7,11 @@ from torch_geometric.data import Data, HeteroData, OnDiskDataset from torch_geometric.loader import DataLoader -from torch_geometric.testing import get_random_edge_index, withCUDA +from torch_geometric.testing import ( + get_random_edge_index, + withCUDA, + withPackage, +) with_mp = sys.platform not in ['win32'] num_workers_list = [0, 2] if with_mp else [0] @@ -181,6 +185,26 @@ def test_heterogeneous_dataloader(num_workers): assert id(batch) == id(store._parent()) +@withPackage('torch_frame') +def test_dataloader_tensor_frame(get_tensor_frame): + tf = get_tensor_frame(10) + loader = DataLoader([tf, tf, tf, tf], batch_size=2, shuffle=False) + assert len(loader) == 2 + + for batch in loader: + assert batch.num_rows == 20 + + data = Data(tf=tf, edge_index=get_random_edge_index(10, 10, 20)) + loader = DataLoader([data, data, data, data], batch_size=2, shuffle=False) + assert len(loader) == 2 + + for batch in loader: + assert batch.num_graphs == len(batch) == 2 + assert batch.num_nodes == 20 + assert batch.tf.num_rows == 20 + assert batch.edge_index.max() >= 10 + + if __name__ == '__main__': import argparse import time diff --git a/torch_geometric/data/collate.py b/torch_geometric/data/collate.py index bef83df9843d..fb5eeb196fe2 100644 --- a/torch_geometric/data/collate.py +++ b/torch_geometric/data/collate.py @@ -8,7 +8,12 @@ import torch_geometric.typing from torch_geometric.data.data import BaseData from torch_geometric.data.storage import BaseStorage, NodeStorage -from torch_geometric.typing import SparseTensor, torch_sparse +from torch_geometric.typing import ( + SparseTensor, + TensorFrame, + torch_frame, + torch_sparse, +) from torch_geometric.utils import cumsum, is_sparse, is_torch_sparse_tensor from torch_geometric.utils.sparse import cat @@ -173,6 +178,13 @@ def _collate( return value, slices, incs + elif isinstance(elem, TensorFrame): + key = str(key) + sizes = torch.tensor([value.num_rows for value in values]) + slices = cumsum(sizes) + value = torch_frame.cat(values, along='row') + return value, slices, None + elif is_sparse(elem) and increment: # Concatenate a list of `SparseTensor` along the `cat_dim`. # NOTE: `cat_dim` may return a tuple to allow for diagonal stacking. diff --git a/torch_geometric/data/hetero_data.py b/torch_geometric/data/hetero_data.py index d67c2b41ea95..8f1d1fba43c4 100644 --- a/torch_geometric/data/hetero_data.py +++ b/torch_geometric/data/hetero_data.py @@ -23,6 +23,7 @@ QueryType, SparseTensor, TensorFrame, + torch_frame, ) from torch_geometric.utils import ( bipartite_subgraph, @@ -912,15 +913,7 @@ def _consistent_size(stores: List[BaseStorage]) -> List[str]: continue values = [store[key] for store in self.node_stores] if isinstance(values[0], TensorFrame): - # TODO (jinu) Implement `cat` function for TensorFrame. - feat_dict = {} - for stype in values[0].feat_dict.keys(): - feat_dict[stype] = torch.cat( - [value.feat_dict[stype] for value in values], dim=0) - y = None - if values[0].y is not None: - y = torch.cat([value.y for value in values], dim=0) - value = TensorFrame(feat_dict, values[0].col_names_dict, y) + value = torch_frame.cat(values, along='row') else: dim = self.__cat_dim__(key, values[0], self.node_stores[0]) dim = values[0].dim() + dim if dim < 0 else dim diff --git a/torch_geometric/data/separate.py b/torch_geometric/data/separate.py index 51429b9b7639..5412e7c133d6 100644 --- a/torch_geometric/data/separate.py +++ b/torch_geometric/data/separate.py @@ -5,7 +5,7 @@ from torch_geometric.data.data import BaseData from torch_geometric.data.storage import BaseStorage -from torch_geometric.typing import SparseTensor +from torch_geometric.typing import SparseTensor, TensorFrame from torch_geometric.utils import narrow @@ -81,6 +81,12 @@ def _separate( value = value.narrow(dim, start, end - start) return value + elif isinstance(value, TensorFrame): + key = str(key) + start, end = int(slices[idx]), int(slices[idx + 1]) + value = value[start:end] + return value + elif isinstance(value, Mapping): # Recursively separate elements of dictionaries. return { diff --git a/torch_geometric/loader/dataloader.py b/torch_geometric/loader/dataloader.py index c803dbba8711..295da9f09299 100644 --- a/torch_geometric/loader/dataloader.py +++ b/torch_geometric/loader/dataloader.py @@ -8,6 +8,7 @@ from torch_geometric.data.data import BaseData from torch_geometric.data.datapipes import DatasetAdapter from torch_geometric.data.on_disk_dataset import OnDiskDataset +from torch_geometric.typing import TensorFrame, torch_frame class Collater: @@ -31,6 +32,8 @@ def __call__(self, batch: List[Any]) -> Any: ) elif isinstance(elem, torch.Tensor): return default_collate(batch) + elif isinstance(elem, TensorFrame): + return torch_frame.cat(batch, along='row') elif isinstance(elem, float): return torch.tensor(batch, dtype=torch.float) elif isinstance(elem, int): diff --git a/torch_geometric/typing.py b/torch_geometric/typing.py index 38c935c09f3d..ccae087c9672 100644 --- a/torch_geometric/typing.py +++ b/torch_geometric/typing.py @@ -216,6 +216,7 @@ def masked_select_nnz(src: SparseTensor, mask: Tensor, WITH_TORCH_FRAME = True from torch_frame import TensorFrame except Exception: + torch_frame = object WITH_TORCH_FRAME = False class TensorFrame: