Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed May 8, 2023
1 parent 93d4d30 commit 304afb8
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 38 deletions.
72 changes: 69 additions & 3 deletions torch_geometric/nn/pool/connect/base.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,72 @@
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
from torch import Tensor

from torch_geometric.nn.pool.select import SelectOutput
from torch_geometric.utils.mixin import CastMixin


@dataclass(init=False)
class ConnectOutput(CastMixin):
r"""The output of the :class:`Connect` method, which holds the coarsened
graph structure, and optional pooled edge features and batch vectors.
Args:
edge_index (torch.Tensor): The edge indices of the cooarsened graph.
edge_attr (torch.Tensor, optional): The pooled edge features of the
coarsened graph. (default: :obj:`None`)
batch (torch.Tensor, optional): The pooled batch vector of the
coarsened graph. (default: :obj:`None`)
"""
edge_index: Tensor
edge_attr: Optional[Tensor] = None
batch: Optional[Tensor] = None

def __init__(
self,
edge_index: Tensor,
edge_attr: Optional[Tensor] = None,
batch: Optional[Tensor] = None,
):
if edge_index.dim() != 2:
raise ValueError(f"Expected 'edge_index' to be two-dimensional "
f"(got {edge_index.dim()} dimensions)")

if edge_index.size(0) != 2:
raise ValueError(f"Expected 'edge_index' to have size '2' in the "
f"first dimension (got '{edge_index.size(0)}')")

if edge_attr is not None and edge_attr.size(0) != edge_index.size(1):
raise ValueError(f"Expected 'edge_index' and 'edge_attr' to "
f"hold the same number of edges (got "
f"{edge_index.size(1)} and {edge_attr.size(0)} "
f"edges)")

self.edge_index = edge_index
self.edge_attr = edge_attr
self.batch = batch


class Connect(torch.nn.Module):
r"""An abstract base class implementing custom edge connection operators.
r"""An abstract base class implementing custom edge connection operators as
described in the `"Understanding Pooling in Graph Neural Networks"
<https://arxiv.org/abs/1905.05178>`_ paper.
Specifically, :class:`Connect` determines for each pair of supernodes the
presence or abscene of an edge based on the existing edges between the
nodes in the two supernodes.
The operator also computes new coarsened edge features (if present).
The operator also computes pooled edge features and batch vectors
(if present).
"""
def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
pass

def forward(
self,
cluster: SelectOutput,
select_output: SelectOutput,
edge_index: Tensor,
edge_attr: Optional[Tensor] = None,
batch: Optional[Tensor] = None,
Expand All @@ -38,5 +84,25 @@ def forward(
"""
raise NotImplementedError

@staticmethod
def get_pooled_batch(
select_output: SelectOutput,
batch: Optional[Tensor],
) -> Optional[Tensor]:
r"""Returns the batch vector of the coarsened graph.
Args:
select_output (SelectOutput): The output of :class:`Select`.
batch (torch.Tensor, optional): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
each element to a specific example. (default: :obj:`None`)
"""
if batch is None:
return batch

out = torch.arange(select_output.num_clusters, device=batch.device)
return out.scatter_(0, select_output.cluster_index,
batch[select_output.node_index])

def __repr__(self) -> str:
return f'{self.__class__.__name__}()'
5 changes: 4 additions & 1 deletion torch_geometric/nn/pool/select/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from .base import Select, SelectOutput

__all__ = ['Select', 'SelectOutput']
__all__ = [
'Select',
'SelectOutput',
]
89 changes: 55 additions & 34 deletions torch_geometric/nn/pool/select/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,56 +4,77 @@
import torch
from torch import Tensor

from torch_geometric.utils.mixin import CastMixin


@dataclass(init=False)
class SelectOutput(CastMixin):
r"""The output of the :class:`Select` method, which holds an assignment
from selected nodes to their respective cluster(s).
@dataclass
class SelectOutput:
"""
Output of the :meth:`Select.forward` method. Contains
a mapping from nodes to clusters, and the number of clusters.
Args:
node_index (Tensor): The indices of the selected nodes.
cluster_index (Tensor): The indices of the clusters each node is
assigned to. Same shape as :obj:`node_index`.
node_index (torch.Tensor): The indices of the selected nodes.
cluster_index (torch.Tensor): The indices of the clusters each node in
:obj:`node_index` is assigned to.
num_clusters (int): The number of clusters.
weight (Tensor, optional): A weight Tensor with values in range
`[0, 1]`, denoting assignment weight of a node to a cluster.
Same shape as :obj:`node_index`. (default: :obj:`None`).
weight (torch.Tensor, optional): A weight vector, denoting the strength
of the assignment of a node to its cluster. (default: :obj:`None`)
"""
node_index: Tensor
cluster_index: Tensor
num_clusters: int
weight: Optional[Tensor] = None

def __init__(
self,
node_index: Tensor,
cluster_index: Tensor,
num_clusters: int,
weight: Optional[Tensor] = None,
):
if node_index.dim() != 1:
raise ValueError(f"Expected 'node_index' to be one-dimensional "
f"(got {node_index.dim()} dimensions)")

if cluster_index.dim() != 1:
raise ValueError(f"Expected 'cluster_index' to be one-dimensional "
f"(got {cluster_index.dim()} dimensions)")

if node_index.numel() != cluster_index.numel():
raise ValueError(f"Expected 'node_index' and 'cluster_index' to "
f"hold the same number of values (got "
f"{node_index.numel()} and "
f"{cluster_index.numel()} values)")

if weight is not None and weight.dim() != 1:
raise ValueError(f"Expected 'weight' vector to be one-dimensional "
f"(got {weight.dim()} dimensions)")

if weight is not None and weight.numel() != node_index.numel():
raise ValueError(f"Expected 'weight' to hold {node_index.numel()} "
f"values (got {weight.numel()} values)")

self.node_index = node_index
self.cluster_index = cluster_index
self.num_clusters = num_clusters
self.weight = weight


class Select(torch.nn.Module):
r"""An abstract base class implementing custom node selections that map the
nodes of an input graph to supernodes of the pooled one.
r"""An abstract base class for implementing custom node selections as
described in the `"Understanding Pooling in Graph Neural Networks"
<https://arxiv.org/abs/1905.05178>`_ paper, which maps the nodes of an
input graph to supernodes in the coarsened graph.
Specifically, :class:`Select` returns a mapping
:math:`\mathbf{c} \in {\{ -1, \ldots, C - 1\}}^N`, which assigns each node
to one of :math:`C` super nodes.
In addition, :class:`Select` returns the number of super nodes.
Specifically, :class:`Select` returns a :class:`SelectOutput` output, which
holds a (sparse) mapping :math:`\mathbf{C} \in {[0, 1]}^{N \times C}` that
assigns selected nodes to one or more of :math:`C` super nodes.
"""
def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
pass

def forward(
self,
x: Tensor,
edge_index: Tensor,
edge_attr: Optional[Tensor] = None,
batch: Optional[Tensor] = None,
) -> SelectOutput:
r"""
Args:
x (torch.Tensor): The input node features.
edge_index (torch.Tensor): The edge indices.
edge_attr (torch.Tensor, optional): The edge features.
(default: :obj:`None`)
batch (torch.Tensor, optional): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
each node to a specific graph. (default: :obj:`None`)
"""
def forward(self, *args, **kwargs) -> SelectOutput:
raise NotImplementedError

def __repr__(self) -> str:
Expand Down

0 comments on commit 304afb8

Please sign in to comment.