From e325ced1358e37186ce37365f58ac18f557d7a11 Mon Sep 17 00:00:00 2001 From: DomInvivo Date: Sat, 25 Mar 2023 11:48:35 -0400 Subject: [PATCH 1/5] Adding a different logic for the unbatching of torch --- torch_geometric/data/separate.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/torch_geometric/data/separate.py b/torch_geometric/data/separate.py index a3850a251198..8c234c723940 100644 --- a/torch_geometric/data/separate.py +++ b/torch_geometric/data/separate.py @@ -1,7 +1,7 @@ from collections.abc import Mapping, Sequence from typing import Any -from torch import Tensor +from torch import Tensor, sparse_coo, arange, long from torch_geometric.data.data import BaseData from torch_geometric.data.storage import BaseStorage @@ -56,7 +56,7 @@ def _separate( decrement: bool, ) -> Any: - if isinstance(value, Tensor): + if isinstance(value, Tensor) and not value.is_sparse: # Narrow a `torch.Tensor` based on `slices`. # NOTE: We need to take care of decrementing elements appropriately. key = str(key) @@ -68,6 +68,17 @@ def _separate( value = value - incs[idx].to(value.device) return value + elif isinstance(value, Tensor) and value.is_sparse: + # Allows to unbatch a sparse tensors from pytorch, including `sparse_coo_tensor` + key = str(key) + cat_dim = batch.__cat_dim__(key, value, store) + start, end = int(slices[idx]), int(slices[idx + 1]) + indices = arange(start, end, dtype=long) + value = value.index_select(cat_dim or 0, indices) + if decrement and (incs.dim() > 1 or int(incs[idx]) != 0): + value = value - incs[idx].to(value.device) + return value + elif isinstance(value, SparseTensor) and decrement: # Narrow a `SparseTensor` based on `slices`. # NOTE: `cat_dim` may return a tuple to allow for diagonal stacking. From f149bab0212f93cf62023c5eea0a4d54f37b61a7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 25 Mar 2023 15:53:24 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_geometric/data/separate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/data/separate.py b/torch_geometric/data/separate.py index 8c234c723940..62e5665e5dc5 100644 --- a/torch_geometric/data/separate.py +++ b/torch_geometric/data/separate.py @@ -1,7 +1,7 @@ from collections.abc import Mapping, Sequence from typing import Any -from torch import Tensor, sparse_coo, arange, long +from torch import Tensor, arange, long, sparse_coo from torch_geometric.data.data import BaseData from torch_geometric.data.storage import BaseStorage From 56b8c5939f172a4a1ac0c8096a3f70d0b19f9e1b Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 27 Mar 2023 08:22:11 +0000 Subject: [PATCH 3/5] update --- CHANGELOG.md | 2 ++ test/data/test_batch.py | 42 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d782d9ec530..1f44c14dea2a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.4.0] - 2023-MM-DD ### Added + +- Added unbatching logic for `torch.sparse` tensors ([#7037](https://github.com/pyg-team/pytorch_geometric/pull/7037)) - Added the `RotatE` KGE model ([#7026](https://github.com/pyg-team/pytorch_geometric/pull/7026)) ### Changed diff --git a/test/data/test_batch.py b/test/data/test_batch.py index 5a521c03e00b..fc63c16c70fc 100644 --- a/test/data/test_batch.py +++ b/test/data/test_batch.py @@ -166,6 +166,48 @@ def test_batch_with_sparse_tensor(): assert data_list[2].adj.coo()[1].tolist() == [1, 0, 2, 1, 3, 2] +def test_batch_with_torch_coo_tensor(): + x = torch.tensor([[1.0], [2.0], [3.0]]).to_sparse_coo() + data1 = Data(x=x) + + x = torch.tensor([[1.0], [2.0]]).to_sparse_coo() + data2 = Data(x=x) + + x = torch.tensor([[1.0], [2.0], [3.0], [4.0]]).to_sparse_coo() + data3 = Data(x=x) + + batch = Batch.from_data_list([data1]) + assert str(batch) == ('DataBatch(x=[3, 1], batch=[3], ptr=[2])') + assert batch.num_graphs == len(batch) == 1 + assert batch.x.to_dense().tolist() == [[1], [2], [3]] + assert batch.batch.tolist() == [0, 0, 0] + assert batch.ptr.tolist() == [0, 3] + + batch = Batch.from_data_list([data1, data2, data3]) + + assert str(batch) == ('DataBatch(x=[9, 1], batch=[9], ptr=[4])') + assert batch.num_graphs == len(batch) == 3 + assert batch.x.to_dense().view(-1).tolist() == [1, 2, 3, 1, 2, 1, 2, 3, 4] + assert batch.batch.tolist() == [0, 0, 0, 1, 1, 2, 2, 2, 2] + assert batch.ptr.tolist() == [0, 3, 5, 9] + + assert str(batch[0]) == ("Data(x=[3, 1])") + assert str(batch[1]) == ("Data(x=[2, 1])") + assert str(batch[2]) == ("Data(x=[4, 1])") + + data_list = batch.to_data_list() + assert len(data_list) == 3 + + assert len(data_list[0]) == 1 + assert data_list[0].x.to_dense().tolist() == [[1], [2], [3]] + + assert len(data_list[1]) == 1 + assert data_list[1].x.to_dense().tolist() == [[1], [2]] + + assert len(data_list[2]) == 1 + assert data_list[2].x.to_dense().tolist() == [[1], [2], [3], [4]] + + def test_batching_with_new_dimension(): torch_geometric.set_debug(True) From ff187d73a0001e44f516ee918c98d2cceb86faa9 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 27 Mar 2023 08:28:13 +0000 Subject: [PATCH 4/5] update --- torch_geometric/data/separate.py | 16 +++------------- torch_geometric/utils/select.py | 6 ++++++ 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/torch_geometric/data/separate.py b/torch_geometric/data/separate.py index 62e5665e5dc5..ffe6e74e6a72 100644 --- a/torch_geometric/data/separate.py +++ b/torch_geometric/data/separate.py @@ -6,6 +6,7 @@ from torch_geometric.data.data import BaseData from torch_geometric.data.storage import BaseStorage from torch_geometric.typing import SparseTensor +from torch_geometric.utils import narrow def separate(cls, batch: BaseData, idx: int, slice_dict: Any, @@ -56,29 +57,18 @@ def _separate( decrement: bool, ) -> Any: - if isinstance(value, Tensor) and not value.is_sparse: + if isinstance(value, Tensor): # Narrow a `torch.Tensor` based on `slices`. # NOTE: We need to take care of decrementing elements appropriately. key = str(key) cat_dim = batch.__cat_dim__(key, value, store) start, end = int(slices[idx]), int(slices[idx + 1]) - value = value.narrow(cat_dim or 0, start, end - start) + value = narrow(value, cat_dim or 0, start, end - start) value = value.squeeze(0) if cat_dim is None else value if decrement and (incs.dim() > 1 or int(incs[idx]) != 0): value = value - incs[idx].to(value.device) return value - elif isinstance(value, Tensor) and value.is_sparse: - # Allows to unbatch a sparse tensors from pytorch, including `sparse_coo_tensor` - key = str(key) - cat_dim = batch.__cat_dim__(key, value, store) - start, end = int(slices[idx]), int(slices[idx + 1]) - indices = arange(start, end, dtype=long) - value = value.index_select(cat_dim or 0, indices) - if decrement and (incs.dim() > 1 or int(incs[idx]) != 0): - value = value - incs[idx].to(value.device) - return value - elif isinstance(value, SparseTensor) and decrement: # Narrow a `SparseTensor` based on `slices`. # NOTE: `cat_dim` may return a tuple to allow for diagonal stacking. diff --git a/torch_geometric/utils/select.py b/torch_geometric/utils/select.py index 210b127df2cc..54d0a3961f5d 100644 --- a/torch_geometric/utils/select.py +++ b/torch_geometric/utils/select.py @@ -4,6 +4,7 @@ from torch import Tensor from torch_geometric.utils.mask import mask_select +from torch_geometric.utils.sparse import is_torch_sparse_tensor def select(src: Union[Tensor, List[Any]], index_or_mask: Tensor, @@ -41,6 +42,11 @@ def narrow(src: Union[Tensor, List[Any]], dim: int, start: int, start (int): The starting dimension. length (int): The distance to the ending dimension. """ + if is_torch_sparse_tensor(src): + # TODO Sparse tensors in `torch.sparse` do not yet support `narrow`. + index = torch.arange(start, start + length, device=src.device) + return src.index_select(dim, index) + if isinstance(src, Tensor): return src.narrow(dim, start, length) From 097aa6674733001387c3161b561f6793d535fe85 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 27 Mar 2023 08:28:38 +0000 Subject: [PATCH 5/5] update --- torch_geometric/data/separate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/data/separate.py b/torch_geometric/data/separate.py index ffe6e74e6a72..5fcb98b9203c 100644 --- a/torch_geometric/data/separate.py +++ b/torch_geometric/data/separate.py @@ -1,7 +1,7 @@ from collections.abc import Mapping, Sequence from typing import Any -from torch import Tensor, arange, long, sparse_coo +from torch import Tensor from torch_geometric.data.data import BaseData from torch_geometric.data.storage import BaseStorage