Skip to content

Commit

Permalink
to_nested_tensor and from_nested_tensor functionality (#6329)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Jan 3, 2023
1 parent 78bba36 commit 7d2c9df
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.3.0] - 2023-MM-DD
### Added
- Added `to_nested_tensor` and `from_nested_tensor` functionality ([#6329](https://github.com/pyg-team/pytorch_geometric/pull/6329))
- Added the `GPSConv` Graph Transformer layer ([#6326](https://github.com/pyg-team/pytorch_geometric/pull/6326))
- Added `networkit` conversion utilities ([#6321](https://github.com/pyg-team/pytorch_geometric/pull/6321))
- Added global dataset attribute access via `dataset.{attr_name}` ([#6319](https://github.com/pyg-team/pytorch_geometric/pull/6319))
Expand Down
34 changes: 34 additions & 0 deletions test/utils/test_nested.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch

from torch_geometric.utils import from_nested_tensor, to_nested_tensor


def test_to_nested_tensor():
x = torch.randn(5, 4, 3)

out = to_nested_tensor(x, batch=torch.tensor([0, 0, 1, 1, 1]))
out = out.to_padded_tensor(padding=0)
assert out.size() == (2, 3, 4, 3)
assert torch.allclose(out[0, :2], x[0:2])
assert torch.allclose(out[1, :3], x[2:5])

out = to_nested_tensor(x, ptr=torch.tensor([0, 2, 5]))
out = out.to_padded_tensor(padding=0)
assert out.size() == (2, 3, 4, 3)
assert torch.allclose(out[0, :2], x[0:2])
assert torch.allclose(out[1, :3], x[2:5])

out = to_nested_tensor(x)
out = out.to_padded_tensor(padding=0)
assert out.size() == (1, 5, 4, 3)
assert torch.allclose(out[0], x)


def test_from_nested_tensor():
x = torch.randn(5, 4, 3)

nested = to_nested_tensor(x, batch=torch.tensor([0, 0, 1, 1, 1]))
out, batch = from_nested_tensor(nested)

assert torch.equal(x, out)
assert batch.tolist() == [0, 0, 1, 1, 1]
3 changes: 3 additions & 0 deletions torch_geometric/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .mask import mask_select, index_to_mask, mask_to_index
from .to_dense_batch import to_dense_batch
from .to_dense_adj import to_dense_adj
from .nested import to_nested_tensor, from_nested_tensor
from .sparse import (dense_to_sparse, is_sparse, is_torch_sparse_tensor,
to_torch_coo_tensor)
from .spmm import spmm
Expand Down Expand Up @@ -77,6 +78,8 @@
'mask_to_index',
'to_dense_batch',
'to_dense_adj',
'to_nested_tensor',
'from_nested_tensor',
'dense_to_sparse',
'is_torch_sparse_tensor',
'is_sparse',
Expand Down
77 changes: 77 additions & 0 deletions torch_geometric/utils/nested.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from typing import List, Optional, Tuple

import torch
from torch import Tensor

from torch_geometric.utils import scatter


def to_nested_tensor(
x: Tensor,
batch: Optional[Tensor] = None,
ptr: Optional[Tensor] = None,
batch_size: Optional[int] = None,
) -> torch.Tensor:
r"""Given a contiguous batch of tensors
:math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times *}`
(with :math:`N_i` indicating the number of elements in example :math:`i`),
creates a `nested PyTorch tensor
<https://pytorch.org/docs/stable/nested.html>`__.
Reverse operation of :meth:`from_nested_tensor`.
Args:
x (torch.Tensor): The input tensor
:math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times *}`.
batch (torch.Tensor, optional): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
element to a specific example. Must be ordered.
(default: :obj:`None`)
ptr (torch.Tensor, optional): Alternative representation of
:obj:`batch` in compressed format. (default: :obj:`None`)
batch_size (int, optional) The batch size :math:`B`.
(default: :obj:`None`)
"""
if ptr is not None:
sizes = ptr[1:] - ptr[:-1]
sizes: List[int] = sizes.tolist()
xs = list(torch.split(x, sizes, dim=0))
elif batch is not None:
sizes = scatter(torch.ones_like(batch), batch, dim_size=batch_size)
sizes: List[int] = sizes.tolist()
xs = list(torch.split(x, sizes, dim=0))
else:
xs = [x]

# This currently copies the data, although `x` is already contiguous.
# Sadly, there does not exist any (public) API to preven this :(
return torch.nested.as_nested_tensor(xs)


def from_nested_tensor(x: Tensor) -> Tuple[Tensor, Tensor]:
r"""Given a `nested PyTorch tensor
<https://pytorch.org/docs/stable/nested.html>`__, creates a contiguous
batch of tensors
:math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times *}` and a
batch vector, which assigns each element to a specific example.
Reverse operation of :meth:`to_nested_tensor`.
Args:
x (torch.Tensor): The nested input tensor. The size of nested tensors
need to match except for the first dimension.
"""
if not x.is_nested:
raise ValueError("Input tensor is not nested")

sizes = x._nested_tensor_size()
for dim, (a, b) in enumerate(zip(sizes.t()[1:], sizes[0, 1:])):
if not torch.equal(a, b.expand_as(a)):
raise ValueError(f"Not all nested tensors have the same size in "
f"dimension {dim + 1}")

batch = torch.arange(x.size(0), device=x.device)
batch = batch.repeat_interleave(sizes[:, 0].to(batch.device))

out = torch.tensor(x.contiguous().storage())
out = out.view(batch.numel(), *sizes[0, 1:].tolist())

return out, batch

0 comments on commit 7d2c9df

Please sign in to comment.