diff --git a/torch_geometric/nn/pool/glob.py b/torch_geometric/nn/pool/glob.py index 17ea54653987..75792fadbb86 100644 --- a/torch_geometric/nn/pool/glob.py +++ b/torch_geometric/nn/pool/glob.py @@ -26,10 +26,12 @@ def global_add_pool(x: Tensor, batch: Optional[Tensor], size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) """ + dim = -1 if x.dim() == 1 else -2 + if batch is None: - return x.sum(dim=-2, keepdim=x.dim() == 2) + return x.sum(dim=dim, keepdim=x.dim() == 2 or x.dim() == 1) size = int(batch.max().item() + 1) if size is None else size - return scatter(x, batch, dim=-2, dim_size=size, reduce='sum') + return scatter(x, batch, dim=dim, dim_size=size, reduce='sum') def global_mean_pool(x: Tensor, batch: Optional[Tensor], @@ -53,10 +55,12 @@ def global_mean_pool(x: Tensor, batch: Optional[Tensor], size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) """ + dim = -1 if x.dim() == 1 else -2 + if batch is None: - return x.mean(dim=-2, keepdim=x.dim() == 2) + return x.mean(dim=dim, keepdim=x.dim() == 2 or x.dim() == 1) size = int(batch.max().item() + 1) if size is None else size - return scatter(x, batch, dim=-2, dim_size=size, reduce='mean') + return scatter(x, batch, dim=dim, dim_size=size, reduce='mean') def global_max_pool(x: Tensor, batch: Optional[Tensor], @@ -80,7 +84,9 @@ def global_max_pool(x: Tensor, batch: Optional[Tensor], size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) """ + dim = -1 if x.dim() == 1 else -2 + if batch is None: - return x.max(dim=-2, keepdim=x.dim() == 2)[0] + return x.max(dim=dim, keepdim=x.dim() == 2 or x.dim() == 1)[0] size = int(batch.max().item() + 1) if size is None else size - return scatter(x, batch, dim=-2, dim_size=size, reduce='max') + return scatter(x, batch, dim=dim, dim_size=size, reduce='max')