diff --git a/CHANGELOG.md b/CHANGELOG.md index 0cedaf7923b0..fd17cdfe9b16 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added a `batch_size` argument to `unbatch` functionalities ([#7851](https://github.com/pyg-team/pytorch_geometric/pull/7851)) - Added a distributed example using `graphlearn-for-pytorch` ([#7402](https://github.com/pyg-team/pytorch_geometric/pull/7402)) - Integrate `neg_sampling_ratio` into `TemporalDataLoader` ([#7644](https://github.com/pyg-team/pytorch_geometric/pull/7644)) - Added `faiss`-based `KNNINdex` classes for L2 or maximum inner product search ([#7842](https://github.com/pyg-team/pytorch_geometric/pull/7842)) diff --git a/torch_geometric/utils/unbatch.py b/torch_geometric/utils/unbatch.py index 62efcbf07bb0..0ba85c7e1135 100644 --- a/torch_geometric/utils/unbatch.py +++ b/torch_geometric/utils/unbatch.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional import torch from torch import Tensor @@ -6,7 +6,12 @@ from torch_geometric.utils import degree -def unbatch(src: Tensor, batch: Tensor, dim: int = 0) -> List[Tensor]: +def unbatch( + src: Tensor, + batch: Tensor, + dim: int = 0, + batch_size: Optional[int] = None, +) -> List[Tensor]: r"""Splits :obj:`src` according to a :obj:`batch` vector along dimension :obj:`dim`. @@ -17,6 +22,7 @@ def unbatch(src: Tensor, batch: Tensor, dim: int = 0) -> List[Tensor]: entry in :obj:`src` to a specific example. Must be ordered. dim (int, optional): The dimension along which to split the :obj:`src` tensor. (default: :obj:`0`) + batch_size (int, optional) The batch size. (default: :obj:`None`) :rtype: :class:`List[Tensor]` @@ -27,11 +33,15 @@ def unbatch(src: Tensor, batch: Tensor, dim: int = 0) -> List[Tensor]: >>> unbatch(src, batch) (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) """ - sizes = degree(batch, dtype=torch.long).tolist() + sizes = degree(batch, batch_size, dtype=torch.long).tolist() return src.split(sizes, dim) -def unbatch_edge_index(edge_index: Tensor, batch: Tensor) -> List[Tensor]: +def unbatch_edge_index( + edge_index: Tensor, + batch: Tensor, + batch_size: Optional[int] = None, +) -> List[Tensor]: r"""Splits the :obj:`edge_index` according to a :obj:`batch` vector. Args: @@ -39,6 +49,7 @@ def unbatch_edge_index(edge_index: Tensor, batch: Tensor) -> List[Tensor]: batch (LongTensor): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. Must be ordered. + batch_size (int, optional) The batch size. (default: :obj:`None`) :rtype: :class:`List[Tensor]` @@ -53,10 +64,10 @@ def unbatch_edge_index(edge_index: Tensor, batch: Tensor) -> List[Tensor]: tensor([[0, 1, 1, 2], [1, 0, 2, 1]])) """ - deg = degree(batch, dtype=torch.int64) + deg = degree(batch, batch_size, dtype=torch.long) ptr = torch.cat([deg.new_zeros(1), deg.cumsum(dim=0)[:-1]], dim=0) edge_batch = batch[edge_index[0]] edge_index = edge_index - ptr[edge_batch] - sizes = degree(edge_batch, dtype=torch.int64).cpu().tolist() + sizes = degree(edge_batch, batch_size, dtype=torch.long).cpu().tolist() return edge_index.split(sizes, dim=1)