From 78d6d64a4d0d8d9df8d38cd9bb4fba87bc429ebf Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sat, 5 Aug 2023 10:30:48 +0200 Subject: [PATCH 1/2] update --- torch_geometric/utils/unbatch.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/torch_geometric/utils/unbatch.py b/torch_geometric/utils/unbatch.py index 62efcbf07bb0..0ba85c7e1135 100644 --- a/torch_geometric/utils/unbatch.py +++ b/torch_geometric/utils/unbatch.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional import torch from torch import Tensor @@ -6,7 +6,12 @@ from torch_geometric.utils import degree -def unbatch(src: Tensor, batch: Tensor, dim: int = 0) -> List[Tensor]: +def unbatch( + src: Tensor, + batch: Tensor, + dim: int = 0, + batch_size: Optional[int] = None, +) -> List[Tensor]: r"""Splits :obj:`src` according to a :obj:`batch` vector along dimension :obj:`dim`. @@ -17,6 +22,7 @@ def unbatch(src: Tensor, batch: Tensor, dim: int = 0) -> List[Tensor]: entry in :obj:`src` to a specific example. Must be ordered. dim (int, optional): The dimension along which to split the :obj:`src` tensor. (default: :obj:`0`) + batch_size (int, optional) The batch size. (default: :obj:`None`) :rtype: :class:`List[Tensor]` @@ -27,11 +33,15 @@ def unbatch(src: Tensor, batch: Tensor, dim: int = 0) -> List[Tensor]: >>> unbatch(src, batch) (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) """ - sizes = degree(batch, dtype=torch.long).tolist() + sizes = degree(batch, batch_size, dtype=torch.long).tolist() return src.split(sizes, dim) -def unbatch_edge_index(edge_index: Tensor, batch: Tensor) -> List[Tensor]: +def unbatch_edge_index( + edge_index: Tensor, + batch: Tensor, + batch_size: Optional[int] = None, +) -> List[Tensor]: r"""Splits the :obj:`edge_index` according to a :obj:`batch` vector. Args: @@ -39,6 +49,7 @@ def unbatch_edge_index(edge_index: Tensor, batch: Tensor) -> List[Tensor]: batch (LongTensor): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. Must be ordered. + batch_size (int, optional) The batch size. (default: :obj:`None`) :rtype: :class:`List[Tensor]` @@ -53,10 +64,10 @@ def unbatch_edge_index(edge_index: Tensor, batch: Tensor) -> List[Tensor]: tensor([[0, 1, 1, 2], [1, 0, 2, 1]])) """ - deg = degree(batch, dtype=torch.int64) + deg = degree(batch, batch_size, dtype=torch.long) ptr = torch.cat([deg.new_zeros(1), deg.cumsum(dim=0)[:-1]], dim=0) edge_batch = batch[edge_index[0]] edge_index = edge_index - ptr[edge_batch] - sizes = degree(edge_batch, dtype=torch.int64).cpu().tolist() + sizes = degree(edge_batch, batch_size, dtype=torch.long).cpu().tolist() return edge_index.split(sizes, dim=1) From 5f46c7c1363e00496f2c4107ff9c114703f7c790 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sat, 5 Aug 2023 10:31:56 +0200 Subject: [PATCH 2/2] update --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0cedaf7923b0..fd17cdfe9b16 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added a `batch_size` argument to `unbatch` functionalities ([#7851](https://github.com/pyg-team/pytorch_geometric/pull/7851)) - Added a distributed example using `graphlearn-for-pytorch` ([#7402](https://github.com/pyg-team/pytorch_geometric/pull/7402)) - Integrate `neg_sampling_ratio` into `TemporalDataLoader` ([#7644](https://github.com/pyg-team/pytorch_geometric/pull/7644)) - Added `faiss`-based `KNNINdex` classes for L2 or maximum inner product search ([#7842](https://github.com/pyg-team/pytorch_geometric/pull/7842))