Skip to content

Commit

Permalink
allow 1D input to global_*_pool functions (#6504)
Browse files Browse the repository at this point in the history
Allow 1D input to global_*_pool functions

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
kgajdamo and rusty1s authored Jan 24, 2023
1 parent ecf4020 commit b10b465
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Allow 1D input to `global_*_pool` functions ([#6504](https://github.com/pyg-team/pytorch_geometric/pull/6504))
- Add information about dynamic shapes in `RGCNConv` ([#6482](https://github.com/pyg-team/pytorch_geometric/pull/6482))
- Fixed the use of types removed in `numpy 1.24.0` ([#6495](https://github.com/pyg-team/pytorch_geometric/pull/6495))
- Fixed keyword parameters in `examples/mnist_voxel_grid.py` ([#6478](https://github.com/pyg-team/pytorch_geometric/pull/6478))
Expand Down
18 changes: 12 additions & 6 deletions torch_geometric/nn/pool/glob.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ def global_add_pool(x: Tensor, batch: Optional[Tensor],
size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
"""
dim = -1 if x.dim() == 1 else -2

if batch is None:
return x.sum(dim=-2, keepdim=x.dim() == 2)
return x.sum(dim=dim, keepdim=x.dim() <= 2)
size = int(batch.max().item() + 1) if size is None else size
return scatter(x, batch, dim=-2, dim_size=size, reduce='sum')
return scatter(x, batch, dim=dim, dim_size=size, reduce='sum')


def global_mean_pool(x: Tensor, batch: Optional[Tensor],
Expand All @@ -53,10 +55,12 @@ def global_mean_pool(x: Tensor, batch: Optional[Tensor],
size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
"""
dim = -1 if x.dim() == 1 else -2

if batch is None:
return x.mean(dim=-2, keepdim=x.dim() == 2)
return x.mean(dim=dim, keepdim=x.dim() <= 2)
size = int(batch.max().item() + 1) if size is None else size
return scatter(x, batch, dim=-2, dim_size=size, reduce='mean')
return scatter(x, batch, dim=dim, dim_size=size, reduce='mean')


def global_max_pool(x: Tensor, batch: Optional[Tensor],
Expand All @@ -80,7 +84,9 @@ def global_max_pool(x: Tensor, batch: Optional[Tensor],
size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
"""
dim = -1 if x.dim() == 1 else -2

if batch is None:
return x.max(dim=-2, keepdim=x.dim() == 2)[0]
return x.max(dim=dim, keepdim=x.dim() <= 2)[0]
size = int(batch.max().item() + 1) if size is None else size
return scatter(x, batch, dim=-2, dim_size=size, reduce='max')
return scatter(x, batch, dim=dim, dim_size=size, reduce='max')

0 comments on commit b10b465

Please sign in to comment.