Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TensorFrame support in DataLoader #8151

Merged
merged 6 commits into from
Oct 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `NeuroGraphDataset` benchmark collection ([#8122](https://github.com/pyg-team/pytorch_geometric/pull/8122))
- Added support for a node-level `mask` tensor in `dense_to_sparse` ([#8117](https://github.com/pyg-team/pytorch_geometric/pull/8117))
- Added the `to_on_disk_dataset()` method to convert `InMemoryDataset` instances to `OnDiskDataset` instances ([#8116](https://github.com/pyg-team/pytorch_geometric/pull/8116))
- Added `torch-frame` support ([#8110](https://github.com/pyg-team/pytorch_geometric/pull/8110), [#8118](https://github.com/pyg-team/pytorch_geometric/pull/8118))
- Added `torch-frame` support ([#8110](https://github.com/pyg-team/pytorch_geometric/pull/8110), [#8118](https://github.com/pyg-team/pytorch_geometric/pull/8118), [#8151](https://github.com/pyg-team/pytorch_geometric/pull/8151))
- Added the `DistLoader` base class ([#8079](https://github.com/pyg-team/pytorch_geometric/pull/8079))
- Added `HyperGraphData` to support hypergraphs ([#7611](https://github.com/pyg-team/pytorch_geometric/pull/7611))
- Added the `PCQM4Mv2` dataset as a reference implementation for `OnDiskDataset` ([#8102](https://github.com/pyg-team/pytorch_geometric/pull/8102))
Expand Down
26 changes: 25 additions & 1 deletion test/loader/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@

from torch_geometric.data import Data, HeteroData, OnDiskDataset
from torch_geometric.loader import DataLoader
from torch_geometric.testing import get_random_edge_index, withCUDA
from torch_geometric.testing import (
get_random_edge_index,
withCUDA,
withPackage,
)

with_mp = sys.platform not in ['win32']
num_workers_list = [0, 2] if with_mp else [0]
Expand Down Expand Up @@ -181,6 +185,26 @@ def test_heterogeneous_dataloader(num_workers):
assert id(batch) == id(store._parent())


@withPackage('torch_frame')
def test_dataloader_tensor_frame(get_tensor_frame):
tf = get_tensor_frame(10)
loader = DataLoader([tf, tf, tf, tf], batch_size=2, shuffle=False)
assert len(loader) == 2

for batch in loader:
assert batch.num_rows == 20

data = Data(tf=tf, edge_index=get_random_edge_index(10, 10, 20))
loader = DataLoader([data, data, data, data], batch_size=2, shuffle=False)
assert len(loader) == 2

for batch in loader:
assert batch.num_graphs == len(batch) == 2
assert batch.num_nodes == 20
assert batch.tf.num_rows == 20
assert batch.edge_index.max() >= 10


if __name__ == '__main__':
import argparse
import time
Expand Down
14 changes: 13 additions & 1 deletion torch_geometric/data/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
import torch_geometric.typing
from torch_geometric.data.data import BaseData
from torch_geometric.data.storage import BaseStorage, NodeStorage
from torch_geometric.typing import SparseTensor, torch_sparse
from torch_geometric.typing import (
SparseTensor,
TensorFrame,
torch_frame,
torch_sparse,
)
from torch_geometric.utils import cumsum, is_sparse, is_torch_sparse_tensor
from torch_geometric.utils.sparse import cat

Expand Down Expand Up @@ -173,6 +178,13 @@ def _collate(

return value, slices, incs

elif isinstance(elem, TensorFrame):
key = str(key)
sizes = torch.tensor([value.num_rows for value in values])
slices = cumsum(sizes)
value = torch_frame.cat(values, along='row')
return value, slices, None

elif is_sparse(elem) and increment:
# Concatenate a list of `SparseTensor` along the `cat_dim`.
# NOTE: `cat_dim` may return a tuple to allow for diagonal stacking.
Expand Down
11 changes: 2 additions & 9 deletions torch_geometric/data/hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
QueryType,
SparseTensor,
TensorFrame,
torch_frame,
)
from torch_geometric.utils import (
bipartite_subgraph,
Expand Down Expand Up @@ -912,15 +913,7 @@ def _consistent_size(stores: List[BaseStorage]) -> List[str]:
continue
values = [store[key] for store in self.node_stores]
if isinstance(values[0], TensorFrame):
# TODO (jinu) Implement `cat` function for TensorFrame.
feat_dict = {}
for stype in values[0].feat_dict.keys():
feat_dict[stype] = torch.cat(
[value.feat_dict[stype] for value in values], dim=0)
y = None
if values[0].y is not None:
y = torch.cat([value.y for value in values], dim=0)
value = TensorFrame(feat_dict, values[0].col_names_dict, y)
value = torch_frame.cat(values, along='row')
else:
dim = self.__cat_dim__(key, values[0], self.node_stores[0])
dim = values[0].dim() + dim if dim < 0 else dim
Expand Down
8 changes: 7 additions & 1 deletion torch_geometric/data/separate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from torch_geometric.data.data import BaseData
from torch_geometric.data.storage import BaseStorage
from torch_geometric.typing import SparseTensor
from torch_geometric.typing import SparseTensor, TensorFrame
from torch_geometric.utils import narrow


Expand Down Expand Up @@ -81,6 +81,12 @@ def _separate(
value = value.narrow(dim, start, end - start)
return value

elif isinstance(value, TensorFrame):
key = str(key)
start, end = int(slices[idx]), int(slices[idx + 1])
value = value[start:end]
return value

elif isinstance(value, Mapping):
# Recursively separate elements of dictionaries.
return {
Expand Down
3 changes: 3 additions & 0 deletions torch_geometric/loader/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch_geometric.data.data import BaseData
from torch_geometric.data.datapipes import DatasetAdapter
from torch_geometric.data.on_disk_dataset import OnDiskDataset
from torch_geometric.typing import TensorFrame, torch_frame


class Collater:
Expand All @@ -31,6 +32,8 @@ def __call__(self, batch: List[Any]) -> Any:
)
elif isinstance(elem, torch.Tensor):
return default_collate(batch)
elif isinstance(elem, TensorFrame):
return torch_frame.cat(batch, along='row')
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float)
elif isinstance(elem, int):
Expand Down
1 change: 1 addition & 0 deletions torch_geometric/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def masked_select_nnz(src: SparseTensor, mask: Tensor,
WITH_TORCH_FRAME = True
from torch_frame import TensorFrame
except Exception:
torch_frame = object
WITH_TORCH_FRAME = False

class TensorFrame:
Expand Down