Skip to content

Commit

Permalink
Adding unbatching support for torch.sparse.Tensor (#7037)
Browse files Browse the repository at this point in the history
This PR allows to fix the issues with `Batch.from_data_list` or
`Batch.get_example` or `Batch__getitem__` when using pytorch's sparse
tensors, such as `sparse_coo_tensor`. See issue #7022

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored Mar 27, 2023
1 parent d3f4471 commit 0e439f3
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [2.4.0] - 2023-MM-DD

### Added

- Added unbatching logic for `torch.sparse` tensors ([#7037](https://github.com/pyg-team/pytorch_geometric/pull/7037))
- Added the `RotatE` KGE model ([#7026](https://github.com/pyg-team/pytorch_geometric/pull/7026))

### Changed
Expand Down
42 changes: 42 additions & 0 deletions test/data/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,48 @@ def test_batch_with_sparse_tensor():
assert data_list[2].adj.coo()[1].tolist() == [1, 0, 2, 1, 3, 2]


def test_batch_with_torch_coo_tensor():
x = torch.tensor([[1.0], [2.0], [3.0]]).to_sparse_coo()
data1 = Data(x=x)

x = torch.tensor([[1.0], [2.0]]).to_sparse_coo()
data2 = Data(x=x)

x = torch.tensor([[1.0], [2.0], [3.0], [4.0]]).to_sparse_coo()
data3 = Data(x=x)

batch = Batch.from_data_list([data1])
assert str(batch) == ('DataBatch(x=[3, 1], batch=[3], ptr=[2])')
assert batch.num_graphs == len(batch) == 1
assert batch.x.to_dense().tolist() == [[1], [2], [3]]
assert batch.batch.tolist() == [0, 0, 0]
assert batch.ptr.tolist() == [0, 3]

batch = Batch.from_data_list([data1, data2, data3])

assert str(batch) == ('DataBatch(x=[9, 1], batch=[9], ptr=[4])')
assert batch.num_graphs == len(batch) == 3
assert batch.x.to_dense().view(-1).tolist() == [1, 2, 3, 1, 2, 1, 2, 3, 4]
assert batch.batch.tolist() == [0, 0, 0, 1, 1, 2, 2, 2, 2]
assert batch.ptr.tolist() == [0, 3, 5, 9]

assert str(batch[0]) == ("Data(x=[3, 1])")
assert str(batch[1]) == ("Data(x=[2, 1])")
assert str(batch[2]) == ("Data(x=[4, 1])")

data_list = batch.to_data_list()
assert len(data_list) == 3

assert len(data_list[0]) == 1
assert data_list[0].x.to_dense().tolist() == [[1], [2], [3]]

assert len(data_list[1]) == 1
assert data_list[1].x.to_dense().tolist() == [[1], [2]]

assert len(data_list[2]) == 1
assert data_list[2].x.to_dense().tolist() == [[1], [2], [3], [4]]


def test_batching_with_new_dimension():
torch_geometric.set_debug(True)

Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/data/separate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch_geometric.data.data import BaseData
from torch_geometric.data.storage import BaseStorage
from torch_geometric.typing import SparseTensor
from torch_geometric.utils import narrow


def separate(cls, batch: BaseData, idx: int, slice_dict: Any,
Expand Down Expand Up @@ -62,7 +63,7 @@ def _separate(
key = str(key)
cat_dim = batch.__cat_dim__(key, value, store)
start, end = int(slices[idx]), int(slices[idx + 1])
value = value.narrow(cat_dim or 0, start, end - start)
value = narrow(value, cat_dim or 0, start, end - start)
value = value.squeeze(0) if cat_dim is None else value
if decrement and (incs.dim() > 1 or int(incs[idx]) != 0):
value = value - incs[idx].to(value.device)
Expand Down
6 changes: 6 additions & 0 deletions torch_geometric/utils/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch import Tensor

from torch_geometric.utils.mask import mask_select
from torch_geometric.utils.sparse import is_torch_sparse_tensor


def select(src: Union[Tensor, List[Any]], index_or_mask: Tensor,
Expand Down Expand Up @@ -41,6 +42,11 @@ def narrow(src: Union[Tensor, List[Any]], dim: int, start: int,
start (int): The starting dimension.
length (int): The distance to the ending dimension.
"""
if is_torch_sparse_tensor(src):
# TODO Sparse tensors in `torch.sparse` do not yet support `narrow`.
index = torch.arange(start, start + length, device=src.device)
return src.index_select(dim, index)

if isinstance(src, Tensor):
return src.narrow(dim, start, length)

Expand Down

0 comments on commit 0e439f3

Please sign in to comment.