diff --git a/torch_geometric/nn/pool/connect/__init__.py b/torch_geometric/nn/pool/connect/__init__.py index aabde59575b4..52b45639fc63 100644 --- a/torch_geometric/nn/pool/connect/__init__.py +++ b/torch_geometric/nn/pool/connect/__init__.py @@ -1,5 +1,6 @@ -from .base import Connect +from .base import Connect, ConnectOutput __all__ = [ 'Connect', + 'ConnectOutput', ] diff --git a/torch_geometric/nn/pool/connect/base.py b/torch_geometric/nn/pool/connect/base.py index 5aabf0dbe780..1b528e29a9be 100644 --- a/torch_geometric/nn/pool/connect/base.py +++ b/torch_geometric/nn/pool/connect/base.py @@ -50,9 +50,9 @@ def __init__( class Connect(torch.nn.Module): - r"""An abstract base class implementing custom edge connection operators as - described in the `"Understanding Pooling in Graph Neural Networks" - `_ paper. + r"""An abstract base class for implementing custom edge connection + operators as described in the `"Understanding Pooling in Graph Neural + Networks" `_ paper. Specifically, :class:`Connect` determines for each pair of supernodes the presence or abscene of an edge based on the existing edges between the @@ -73,8 +73,7 @@ def forward( ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: - cluster (SelectOutput): The output of `Select`, with a mapping from - nodes to clusters. + select_output (SelectOutput): The output of :class:`Select`. edge_index (torch.Tensor): The edge indices. edge_attr (torch.Tensor, optional): The edge features. (default: :obj:`None`)