Skip to content

Commit

Permalink
Add TensorFrame support in DataLoader (#8151)
Browse files Browse the repository at this point in the history
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
wsad1 and rusty1s authored Oct 7, 2023
1 parent c3ea7ab commit 7b7fdc1
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 13 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `NeuroGraphDataset` benchmark collection ([#8122](https://github.com/pyg-team/pytorch_geometric/pull/8122))
- Added support for a node-level `mask` tensor in `dense_to_sparse` ([#8117](https://github.com/pyg-team/pytorch_geometric/pull/8117))
- Added the `to_on_disk_dataset()` method to convert `InMemoryDataset` instances to `OnDiskDataset` instances ([#8116](https://github.com/pyg-team/pytorch_geometric/pull/8116))
- Added `torch-frame` support ([#8110](https://github.com/pyg-team/pytorch_geometric/pull/8110), [#8118](https://github.com/pyg-team/pytorch_geometric/pull/8118))
- Added `torch-frame` support ([#8110](https://github.com/pyg-team/pytorch_geometric/pull/8110), [#8118](https://github.com/pyg-team/pytorch_geometric/pull/8118), [#8151](https://github.com/pyg-team/pytorch_geometric/pull/8151))
- Added the `DistLoader` base class ([#8079](https://github.com/pyg-team/pytorch_geometric/pull/8079))
- Added `HyperGraphData` to support hypergraphs ([#7611](https://github.com/pyg-team/pytorch_geometric/pull/7611))
- Added the `PCQM4Mv2` dataset as a reference implementation for `OnDiskDataset` ([#8102](https://github.com/pyg-team/pytorch_geometric/pull/8102))
Expand Down
26 changes: 25 additions & 1 deletion test/loader/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@

from torch_geometric.data import Data, HeteroData, OnDiskDataset
from torch_geometric.loader import DataLoader
from torch_geometric.testing import get_random_edge_index, withCUDA
from torch_geometric.testing import (
get_random_edge_index,
withCUDA,
withPackage,
)

with_mp = sys.platform not in ['win32']
num_workers_list = [0, 2] if with_mp else [0]
Expand Down Expand Up @@ -181,6 +185,26 @@ def test_heterogeneous_dataloader(num_workers):
assert id(batch) == id(store._parent())


@withPackage('torch_frame')
def test_dataloader_tensor_frame(get_tensor_frame):
tf = get_tensor_frame(10)
loader = DataLoader([tf, tf, tf, tf], batch_size=2, shuffle=False)
assert len(loader) == 2

for batch in loader:
assert batch.num_rows == 20

data = Data(tf=tf, edge_index=get_random_edge_index(10, 10, 20))
loader = DataLoader([data, data, data, data], batch_size=2, shuffle=False)
assert len(loader) == 2

for batch in loader:
assert batch.num_graphs == len(batch) == 2
assert batch.num_nodes == 20
assert batch.tf.num_rows == 20
assert batch.edge_index.max() >= 10


if __name__ == '__main__':
import argparse
import time
Expand Down
14 changes: 13 additions & 1 deletion torch_geometric/data/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
import torch_geometric.typing
from torch_geometric.data.data import BaseData
from torch_geometric.data.storage import BaseStorage, NodeStorage
from torch_geometric.typing import SparseTensor, torch_sparse
from torch_geometric.typing import (
SparseTensor,
TensorFrame,
torch_frame,
torch_sparse,
)
from torch_geometric.utils import cumsum, is_sparse, is_torch_sparse_tensor
from torch_geometric.utils.sparse import cat

Expand Down Expand Up @@ -173,6 +178,13 @@ def _collate(

return value, slices, incs

elif isinstance(elem, TensorFrame):
key = str(key)
sizes = torch.tensor([value.num_rows for value in values])
slices = cumsum(sizes)
value = torch_frame.cat(values, along='row')
return value, slices, None

elif is_sparse(elem) and increment:
# Concatenate a list of `SparseTensor` along the `cat_dim`.
# NOTE: `cat_dim` may return a tuple to allow for diagonal stacking.
Expand Down
11 changes: 2 additions & 9 deletions torch_geometric/data/hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
QueryType,
SparseTensor,
TensorFrame,
torch_frame,
)
from torch_geometric.utils import (
bipartite_subgraph,
Expand Down Expand Up @@ -912,15 +913,7 @@ def _consistent_size(stores: List[BaseStorage]) -> List[str]:
continue
values = [store[key] for store in self.node_stores]
if isinstance(values[0], TensorFrame):
# TODO (jinu) Implement `cat` function for TensorFrame.
feat_dict = {}
for stype in values[0].feat_dict.keys():
feat_dict[stype] = torch.cat(
[value.feat_dict[stype] for value in values], dim=0)
y = None
if values[0].y is not None:
y = torch.cat([value.y for value in values], dim=0)
value = TensorFrame(feat_dict, values[0].col_names_dict, y)
value = torch_frame.cat(values, along='row')
else:
dim = self.__cat_dim__(key, values[0], self.node_stores[0])
dim = values[0].dim() + dim if dim < 0 else dim
Expand Down
8 changes: 7 additions & 1 deletion torch_geometric/data/separate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from torch_geometric.data.data import BaseData
from torch_geometric.data.storage import BaseStorage
from torch_geometric.typing import SparseTensor
from torch_geometric.typing import SparseTensor, TensorFrame
from torch_geometric.utils import narrow


Expand Down Expand Up @@ -81,6 +81,12 @@ def _separate(
value = value.narrow(dim, start, end - start)
return value

elif isinstance(value, TensorFrame):
key = str(key)
start, end = int(slices[idx]), int(slices[idx + 1])
value = value[start:end]
return value

elif isinstance(value, Mapping):
# Recursively separate elements of dictionaries.
return {
Expand Down
3 changes: 3 additions & 0 deletions torch_geometric/loader/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch_geometric.data.data import BaseData
from torch_geometric.data.datapipes import DatasetAdapter
from torch_geometric.data.on_disk_dataset import OnDiskDataset
from torch_geometric.typing import TensorFrame, torch_frame


class Collater:
Expand All @@ -31,6 +32,8 @@ def __call__(self, batch: List[Any]) -> Any:
)
elif isinstance(elem, torch.Tensor):
return default_collate(batch)
elif isinstance(elem, TensorFrame):
return torch_frame.cat(batch, along='row')
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float)
elif isinstance(elem, int):
Expand Down
1 change: 1 addition & 0 deletions torch_geometric/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def masked_select_nnz(src: SparseTensor, mask: Tensor,
WITH_TORCH_FRAME = True
from torch_frame import TensorFrame
except Exception:
torch_frame = object
WITH_TORCH_FRAME = False

class TensorFrame:
Expand Down

0 comments on commit 7b7fdc1

Please sign in to comment.