Skip to content

Commit

Permalink
PNAConv: Customize activation function (#5262)
Browse files Browse the repository at this point in the history
* PNA custom activation function

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored Aug 23, 2022
1 parent 7b6e199 commit 8bcc77c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240))
- Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222))
### Changed
- Allow customization of the activation function within `PNAConv` ([#5262](https://github.com/pyg-team/pytorch_geometric/pull/5262))
- Do not fill `InMemoryDataset` cache on `dataset.num_features` ([#5264](https://github.com/pyg-team/pytorch_geometric/pull/5264))
- Changed tests relying on `dblp` datasets to instead use synthetic data ([#5250](https://github.com/pyg-team/pytorch_geometric/pull/5250))
- Fixed a bug for the initialization of activation function examples in `custom_graphgym` ([#5243](https://github.com/pyg-team/pytorch_geometric/pull/5243))
Expand Down
35 changes: 26 additions & 9 deletions torch_geometric/nn/conv/pna_conv.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import List, Optional
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from torch import Tensor
from torch.nn import ModuleList, ReLU, Sequential
from torch.nn import ModuleList, Sequential

from torch_geometric.nn.aggr import DegreeScalerAggregation
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.utils import degree

Expand Down Expand Up @@ -70,6 +71,11 @@ class PNAConv(MessagePassing):
aggregation (default: :obj:`1`).
divide_input (bool, optional): Whether the input features should
be split between towers or not (default: :obj:`False`).
act (str or Callable, optional): Pre- and post-layer activation
function to use. (default: :obj:`"relu"`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
Expand All @@ -80,11 +86,22 @@ class PNAConv(MessagePassing):
edge features :math:`(|\mathcal{E}|, D)` *(optional)*
- **output:** node features :math:`(|\mathcal{V}|, F_{out})`
"""
def __init__(self, in_channels: int, out_channels: int,
aggregators: List[str], scalers: List[str], deg: Tensor,
edge_dim: Optional[int] = None, towers: int = 1,
pre_layers: int = 1, post_layers: int = 1,
divide_input: bool = False, **kwargs):
def __init__(
self,
in_channels: int,
out_channels: int,
aggregators: List[str],
scalers: List[str],
deg: Tensor,
edge_dim: Optional[int] = None,
towers: int = 1,
pre_layers: int = 1,
post_layers: int = 1,
divide_input: bool = False,
act: Union[str, Callable, None] = "relu",
act_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
):

aggr = DegreeScalerAggregation(aggregators, scalers, deg)
super().__init__(aggr=aggr, node_dim=0, **kwargs)
Expand All @@ -110,14 +127,14 @@ def __init__(self, in_channels: int, out_channels: int,
for _ in range(towers):
modules = [Linear((3 if edge_dim else 2) * self.F_in, self.F_in)]
for _ in range(pre_layers - 1):
modules += [ReLU()]
modules += [activation_resolver(act, **(act_kwargs or {}))]
modules += [Linear(self.F_in, self.F_in)]
self.pre_nns.append(Sequential(*modules))

in_channels = (len(aggregators) * len(scalers) + 1) * self.F_in
modules = [Linear(in_channels, self.F_out)]
for _ in range(post_layers - 1):
modules += [ReLU()]
modules += [activation_resolver(act, **(act_kwargs or {}))]
modules += [Linear(self.F_out, self.F_out)]
self.post_nns.append(Sequential(*modules))

Expand Down

0 comments on commit 8bcc77c

Please sign in to comment.