Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow 1D input to global_*_pool functions #6504

Merged
merged 6 commits into from
Jan 24, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions torch_geometric/nn/pool/glob.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ def global_add_pool(x: Tensor, batch: Optional[Tensor],
size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
"""
dim = -1 if x.dim() == 1 else -2

if batch is None:
return x.sum(dim=-2, keepdim=x.dim() == 2)
return x.sum(dim=dim, keepdim=x.dim() == 2 or x.dim() == 1)
size = int(batch.max().item() + 1) if size is None else size
return scatter(x, batch, dim=-2, dim_size=size, reduce='sum')
return scatter(x, batch, dim=dim, dim_size=size, reduce='sum')


def global_mean_pool(x: Tensor, batch: Optional[Tensor],
Expand All @@ -53,10 +55,12 @@ def global_mean_pool(x: Tensor, batch: Optional[Tensor],
size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
"""
dim = -1 if x.dim() == 1 else -2

if batch is None:
return x.mean(dim=-2, keepdim=x.dim() == 2)
return x.mean(dim=dim, keepdim=x.dim() == 2 or x.dim() == 1)
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
size = int(batch.max().item() + 1) if size is None else size
return scatter(x, batch, dim=-2, dim_size=size, reduce='mean')
return scatter(x, batch, dim=dim, dim_size=size, reduce='mean')


def global_max_pool(x: Tensor, batch: Optional[Tensor],
Expand All @@ -80,7 +84,9 @@ def global_max_pool(x: Tensor, batch: Optional[Tensor],
size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
"""
dim = -1 if x.dim() == 1 else -2

if batch is None:
return x.max(dim=-2, keepdim=x.dim() == 2)[0]
return x.max(dim=dim, keepdim=x.dim() == 2 or x.dim() == 1)[0]
size = int(batch.max().item() + 1) if size is None else size
return scatter(x, batch, dim=-2, dim_size=size, reduce='max')
return scatter(x, batch, dim=dim, dim_size=size, reduce='max')