Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed May 8, 2023
1 parent 304afb8 commit d59f37e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
3 changes: 2 additions & 1 deletion torch_geometric/nn/pool/connect/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base import Connect
from .base import Connect, ConnectOutput

__all__ = [
'Connect',
'ConnectOutput',
]
9 changes: 4 additions & 5 deletions torch_geometric/nn/pool/connect/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def __init__(


class Connect(torch.nn.Module):
r"""An abstract base class implementing custom edge connection operators as
described in the `"Understanding Pooling in Graph Neural Networks"
<https://arxiv.org/abs/1905.05178>`_ paper.
r"""An abstract base class for implementing custom edge connection
operators as described in the `"Understanding Pooling in Graph Neural
Networks" <https://arxiv.org/abs/1905.05178>`_ paper.
Specifically, :class:`Connect` determines for each pair of supernodes the
presence or abscene of an edge based on the existing edges between the
Expand All @@ -73,8 +73,7 @@ def forward(
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
cluster (SelectOutput): The output of `Select`, with a mapping from
nodes to clusters.
select_output (SelectOutput): The output of :class:`Select`.
edge_index (torch.Tensor): The edge indices.
edge_attr (torch.Tensor, optional): The edge features.
(default: :obj:`None`)
Expand Down

0 comments on commit d59f37e

Please sign in to comment.