Skip to content

Commit

Permalink
Add usage disable_dynamic_shape decorator in aggregation layers.
Browse files Browse the repository at this point in the history
  • Loading branch information
piotrchmiel committed Jun 7, 2023
1 parent 0f0e0da commit a9d82a9
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added an optional `batch_size` and `max_num_nodes` arguments to `MemPooling` layer ([#7239](https://github.com/pyg-team/pytorch_geometric/pull/7239))
- Fixed training issues of the GraphGPS example ([#7377](https://github.com/pyg-team/pytorch_geometric/pull/7377))
- Allowed `CaptumExplainer` to be called multiple times in a row ([#7391](https://github.com/pyg-team/pytorch_geometric/pull/7391))
- Added usage of disable_dynamic_shape decorator in aggregation layers ([#7534](https://github.com/pyg-team/pytorch_geometric/pull/7534))


### Removed

Expand Down
3 changes: 3 additions & 0 deletions torch_geometric/nn/aggr/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch import Tensor

from torch_geometric.experimental import disable_dynamic_shapes
from torch_geometric.utils import scatter, segment, to_dense_batch


Expand Down Expand Up @@ -58,6 +59,7 @@ class Aggregation(torch.nn.Module):
- **output:** graph features :math:`(*, |\mathcal{G}|, F_{out})` or
node features :math:`(*, |\mathcal{V}|, F_{out})`
"""
@disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements'])
def forward(
self,
x: Tensor,
Expand Down Expand Up @@ -91,6 +93,7 @@ def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
pass

@disable_dynamic_shapes(required_args=['dim_size'])
def __call__(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2, **kwargs) -> Tensor:
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/nn/aggr/gmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch import Tensor

from torch_geometric.experimental import disable_dynamic_shapes
from torch_geometric.nn.aggr import Aggregation
from torch_geometric.nn.aggr.utils import (
PoolingByMultiheadAttention,
Expand Down Expand Up @@ -65,6 +66,7 @@ def reset_parameters(self):
encoder.reset_parameters()
self.pma2.reset_parameters()

@disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements'])
def forward(
self,
x: Tensor,
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/nn/aggr/gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch import Tensor
from torch.nn import GRU

from torch_geometric.experimental import disable_dynamic_shapes
from torch_geometric.nn.aggr import Aggregation


Expand All @@ -29,6 +30,7 @@ def __init__(self, in_channels: int, out_channels: int, **kwargs):
def reset_parameters(self):
self.gru.reset_parameters()

@disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements'])
def forward(
self,
x: Tensor,
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/nn/aggr/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch import Tensor
from torch.nn import LSTM

from torch_geometric.experimental import disable_dynamic_shapes
from torch_geometric.nn.aggr import Aggregation


Expand All @@ -29,6 +30,7 @@ def __init__(self, in_channels: int, out_channels: int, **kwargs):
def reset_parameters(self):
self.lstm.reset_parameters()

@disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements'])
def forward(
self,
x: Tensor,
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/nn/aggr/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from torch import Tensor

from torch_geometric.experimental import disable_dynamic_shapes
from torch_geometric.nn.aggr import Aggregation


Expand All @@ -22,6 +23,7 @@ class MLPAggregation(Aggregation):
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.models.MLP`.
"""
@disable_dynamic_shapes(required_args=['max_num_elements'])
def __init__(self, in_channels: int, out_channels: int,
max_num_elements: int, **kwargs):
super().__init__()
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/nn/aggr/set_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch import Tensor

from torch_geometric.experimental import disable_dynamic_shapes
from torch_geometric.nn.aggr import Aggregation
from torch_geometric.nn.aggr.utils import (
PoolingByMultiheadAttention,
Expand Down Expand Up @@ -73,6 +74,7 @@ def reset_parameters(self):
for decoder in self.decoders:
decoder.reset_parameters()

@disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements'])
def forward(
self,
x: Tensor,
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/nn/aggr/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch import Tensor

from torch_geometric.experimental import disable_dynamic_shapes
from torch_geometric.nn.aggr import Aggregation


Expand All @@ -20,6 +21,7 @@ def __init__(self, k: int):
super().__init__()
self.k = k

@disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements'])
def forward(
self,
x: Tensor,
Expand Down

0 comments on commit a9d82a9

Please sign in to comment.