Skip to content

Commit

Permalink
Add batch and ptr vectors for a list of tensors and nested dicts (#…
Browse files Browse the repository at this point in the history
…4837)

* Added support of _ptr and _batch for nested structures, matching the implementation of _collate. Added test in test_batch.py.

* Changed comment added pre-commits.

* Added support of batch vectors for recursive data structures.

* Update CHANGELOG.md

Updated CHANGELOG.md. Still missing pull link.

* Extended `InMemoryDataset` to infer len() correctly when using lists of tensors. Modified seperate such that InMemoryDataset can deal with lists of tensors correctly. Added test for it.

* Added test to cover special elif case in `separate`.

* Apply suggestions from code review

Fix elif bug, change _batch name, clean-up (by rust1s)

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>

* Merged function _batch and _ptr to _batch_and_ptr.

* Updated CHANGELOG.md

* Fixed Type Hint.

* Update collate.py

Removed old/wrong code comment.

Co-authored-by: janmeissnerRWTH <96341562+janmeissnerRWTH@users.noreply.github.com>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored Jun 30, 2022
1 parent 3d6eb74 commit 9bb19b6
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 11 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added support for `follow_batch` for lists or dictionaries of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837))
- Added `Data.validate()` and `HeteroData.validate()` functionality ([#4885](https://github.com/pyg-team/pytorch_geometric/pull/4885))
- Added `LinkNeighborLoader` support to `LightningDataModule` ([#4868](https://github.com/pyg-team/pytorch_geometric/pull/4868))
- Added `predict()` support to the `LightningNodeData` module ([#4884](https://github.com/pyg-team/pytorch_geometric/pull/4884))
Expand Down Expand Up @@ -47,6 +48,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
- Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581))
### Changed
- Fixed `InMemoryDataset` inferring wrong `len` for lists of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837))
- Fixed `Batch.separate` when using it for lists of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837))
- Correct docstring for SAGEConv ([#4852](https://github.com/pyg-team/pytorch_geometric/pull/4852))
- Fixed a bug in `TUDataset` where `pre_filter` was not applied whenever `pre_transform` was present
- Renamed `RandomTranslate` to `RandomJitter` - the usage of `RandomTranslate` is now deprecated ([#4828](https://github.com/pyg-team/pytorch_geometric/pull/4828))
Expand Down
39 changes: 39 additions & 0 deletions test/data/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,3 +387,42 @@ def test_batch_with_empty_list():
assert batch.nontensor == [[], []]
assert batch[0].nontensor == []
assert batch[1].nontensor == []


def test_nested_follow_batch():
def tr(n, m):
return torch.rand((n, m))

d1 = Data(xs=[tr(4, 3), tr(11, 4), tr(1, 2)], a={"aa": tr(11, 3)},
x=tr(10, 5))
d2 = Data(xs=[tr(5, 3), tr(14, 4), tr(3, 2)], a={"aa": tr(2, 3)},
x=tr(11, 5))
d3 = Data(xs=[tr(6, 3), tr(15, 4), tr(2, 2)], a={"aa": tr(4, 3)},
x=tr(9, 5))
d4 = Data(xs=[tr(4, 3), tr(16, 4), tr(1, 2)], a={"aa": tr(8, 3)},
x=tr(8, 5))

# Dataset
data_list = [d1, d2, d3, d4]

batch = Batch.from_data_list(data_list, follow_batch=['xs', 'a'])

# assert shapes
assert batch.xs[0].shape == (19, 3)
assert batch.xs[1].shape == (56, 4)
assert batch.xs[2].shape == (7, 2)
assert batch.a['aa'].shape == (25, 3)

assert len(batch.xs_batch) == 3
assert len(batch.a_batch) == 1

# assert _batch
assert batch.xs_batch[0].tolist() == \
[0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3]
assert batch.xs_batch[1].tolist() == \
[0] * 11 + [1] * 14 + [2] * 15 + [3] * 16
assert batch.xs_batch[2].tolist() == \
[0] * 1 + [1] * 3 + [2] * 2 + [3] * 1

assert batch.a_batch['aa'].tolist() == \
[0] * 11 + [1] * 2 + [2] * 4 + [3] * 8
55 changes: 55 additions & 0 deletions test/data/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch_sparse import SparseTensor

from torch_geometric.data import Data, Dataset, HeteroData, InMemoryDataset

Expand Down Expand Up @@ -158,3 +159,57 @@ def _process(self):
ds = DS3()
assert not ds.enter_download
assert not ds.enter_process


class MyTestDataset2(InMemoryDataset):
def __init__(self, data_list):
super().__init__('/tmp/MyTestDataset2')
self.data, self.slices = self.collate(data_list)


def test_lists_of_tensors_in_memory_dataset():
def tr(n, m):
return torch.rand((n, m))

d1 = Data(xs=[tr(4, 3), tr(11, 4), tr(1, 2)])
d2 = Data(xs=[tr(5, 3), tr(14, 4), tr(3, 2)])
d3 = Data(xs=[tr(6, 3), tr(15, 4), tr(2, 2)])
d4 = Data(xs=[tr(4, 3), tr(16, 4), tr(1, 2)])

data_list = [d1, d2, d3, d4]

dataset = MyTestDataset2(data_list)
assert len(dataset) == 4
assert dataset[0].xs[1].shape == (11, 4)
assert dataset[0].xs[2].shape == (1, 2)
assert dataset[1].xs[0].shape == (5, 3)
assert dataset[2].xs[1].shape == (15, 4)
assert dataset[3].xs[1].shape == (16, 4)


class MyTestDataset3(InMemoryDataset):
def __init__(self, data_list):
super().__init__('/tmp/MyTestDataset3')
self.data, self.slices = self.collate(data_list)


def test_lists_of_SparseTensors():
e1 = torch.tensor([[4, 1, 3, 2, 2, 3], [1, 3, 2, 3, 3, 2]])
e2 = torch.tensor([[0, 1, 4, 7, 2, 9], [7, 2, 2, 1, 4, 7]])
e3 = torch.tensor([[3, 5, 1, 2, 3, 3], [5, 0, 2, 1, 3, 7]])
e4 = torch.tensor([[0, 1, 9, 2, 0, 3], [1, 1, 2, 1, 3, 2]])
adj1 = SparseTensor.from_edge_index(e1, sparse_sizes=(11, 11))
adj2 = SparseTensor.from_edge_index(e2, sparse_sizes=(22, 22))
adj3 = SparseTensor.from_edge_index(e3, sparse_sizes=(12, 12))
adj4 = SparseTensor.from_edge_index(e4, sparse_sizes=(15, 15))

d1 = Data(adj_test=[adj1, adj2])
d2 = Data(adj_test=[adj3, adj4])

data_list = [d1, d2]
dataset = MyTestDataset3(data_list)
assert len(dataset) == 2
assert dataset[0].adj_test[0].sparse_sizes() == (11, 11)
assert dataset[0].adj_test[1].sparse_sizes() == (22, 22)
assert dataset[1].adj_test[0].sparse_sizes() == (12, 12)
assert dataset[1].adj_test[1].sparse_sizes() == (15, 15)
41 changes: 36 additions & 5 deletions torch_geometric/data/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,10 @@ def collate(
inc_dict[attr] = incs

# Add an additional batch vector for the given attribute:
if (attr in follow_batch and isinstance(slices, Tensor)
and slices.dim() == 1):
repeats = slices[1:] - slices[:-1]
batch = repeat_interleave(repeats.tolist(), device=device)
if attr in follow_batch:
batch, ptr = _batch_and_ptr(slices, device)
out_store[f'{attr}_batch'] = batch
out_store[f'{attr}_ptr'] = cumsum(repeats.to(device))
out_store[f'{attr}_ptr'] = ptr

# In case the storage holds node, we add a top-level batch vector it:
if (add_batch and isinstance(stores[0], NodeStorage)
Expand Down Expand Up @@ -199,6 +197,39 @@ def _collate(
return values, slices, None


def _batch_and_ptr(
slices: Any,
device: Optional[torch.device] = None,
) -> Tuple[Any, Any]:
if (isinstance(slices, Tensor) and slices.dim() == 1):
# Default case, turn slices tensor into batch.
repeats = slices[1:] - slices[:-1]
batch = repeat_interleave(repeats.tolist(), device=device)
ptr = cumsum(repeats.to(device))
return batch, ptr

elif isinstance(slices, Mapping):
# Recursively batch elements of dictionaries.
batch, ptr = {}, {}
for k, v in slices.items():
batch[k], ptr[k] = _batch_and_ptr(v, device)
return batch, ptr

elif (isinstance(slices, Sequence) and not isinstance(slices, str)
and isinstance(slices[0], Tensor)):
# Recursively batch elements of lists.
batch, ptr = [], []
for s in slices:
sub_batch, sub_ptr = _batch_and_ptr(s, device)
batch.append(sub_batch)
ptr.append(sub_ptr)
return batch, ptr

else:
# Failure of batching, usually due to slices.dim() != 1
return None, None


###############################################################################


Expand Down
15 changes: 9 additions & 6 deletions torch_geometric/data/in_memory_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -130,10 +130,13 @@ def copy(self, idx: Optional[IndexType] = None) -> 'InMemoryDataset':
return dataset


def nested_iter(mapping: Mapping) -> Iterable:
for key, value in mapping.items():
if isinstance(value, Mapping):
def nested_iter(node: Union[Mapping, Sequence]) -> Iterable:
if isinstance(node, Mapping):
for key, value in node.items():
for inner_key, inner_value in nested_iter(value):
yield inner_key, inner_value
else:
yield key, value
elif isinstance(node, Sequence):
for i, inner_value in enumerate(node):
yield i, inner_value
else:
yield None, node
5 changes: 5 additions & 0 deletions torch_geometric/data/separate.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ def _separate(
and not isinstance(value[0], str) and len(value[0]) > 0
and isinstance(value[0][0], (Tensor, SparseTensor))):
# Recursively separate elements of lists of lists.
return [elem[idx] for elem in value]

elif (isinstance(value, Sequence) and not isinstance(value, str)
and isinstance(value[0], (Tensor, SparseTensor))):
# Recursively separate elements of lists of Tensors/SparseTensors.
return [
_separate(key, elem, idx, slices[i],
incs[i] if decrement else None, batch, store, decrement)
Expand Down

0 comments on commit 9bb19b6

Please sign in to comment.