Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch and ptr vectors for a list of tensors and nested dicts #4837

Merged
merged 17 commits into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added support for `follow_batch` for lists or dictionaries of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837))
- Added `Data.validate()` and `HeteroData.validate()` functionality ([#4885](https://github.com/pyg-team/pytorch_geometric/pull/4885))
- Added `LinkNeighborLoader` support to `LightningDataModule` ([#4868](https://github.com/pyg-team/pytorch_geometric/pull/4868))
- Added `predict()` support to the `LightningNodeData` module ([#4884](https://github.com/pyg-team/pytorch_geometric/pull/4884))
Expand Down Expand Up @@ -47,6 +48,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
- Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581))
### Changed
- Fixed `InMemoryDataset` inferring wrong `len` for lists of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837))
- Fixed `Batch.separate` when using it for lists of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837))
- Correct docstring for SAGEConv ([#4852](https://github.com/pyg-team/pytorch_geometric/pull/4852))
- Fixed a bug in `TUDataset` where `pre_filter` was not applied whenever `pre_transform` was present
- Renamed `RandomTranslate` to `RandomJitter` - the usage of `RandomTranslate` is now deprecated ([#4828](https://github.com/pyg-team/pytorch_geometric/pull/4828))
Expand Down
39 changes: 39 additions & 0 deletions test/data/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,3 +387,42 @@ def test_batch_with_empty_list():
assert batch.nontensor == [[], []]
assert batch[0].nontensor == []
assert batch[1].nontensor == []


def test_nested_follow_batch():
def tr(n, m):
return torch.rand((n, m))

d1 = Data(xs=[tr(4, 3), tr(11, 4), tr(1, 2)], a={"aa": tr(11, 3)},
x=tr(10, 5))
d2 = Data(xs=[tr(5, 3), tr(14, 4), tr(3, 2)], a={"aa": tr(2, 3)},
x=tr(11, 5))
d3 = Data(xs=[tr(6, 3), tr(15, 4), tr(2, 2)], a={"aa": tr(4, 3)},
x=tr(9, 5))
d4 = Data(xs=[tr(4, 3), tr(16, 4), tr(1, 2)], a={"aa": tr(8, 3)},
x=tr(8, 5))

# Dataset
data_list = [d1, d2, d3, d4]

batch = Batch.from_data_list(data_list, follow_batch=['xs', 'a'])

# assert shapes
assert batch.xs[0].shape == (19, 3)
assert batch.xs[1].shape == (56, 4)
assert batch.xs[2].shape == (7, 2)
assert batch.a['aa'].shape == (25, 3)

assert len(batch.xs_batch) == 3
assert len(batch.a_batch) == 1

# assert _batch
assert batch.xs_batch[0].tolist() == \
[0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3]
assert batch.xs_batch[1].tolist() == \
[0] * 11 + [1] * 14 + [2] * 15 + [3] * 16
assert batch.xs_batch[2].tolist() == \
[0] * 1 + [1] * 3 + [2] * 2 + [3] * 1

assert batch.a_batch['aa'].tolist() == \
[0] * 11 + [1] * 2 + [2] * 4 + [3] * 8
55 changes: 55 additions & 0 deletions test/data/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch_sparse import SparseTensor

from torch_geometric.data import Data, Dataset, HeteroData, InMemoryDataset

Expand Down Expand Up @@ -158,3 +159,57 @@ def _process(self):
ds = DS3()
assert not ds.enter_download
assert not ds.enter_process


class MyTestDataset2(InMemoryDataset):
def __init__(self, data_list):
super().__init__('/tmp/MyTestDataset2')
self.data, self.slices = self.collate(data_list)


def test_lists_of_tensors_in_memory_dataset():
def tr(n, m):
return torch.rand((n, m))

d1 = Data(xs=[tr(4, 3), tr(11, 4), tr(1, 2)])
d2 = Data(xs=[tr(5, 3), tr(14, 4), tr(3, 2)])
d3 = Data(xs=[tr(6, 3), tr(15, 4), tr(2, 2)])
d4 = Data(xs=[tr(4, 3), tr(16, 4), tr(1, 2)])

data_list = [d1, d2, d3, d4]

dataset = MyTestDataset2(data_list)
assert len(dataset) == 4
assert dataset[0].xs[1].shape == (11, 4)
assert dataset[0].xs[2].shape == (1, 2)
assert dataset[1].xs[0].shape == (5, 3)
assert dataset[2].xs[1].shape == (15, 4)
assert dataset[3].xs[1].shape == (16, 4)


class MyTestDataset3(InMemoryDataset):
def __init__(self, data_list):
super().__init__('/tmp/MyTestDataset3')
self.data, self.slices = self.collate(data_list)


def test_lists_of_SparseTensors():
e1 = torch.tensor([[4, 1, 3, 2, 2, 3], [1, 3, 2, 3, 3, 2]])
e2 = torch.tensor([[0, 1, 4, 7, 2, 9], [7, 2, 2, 1, 4, 7]])
e3 = torch.tensor([[3, 5, 1, 2, 3, 3], [5, 0, 2, 1, 3, 7]])
e4 = torch.tensor([[0, 1, 9, 2, 0, 3], [1, 1, 2, 1, 3, 2]])
adj1 = SparseTensor.from_edge_index(e1, sparse_sizes=(11, 11))
adj2 = SparseTensor.from_edge_index(e2, sparse_sizes=(22, 22))
adj3 = SparseTensor.from_edge_index(e3, sparse_sizes=(12, 12))
adj4 = SparseTensor.from_edge_index(e4, sparse_sizes=(15, 15))

d1 = Data(adj_test=[adj1, adj2])
d2 = Data(adj_test=[adj3, adj4])

data_list = [d1, d2]
dataset = MyTestDataset3(data_list)
assert len(dataset) == 2
assert dataset[0].adj_test[0].sparse_sizes() == (11, 11)
assert dataset[0].adj_test[1].sparse_sizes() == (22, 22)
assert dataset[1].adj_test[0].sparse_sizes() == (12, 12)
assert dataset[1].adj_test[1].sparse_sizes() == (15, 15)
41 changes: 36 additions & 5 deletions torch_geometric/data/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,10 @@ def collate(
inc_dict[attr] = incs

# Add an additional batch vector for the given attribute:
if (attr in follow_batch and isinstance(slices, Tensor)
and slices.dim() == 1):
repeats = slices[1:] - slices[:-1]
batch = repeat_interleave(repeats.tolist(), device=device)
if attr in follow_batch:
batch, ptr = _batch_and_ptr(slices, device)
out_store[f'{attr}_batch'] = batch
out_store[f'{attr}_ptr'] = cumsum(repeats.to(device))
out_store[f'{attr}_ptr'] = ptr

# In case the storage holds node, we add a top-level batch vector it:
if (add_batch and isinstance(stores[0], NodeStorage)
Expand Down Expand Up @@ -199,6 +197,39 @@ def _collate(
return values, slices, None


def _batch_and_ptr(
slices: Any,
device: Optional[torch.device] = None,
) -> Tuple[Any, Any]:
if (isinstance(slices, Tensor) and slices.dim() == 1):
# Default case, turn slices tensor into batch.
repeats = slices[1:] - slices[:-1]
batch = repeat_interleave(repeats.tolist(), device=device)
ptr = cumsum(repeats.to(device))
return batch, ptr

elif isinstance(slices, Mapping):
# Recursively batch elements of dictionaries.
batch, ptr = {}, {}
for k, v in slices.items():
batch[k], ptr[k] = _batch_and_ptr(v, device)
return batch, ptr

elif (isinstance(slices, Sequence) and not isinstance(slices, str)
and isinstance(slices[0], Tensor)):
# Recursively batch elements of lists.
batch, ptr = [], []
for s in slices:
sub_batch, sub_ptr = _batch_and_ptr(s, device)
batch.append(sub_batch)
ptr.append(sub_ptr)
return batch, ptr

else:
# Failure of batching, usually due to slices.dim() != 1
return None, None


###############################################################################


Expand Down
15 changes: 9 additions & 6 deletions torch_geometric/data/in_memory_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -130,10 +130,13 @@ def copy(self, idx: Optional[IndexType] = None) -> 'InMemoryDataset':
return dataset


def nested_iter(mapping: Mapping) -> Iterable:
for key, value in mapping.items():
if isinstance(value, Mapping):
def nested_iter(node: Union[Mapping, Sequence]) -> Iterable:
if isinstance(node, Mapping):
for key, value in node.items():
for inner_key, inner_value in nested_iter(value):
yield inner_key, inner_value
else:
yield key, value
elif isinstance(node, Sequence):
for i, inner_value in enumerate(node):
yield i, inner_value
else:
yield None, node
5 changes: 5 additions & 0 deletions torch_geometric/data/separate.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ def _separate(
and not isinstance(value[0], str) and len(value[0]) > 0
and isinstance(value[0][0], (Tensor, SparseTensor))):
# Recursively separate elements of lists of lists.
return [elem[idx] for elem in value]

elif (isinstance(value, Sequence) and not isinstance(value, str)
and isinstance(value[0], (Tensor, SparseTensor))):
# Recursively separate elements of lists of Tensors/SparseTensors.
return [
_separate(key, elem, idx, slices[i],
incs[i] if decrement else None, batch, store, decrement)
Expand Down