From b10b465ff1752becc39897edf786b250d5e1a519 Mon Sep 17 00:00:00 2001 From: kgajdamo Date: Tue, 24 Jan 2023 18:53:20 +0100 Subject: [PATCH] allow 1D input to `global_*_pool` functions (#6504) Allow 1D input to global_*_pool functions Co-authored-by: Matthias Fey --- CHANGELOG.md | 1 + torch_geometric/nn/pool/glob.py | 18 ++++++++++++------ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 14dec4979bf8..39f2e8348fc2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -66,6 +66,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Allow 1D input to `global_*_pool` functions ([#6504](https://github.com/pyg-team/pytorch_geometric/pull/6504)) - Add information about dynamic shapes in `RGCNConv` ([#6482](https://github.com/pyg-team/pytorch_geometric/pull/6482)) - Fixed the use of types removed in `numpy 1.24.0` ([#6495](https://github.com/pyg-team/pytorch_geometric/pull/6495)) - Fixed keyword parameters in `examples/mnist_voxel_grid.py` ([#6478](https://github.com/pyg-team/pytorch_geometric/pull/6478)) diff --git a/torch_geometric/nn/pool/glob.py b/torch_geometric/nn/pool/glob.py index 17ea54653987..2f2771084daf 100644 --- a/torch_geometric/nn/pool/glob.py +++ b/torch_geometric/nn/pool/glob.py @@ -26,10 +26,12 @@ def global_add_pool(x: Tensor, batch: Optional[Tensor], size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) """ + dim = -1 if x.dim() == 1 else -2 + if batch is None: - return x.sum(dim=-2, keepdim=x.dim() == 2) + return x.sum(dim=dim, keepdim=x.dim() <= 2) size = int(batch.max().item() + 1) if size is None else size - return scatter(x, batch, dim=-2, dim_size=size, reduce='sum') + return scatter(x, batch, dim=dim, dim_size=size, reduce='sum') def global_mean_pool(x: Tensor, batch: Optional[Tensor], @@ -53,10 +55,12 @@ def global_mean_pool(x: Tensor, batch: Optional[Tensor], size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) """ + dim = -1 if x.dim() == 1 else -2 + if batch is None: - return x.mean(dim=-2, keepdim=x.dim() == 2) + return x.mean(dim=dim, keepdim=x.dim() <= 2) size = int(batch.max().item() + 1) if size is None else size - return scatter(x, batch, dim=-2, dim_size=size, reduce='mean') + return scatter(x, batch, dim=dim, dim_size=size, reduce='mean') def global_max_pool(x: Tensor, batch: Optional[Tensor], @@ -80,7 +84,9 @@ def global_max_pool(x: Tensor, batch: Optional[Tensor], size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) """ + dim = -1 if x.dim() == 1 else -2 + if batch is None: - return x.max(dim=-2, keepdim=x.dim() == 2)[0] + return x.max(dim=dim, keepdim=x.dim() <= 2)[0] size = int(batch.max().item() + 1) if size is None else size - return scatter(x, batch, dim=-2, dim_size=size, reduce='max') + return scatter(x, batch, dim=dim, dim_size=size, reduce='max')