diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d782d9ec530..1f44c14dea2a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.4.0] - 2023-MM-DD ### Added + +- Added unbatching logic for `torch.sparse` tensors ([#7037](https://github.com/pyg-team/pytorch_geometric/pull/7037)) - Added the `RotatE` KGE model ([#7026](https://github.com/pyg-team/pytorch_geometric/pull/7026)) ### Changed diff --git a/test/data/test_batch.py b/test/data/test_batch.py index 5a521c03e00b..fc63c16c70fc 100644 --- a/test/data/test_batch.py +++ b/test/data/test_batch.py @@ -166,6 +166,48 @@ def test_batch_with_sparse_tensor(): assert data_list[2].adj.coo()[1].tolist() == [1, 0, 2, 1, 3, 2] +def test_batch_with_torch_coo_tensor(): + x = torch.tensor([[1.0], [2.0], [3.0]]).to_sparse_coo() + data1 = Data(x=x) + + x = torch.tensor([[1.0], [2.0]]).to_sparse_coo() + data2 = Data(x=x) + + x = torch.tensor([[1.0], [2.0], [3.0], [4.0]]).to_sparse_coo() + data3 = Data(x=x) + + batch = Batch.from_data_list([data1]) + assert str(batch) == ('DataBatch(x=[3, 1], batch=[3], ptr=[2])') + assert batch.num_graphs == len(batch) == 1 + assert batch.x.to_dense().tolist() == [[1], [2], [3]] + assert batch.batch.tolist() == [0, 0, 0] + assert batch.ptr.tolist() == [0, 3] + + batch = Batch.from_data_list([data1, data2, data3]) + + assert str(batch) == ('DataBatch(x=[9, 1], batch=[9], ptr=[4])') + assert batch.num_graphs == len(batch) == 3 + assert batch.x.to_dense().view(-1).tolist() == [1, 2, 3, 1, 2, 1, 2, 3, 4] + assert batch.batch.tolist() == [0, 0, 0, 1, 1, 2, 2, 2, 2] + assert batch.ptr.tolist() == [0, 3, 5, 9] + + assert str(batch[0]) == ("Data(x=[3, 1])") + assert str(batch[1]) == ("Data(x=[2, 1])") + assert str(batch[2]) == ("Data(x=[4, 1])") + + data_list = batch.to_data_list() + assert len(data_list) == 3 + + assert len(data_list[0]) == 1 + assert data_list[0].x.to_dense().tolist() == [[1], [2], [3]] + + assert len(data_list[1]) == 1 + assert data_list[1].x.to_dense().tolist() == [[1], [2]] + + assert len(data_list[2]) == 1 + assert data_list[2].x.to_dense().tolist() == [[1], [2], [3], [4]] + + def test_batching_with_new_dimension(): torch_geometric.set_debug(True) diff --git a/torch_geometric/data/separate.py b/torch_geometric/data/separate.py index a3850a251198..5fcb98b9203c 100644 --- a/torch_geometric/data/separate.py +++ b/torch_geometric/data/separate.py @@ -6,6 +6,7 @@ from torch_geometric.data.data import BaseData from torch_geometric.data.storage import BaseStorage from torch_geometric.typing import SparseTensor +from torch_geometric.utils import narrow def separate(cls, batch: BaseData, idx: int, slice_dict: Any, @@ -62,7 +63,7 @@ def _separate( key = str(key) cat_dim = batch.__cat_dim__(key, value, store) start, end = int(slices[idx]), int(slices[idx + 1]) - value = value.narrow(cat_dim or 0, start, end - start) + value = narrow(value, cat_dim or 0, start, end - start) value = value.squeeze(0) if cat_dim is None else value if decrement and (incs.dim() > 1 or int(incs[idx]) != 0): value = value - incs[idx].to(value.device) diff --git a/torch_geometric/utils/select.py b/torch_geometric/utils/select.py index 210b127df2cc..54d0a3961f5d 100644 --- a/torch_geometric/utils/select.py +++ b/torch_geometric/utils/select.py @@ -4,6 +4,7 @@ from torch import Tensor from torch_geometric.utils.mask import mask_select +from torch_geometric.utils.sparse import is_torch_sparse_tensor def select(src: Union[Tensor, List[Any]], index_or_mask: Tensor, @@ -41,6 +42,11 @@ def narrow(src: Union[Tensor, List[Any]], dim: int, start: int, start (int): The starting dimension. length (int): The distance to the ending dimension. """ + if is_torch_sparse_tensor(src): + # TODO Sparse tensors in `torch.sparse` do not yet support `narrow`. + index = torch.arange(start, start + length, device=src.device) + return src.index_select(dim, index) + if isinstance(src, Tensor): return src.narrow(dim, start, length)