Skip to content

Commit

Permalink
Add batch_size argument to unbatch functionalities (#7851)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Aug 5, 2023
1 parent a2bc6c3 commit ca5311c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added a `batch_size` argument to `unbatch` functionalities ([#7851](https://github.com/pyg-team/pytorch_geometric/pull/7851))
- Added a distributed example using `graphlearn-for-pytorch` ([#7402](https://github.com/pyg-team/pytorch_geometric/pull/7402))
- Integrate `neg_sampling_ratio` into `TemporalDataLoader` ([#7644](https://github.com/pyg-team/pytorch_geometric/pull/7644))
- Added `faiss`-based `KNNINdex` classes for L2 or maximum inner product search ([#7842](https://github.com/pyg-team/pytorch_geometric/pull/7842))
Expand Down
23 changes: 17 additions & 6 deletions torch_geometric/utils/unbatch.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from typing import List
from typing import List, Optional

import torch
from torch import Tensor

from torch_geometric.utils import degree


def unbatch(src: Tensor, batch: Tensor, dim: int = 0) -> List[Tensor]:
def unbatch(
src: Tensor,
batch: Tensor,
dim: int = 0,
batch_size: Optional[int] = None,
) -> List[Tensor]:
r"""Splits :obj:`src` according to a :obj:`batch` vector along dimension
:obj:`dim`.
Expand All @@ -17,6 +22,7 @@ def unbatch(src: Tensor, batch: Tensor, dim: int = 0) -> List[Tensor]:
entry in :obj:`src` to a specific example. Must be ordered.
dim (int, optional): The dimension along which to split the :obj:`src`
tensor. (default: :obj:`0`)
batch_size (int, optional) The batch size. (default: :obj:`None`)
:rtype: :class:`List[Tensor]`
Expand All @@ -27,18 +33,23 @@ def unbatch(src: Tensor, batch: Tensor, dim: int = 0) -> List[Tensor]:
>>> unbatch(src, batch)
(tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6]))
"""
sizes = degree(batch, dtype=torch.long).tolist()
sizes = degree(batch, batch_size, dtype=torch.long).tolist()
return src.split(sizes, dim)


def unbatch_edge_index(edge_index: Tensor, batch: Tensor) -> List[Tensor]:
def unbatch_edge_index(
edge_index: Tensor,
batch: Tensor,
batch_size: Optional[int] = None,
) -> List[Tensor]:
r"""Splits the :obj:`edge_index` according to a :obj:`batch` vector.
Args:
edge_index (Tensor): The edge_index tensor. Must be ordered.
batch (LongTensor): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. Must be ordered.
batch_size (int, optional) The batch size. (default: :obj:`None`)
:rtype: :class:`List[Tensor]`
Expand All @@ -53,10 +64,10 @@ def unbatch_edge_index(edge_index: Tensor, batch: Tensor) -> List[Tensor]:
tensor([[0, 1, 1, 2],
[1, 0, 2, 1]]))
"""
deg = degree(batch, dtype=torch.int64)
deg = degree(batch, batch_size, dtype=torch.long)
ptr = torch.cat([deg.new_zeros(1), deg.cumsum(dim=0)[:-1]], dim=0)

edge_batch = batch[edge_index[0]]
edge_index = edge_index - ptr[edge_batch]
sizes = degree(edge_batch, dtype=torch.int64).cpu().tolist()
sizes = degree(edge_batch, batch_size, dtype=torch.long).cpu().tolist()
return edge_index.split(sizes, dim=1)

0 comments on commit ca5311c

Please sign in to comment.