Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pooling: Update base classes #7307

Merged
merged 6 commits into from
May 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 0 additions & 52 deletions test/nn/pool/test_pooling_base.py

This file was deleted.

94 changes: 0 additions & 94 deletions torch_geometric/nn/pool/base.py

This file was deleted.

3 changes: 2 additions & 1 deletion torch_geometric/nn/pool/connect/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base import Connect
from .base import Connect, ConnectOutput

__all__ = [
'Connect',
'ConnectOutput',
]
76 changes: 72 additions & 4 deletions torch_geometric/nn/pool/connect/base.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,79 @@
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
from torch import Tensor

from torch_geometric.nn.pool.select import SelectOutput
from torch_geometric.utils.mixin import CastMixin


@dataclass(init=False)
class ConnectOutput(CastMixin):
r"""The output of the :class:`Connect` method, which holds the coarsened
graph structure, and optional pooled edge features and batch vectors.

Args:
edge_index (torch.Tensor): The edge indices of the cooarsened graph.
edge_attr (torch.Tensor, optional): The pooled edge features of the
coarsened graph. (default: :obj:`None`)
batch (torch.Tensor, optional): The pooled batch vector of the
coarsened graph. (default: :obj:`None`)
"""
edge_index: Tensor
edge_attr: Optional[Tensor] = None
batch: Optional[Tensor] = None

def __init__(
self,
edge_index: Tensor,
edge_attr: Optional[Tensor] = None,
batch: Optional[Tensor] = None,
):
if edge_index.dim() != 2:
raise ValueError(f"Expected 'edge_index' to be two-dimensional "
f"(got {edge_index.dim()} dimensions)")

if edge_index.size(0) != 2:
raise ValueError(f"Expected 'edge_index' to have size '2' in the "
f"first dimension (got '{edge_index.size(0)}')")

if edge_attr is not None and edge_attr.size(0) != edge_index.size(1):
raise ValueError(f"Expected 'edge_index' and 'edge_attr' to "
f"hold the same number of edges (got "
f"{edge_index.size(1)} and {edge_attr.size(0)} "
f"edges)")

self.edge_index = edge_index
self.edge_attr = edge_attr
self.batch = batch


class Connect(torch.nn.Module):
r"""An abstract base class implementing custom edge connection operators.
r"""An abstract base class for implementing custom edge connection
operators as described in the `"Understanding Pooling in Graph Neural
Networks" <https://arxiv.org/abs/1905.05178>`_ paper.

Specifically, :class:`Connect` determines for each pair of supernodes the
presence or abscene of an edge based on the existing edges between the
nodes in the two supernodes.
The operator also computes new coarsened edge features (if present).
The operator also computes pooled edge features and batch vectors
(if present).
"""
def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
pass

def forward(
self,
cluster: Tensor,
select_output: SelectOutput,
edge_index: Tensor,
edge_attr: Optional[Tensor] = None,
batch: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
cluster (torch.Tensor): The mapping from nodes to supernodes.
select_output (SelectOutput): The output of :class:`Select`.
edge_index (torch.Tensor): The edge indices.
edge_attr (torch.Tensor, optional): The edge features.
(default: :obj:`None`)
Expand All @@ -35,5 +83,25 @@ def forward(
"""
raise NotImplementedError

@staticmethod
def get_pooled_batch(
select_output: SelectOutput,
batch: Optional[Tensor],
) -> Optional[Tensor]:
r"""Returns the batch vector of the coarsened graph.

Args:
select_output (SelectOutput): The output of :class:`Select`.
batch (torch.Tensor, optional): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
each element to a specific example. (default: :obj:`None`)
"""
if batch is None:
return batch

out = torch.arange(select_output.num_clusters, device=batch.device)
return out.scatter_(0, select_output.cluster_index,
batch[select_output.node_index])

def __repr__(self) -> str:
return f'{self.__class__.__name__}()'
3 changes: 2 additions & 1 deletion torch_geometric/nn/pool/select/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base import Select
from .base import Select, SelectOutput

__all__ = [
'Select',
'SelectOutput',
]
90 changes: 66 additions & 24 deletions torch_geometric/nn/pool/select/base.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,80 @@
from typing import Optional, Tuple
from dataclasses import dataclass
from typing import Optional

import torch
from torch import Tensor

from torch_geometric.utils.mixin import CastMixin


@dataclass(init=False)
class SelectOutput(CastMixin):
r"""The output of the :class:`Select` method, which holds an assignment
from selected nodes to their respective cluster(s).

Args:
node_index (torch.Tensor): The indices of the selected nodes.
cluster_index (torch.Tensor): The indices of the clusters each node in
:obj:`node_index` is assigned to.
num_clusters (int): The number of clusters.
weight (torch.Tensor, optional): A weight vector, denoting the strength
of the assignment of a node to its cluster. (default: :obj:`None`)
"""
node_index: Tensor
cluster_index: Tensor
num_clusters: int
weight: Optional[Tensor] = None

def __init__(
self,
node_index: Tensor,
cluster_index: Tensor,
num_clusters: int,
weight: Optional[Tensor] = None,
):
if node_index.dim() != 1:
raise ValueError(f"Expected 'node_index' to be one-dimensional "
f"(got {node_index.dim()} dimensions)")

if cluster_index.dim() != 1:
raise ValueError(f"Expected 'cluster_index' to be one-dimensional "
f"(got {cluster_index.dim()} dimensions)")

if node_index.numel() != cluster_index.numel():
raise ValueError(f"Expected 'node_index' and 'cluster_index' to "
f"hold the same number of values (got "
f"{node_index.numel()} and "
f"{cluster_index.numel()} values)")

if weight is not None and weight.dim() != 1:
raise ValueError(f"Expected 'weight' vector to be one-dimensional "
f"(got {weight.dim()} dimensions)")

if weight is not None and weight.numel() != node_index.numel():
raise ValueError(f"Expected 'weight' to hold {node_index.numel()} "
f"values (got {weight.numel()} values)")

self.node_index = node_index
self.cluster_index = cluster_index
self.num_clusters = num_clusters
self.weight = weight


class Select(torch.nn.Module):
r"""An abstract base class implementing custom node selections that map the
nodes of an input graph to supernodes of the pooled one.
r"""An abstract base class for implementing custom node selections as
described in the `"Understanding Pooling in Graph Neural Networks"
<https://arxiv.org/abs/1905.05178>`_ paper, which maps the nodes of an
input graph to supernodes in the coarsened graph.

Specifically, :class:`Select` returns a mapping
:math:`\mathbf{c} \in {\{ -1, \ldots, C - 1\}}^N`, which assigns each node
to one of :math:`C` super nodes.
In addition, :class:`Select` returns the number of super nodes.
Specifically, :class:`Select` returns a :class:`SelectOutput` output, which
holds a (sparse) mapping :math:`\mathbf{C} \in {[0, 1]}^{N \times C}` that
assigns selected nodes to one or more of :math:`C` super nodes.
"""
def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
pass

def forward(
self,
x: Tensor,
edge_index: Tensor,
edge_attr: Optional[Tensor] = None,
batch: Optional[Tensor] = None,
) -> Tuple[Tensor, int]:
r"""
Args:
x (torch.Tensor): The input node features.
edge_index (torch.Tensor): The edge indices.
edge_attr (torch.Tensor, optional): The edge features.
(default: :obj:`None`)
batch (torch.Tensor, optional): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
each node to a specific graph. (default: :obj:`None`)
"""
def forward(self, *args, **kwargs) -> SelectOutput:
raise NotImplementedError

def __repr__(self) -> str:
Expand Down