diff --git a/CHANGELOG.md b/CHANGELOG.md index 900693c5152a..b5bab5330031 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.0.5] - 2022-MM-DD ### Added +- Added support for `follow_batch` for lists or dictionaries of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837)) - Added `Data.validate()` and `HeteroData.validate()` functionality ([#4885](https://github.com/pyg-team/pytorch_geometric/pull/4885)) - Added `LinkNeighborLoader` support to `LightningDataModule` ([#4868](https://github.com/pyg-team/pytorch_geometric/pull/4868)) - Added `predict()` support to the `LightningNodeData` module ([#4884](https://github.com/pyg-team/pytorch_geometric/pull/4884)) @@ -47,6 +48,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582)) - Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581)) ### Changed +- Fixed `InMemoryDataset` inferring wrong `len` for lists of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837)) +- Fixed `Batch.separate` when using it for lists of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837)) - Correct docstring for SAGEConv ([#4852](https://github.com/pyg-team/pytorch_geometric/pull/4852)) - Fixed a bug in `TUDataset` where `pre_filter` was not applied whenever `pre_transform` was present - Renamed `RandomTranslate` to `RandomJitter` - the usage of `RandomTranslate` is now deprecated ([#4828](https://github.com/pyg-team/pytorch_geometric/pull/4828)) diff --git a/test/data/test_batch.py b/test/data/test_batch.py index 098710a8d7b4..2e41b0e1de13 100644 --- a/test/data/test_batch.py +++ b/test/data/test_batch.py @@ -387,3 +387,42 @@ def test_batch_with_empty_list(): assert batch.nontensor == [[], []] assert batch[0].nontensor == [] assert batch[1].nontensor == [] + + +def test_nested_follow_batch(): + def tr(n, m): + return torch.rand((n, m)) + + d1 = Data(xs=[tr(4, 3), tr(11, 4), tr(1, 2)], a={"aa": tr(11, 3)}, + x=tr(10, 5)) + d2 = Data(xs=[tr(5, 3), tr(14, 4), tr(3, 2)], a={"aa": tr(2, 3)}, + x=tr(11, 5)) + d3 = Data(xs=[tr(6, 3), tr(15, 4), tr(2, 2)], a={"aa": tr(4, 3)}, + x=tr(9, 5)) + d4 = Data(xs=[tr(4, 3), tr(16, 4), tr(1, 2)], a={"aa": tr(8, 3)}, + x=tr(8, 5)) + + # Dataset + data_list = [d1, d2, d3, d4] + + batch = Batch.from_data_list(data_list, follow_batch=['xs', 'a']) + + # assert shapes + assert batch.xs[0].shape == (19, 3) + assert batch.xs[1].shape == (56, 4) + assert batch.xs[2].shape == (7, 2) + assert batch.a['aa'].shape == (25, 3) + + assert len(batch.xs_batch) == 3 + assert len(batch.a_batch) == 1 + + # assert _batch + assert batch.xs_batch[0].tolist() == \ + [0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3] + assert batch.xs_batch[1].tolist() == \ + [0] * 11 + [1] * 14 + [2] * 15 + [3] * 16 + assert batch.xs_batch[2].tolist() == \ + [0] * 1 + [1] * 3 + [2] * 2 + [3] * 1 + + assert batch.a_batch['aa'].tolist() == \ + [0] * 11 + [1] * 2 + [2] * 4 + [3] * 8 diff --git a/test/data/test_dataset.py b/test/data/test_dataset.py index e3214fac22e9..6cafea226ced 100644 --- a/test/data/test_dataset.py +++ b/test/data/test_dataset.py @@ -1,4 +1,5 @@ import torch +from torch_sparse import SparseTensor from torch_geometric.data import Data, Dataset, HeteroData, InMemoryDataset @@ -158,3 +159,57 @@ def _process(self): ds = DS3() assert not ds.enter_download assert not ds.enter_process + + +class MyTestDataset2(InMemoryDataset): + def __init__(self, data_list): + super().__init__('/tmp/MyTestDataset2') + self.data, self.slices = self.collate(data_list) + + +def test_lists_of_tensors_in_memory_dataset(): + def tr(n, m): + return torch.rand((n, m)) + + d1 = Data(xs=[tr(4, 3), tr(11, 4), tr(1, 2)]) + d2 = Data(xs=[tr(5, 3), tr(14, 4), tr(3, 2)]) + d3 = Data(xs=[tr(6, 3), tr(15, 4), tr(2, 2)]) + d4 = Data(xs=[tr(4, 3), tr(16, 4), tr(1, 2)]) + + data_list = [d1, d2, d3, d4] + + dataset = MyTestDataset2(data_list) + assert len(dataset) == 4 + assert dataset[0].xs[1].shape == (11, 4) + assert dataset[0].xs[2].shape == (1, 2) + assert dataset[1].xs[0].shape == (5, 3) + assert dataset[2].xs[1].shape == (15, 4) + assert dataset[3].xs[1].shape == (16, 4) + + +class MyTestDataset3(InMemoryDataset): + def __init__(self, data_list): + super().__init__('/tmp/MyTestDataset3') + self.data, self.slices = self.collate(data_list) + + +def test_lists_of_SparseTensors(): + e1 = torch.tensor([[4, 1, 3, 2, 2, 3], [1, 3, 2, 3, 3, 2]]) + e2 = torch.tensor([[0, 1, 4, 7, 2, 9], [7, 2, 2, 1, 4, 7]]) + e3 = torch.tensor([[3, 5, 1, 2, 3, 3], [5, 0, 2, 1, 3, 7]]) + e4 = torch.tensor([[0, 1, 9, 2, 0, 3], [1, 1, 2, 1, 3, 2]]) + adj1 = SparseTensor.from_edge_index(e1, sparse_sizes=(11, 11)) + adj2 = SparseTensor.from_edge_index(e2, sparse_sizes=(22, 22)) + adj3 = SparseTensor.from_edge_index(e3, sparse_sizes=(12, 12)) + adj4 = SparseTensor.from_edge_index(e4, sparse_sizes=(15, 15)) + + d1 = Data(adj_test=[adj1, adj2]) + d2 = Data(adj_test=[adj3, adj4]) + + data_list = [d1, d2] + dataset = MyTestDataset3(data_list) + assert len(dataset) == 2 + assert dataset[0].adj_test[0].sparse_sizes() == (11, 11) + assert dataset[0].adj_test[1].sparse_sizes() == (22, 22) + assert dataset[1].adj_test[0].sparse_sizes() == (12, 12) + assert dataset[1].adj_test[1].sparse_sizes() == (15, 15) diff --git a/torch_geometric/data/collate.py b/torch_geometric/data/collate.py index beddcee4afef..fd253301d8bf 100644 --- a/torch_geometric/data/collate.py +++ b/torch_geometric/data/collate.py @@ -96,12 +96,10 @@ def collate( inc_dict[attr] = incs # Add an additional batch vector for the given attribute: - if (attr in follow_batch and isinstance(slices, Tensor) - and slices.dim() == 1): - repeats = slices[1:] - slices[:-1] - batch = repeat_interleave(repeats.tolist(), device=device) + if attr in follow_batch: + batch, ptr = _batch_and_ptr(slices, device) out_store[f'{attr}_batch'] = batch - out_store[f'{attr}_ptr'] = cumsum(repeats.to(device)) + out_store[f'{attr}_ptr'] = ptr # In case the storage holds node, we add a top-level batch vector it: if (add_batch and isinstance(stores[0], NodeStorage) @@ -199,6 +197,39 @@ def _collate( return values, slices, None +def _batch_and_ptr( + slices: Any, + device: Optional[torch.device] = None, +) -> Tuple[Any, Any]: + if (isinstance(slices, Tensor) and slices.dim() == 1): + # Default case, turn slices tensor into batch. + repeats = slices[1:] - slices[:-1] + batch = repeat_interleave(repeats.tolist(), device=device) + ptr = cumsum(repeats.to(device)) + return batch, ptr + + elif isinstance(slices, Mapping): + # Recursively batch elements of dictionaries. + batch, ptr = {}, {} + for k, v in slices.items(): + batch[k], ptr[k] = _batch_and_ptr(v, device) + return batch, ptr + + elif (isinstance(slices, Sequence) and not isinstance(slices, str) + and isinstance(slices[0], Tensor)): + # Recursively batch elements of lists. + batch, ptr = [], [] + for s in slices: + sub_batch, sub_ptr = _batch_and_ptr(s, device) + batch.append(sub_batch) + ptr.append(sub_ptr) + return batch, ptr + + else: + # Failure of batching, usually due to slices.dim() != 1 + return None, None + + ############################################################################### diff --git a/torch_geometric/data/in_memory_dataset.py b/torch_geometric/data/in_memory_dataset.py index 33929726f672..2d9f454615e5 100644 --- a/torch_geometric/data/in_memory_dataset.py +++ b/torch_geometric/data/in_memory_dataset.py @@ -1,5 +1,5 @@ import copy -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import torch @@ -130,10 +130,13 @@ def copy(self, idx: Optional[IndexType] = None) -> 'InMemoryDataset': return dataset -def nested_iter(mapping: Mapping) -> Iterable: - for key, value in mapping.items(): - if isinstance(value, Mapping): +def nested_iter(node: Union[Mapping, Sequence]) -> Iterable: + if isinstance(node, Mapping): + for key, value in node.items(): for inner_key, inner_value in nested_iter(value): yield inner_key, inner_value - else: - yield key, value + elif isinstance(node, Sequence): + for i, inner_value in enumerate(node): + yield i, inner_value + else: + yield None, node diff --git a/torch_geometric/data/separate.py b/torch_geometric/data/separate.py index b4e708e11533..af8af03ef7ad 100644 --- a/torch_geometric/data/separate.py +++ b/torch_geometric/data/separate.py @@ -90,6 +90,11 @@ def _separate( and not isinstance(value[0], str) and len(value[0]) > 0 and isinstance(value[0][0], (Tensor, SparseTensor))): # Recursively separate elements of lists of lists. + return [elem[idx] for elem in value] + + elif (isinstance(value, Sequence) and not isinstance(value, str) + and isinstance(value[0], (Tensor, SparseTensor))): + # Recursively separate elements of lists of Tensors/SparseTensors. return [ _separate(key, elem, idx, slices[i], incs[i] if decrement else None, batch, store, decrement)