Skip to content

Commit

Permalink
Add support for mean and max pool in SAGEConv
Browse files Browse the repository at this point in the history
  • Loading branch information
wilcoln committed Apr 9, 2022
1 parent 3f0019f commit a804b4a
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions torch_geometric/nn/conv/sage_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ class SAGEConv(MessagePassing):
.. math::
\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W}_2 \cdot
\mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{x}_j
\text{ mean }
\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W}_2 \cdot
\mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{W}_3 \cdot \mathbf{x}_j
\text{ mean pool }
Args:
in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
Expand All @@ -28,7 +33,8 @@ class SAGEConv(MessagePassing):
dimensionalities.
out_channels (int): Size of each output sample.
aggr (string, optional): The aggregation scheme to use
(:obj:`"mean"`, :obj:`"max"`, :obj:`"lstm"`).
(:obj:`"mean"`, :obj:`"max"`, :obj:`"lstm"`
, :obj:`"mean_pool"`, :obj:`"max_pool"`).
(default: :obj:`"add"`)
normalize (bool, optional): If set to :obj:`True`, output features
will be :math:`\ell_2`-normalized, *i.e.*,
Expand Down Expand Up @@ -62,20 +68,27 @@ def __init__(
bias: bool = True,
**kwargs,
):
kwargs['aggr'] = aggr if aggr != 'lstm' else None
kwargs['aggr'] = aggr if aggr not in {'lstm', 'mean_pool', 'max_pool'
} else None
super().__init__(**kwargs)

self.in_channels = in_channels
self.out_channels = out_channels
self.normalize = normalize
self.root_weight = root_weight
self.pool = None

if isinstance(in_channels, int):
in_channels = (in_channels, in_channels)

if self.aggr is None:
self.fuse = False # No "fused" message_and_aggregate.
self.lstm = LSTM(in_channels[0], in_channels[0], batch_first=True)
if 'pool' in aggr: # pool case
self.aggr = aggr[:-5]
self.pool = Linear(in_channels[0], in_channels[0], bias=True)
else: # lstm case
self.fuse = False # No "fused" message_and_aggregate.
self.lstm = LSTM(in_channels[0], in_channels[0],
batch_first=True)

self.lin_l = Linear(in_channels[0], out_channels, bias=bias)
if self.root_weight:
Expand All @@ -89,6 +102,8 @@ def reset_parameters(self):
self.lin_l.reset_parameters()
if self.root_weight:
self.lin_r.reset_parameters()
if self.pool is not None:
self.pool.reset_parameters()

def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
size: Size = None) -> Tensor:
Expand Down Expand Up @@ -120,6 +135,9 @@ def message_and_aggregate(self, adj_t: SparseTensor,
def aggregate(self, x: Tensor, index: Tensor, ptr: Optional[Tensor] = None,
dim_size: Optional[int] = None) -> Tensor:
if self.aggr is not None:
if self.pool is not None:
x = F.relu(self.pool(x))

return scatter(x, index, dim=self.node_dim, dim_size=dim_size,
reduce=self.aggr)

Expand All @@ -137,5 +155,6 @@ def aggregate(self, x: Tensor, index: Tensor, ptr: Optional[Tensor] = None,

def __repr__(self) -> str:
aggr = self.aggr if self.aggr is not None else 'lstm'
aggr = aggr if self.pool is None else f'{aggr}_pool'
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, aggr={aggr})')

0 comments on commit a804b4a

Please sign in to comment.