Skip to content

Commit

Permalink
[Type Hints] DenseSAGEConv and DenseGCNConv (#5664)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored Oct 13, 2022
1 parent 75bcddd commit 03e8414
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 6 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641))
- Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642))
- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669))
- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614))
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))
Expand Down
5 changes: 5 additions & 0 deletions test/nn/dense/test_dense_gcn_conv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch

from torch_geometric.nn import DenseGCNConv, GCNConv
from torch_geometric.testing import is_full_test


def test_dense_gcn_conv():
Expand Down Expand Up @@ -38,6 +39,10 @@ def test_dense_gcn_conv():
dense_out = dense_conv(x, adj, mask)
assert dense_out.size() == (2, 3, channels)

if is_full_test():
jit = torch.jit.script(dense_conv)
assert torch.allclose(jit(x, adj, mask), dense_out)

assert dense_out[1, 2].abs().sum().item() == 0
dense_out = dense_out.view(6, channels)[:-1]
assert torch.allclose(sparse_out, dense_out, atol=1e-04)
Expand Down
5 changes: 5 additions & 0 deletions test/nn/dense/test_dense_sage_conv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch

from torch_geometric.nn import DenseSAGEConv, SAGEConv
from torch_geometric.testing import is_full_test


def test_dense_sage_conv():
Expand Down Expand Up @@ -38,6 +39,10 @@ def test_dense_sage_conv():
dense_out = dense_conv(x, adj, mask)
assert dense_out.size() == (2, 3, channels)

if is_full_test():
jit = torch.jit.script(dense_conv)
assert torch.allclose(jit(x, adj, mask), dense_out)

assert dense_out[1, 2].abs().sum().item() == 0
dense_out = dense_out.view(6, channels)[:-1]
assert torch.allclose(sparse_out, dense_out, atol=1e-04)
Expand Down
13 changes: 11 additions & 2 deletions torch_geometric/nn/dense/dense_gcn_conv.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
import torch
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import OptTensor


class DenseGCNConv(torch.nn.Module):
r"""See :class:`torch_geometric.nn.conv.GCNConv`.
"""
def __init__(self, in_channels, out_channels, improved=False, bias=True):
def __init__(
self,
in_channels: int,
out_channels: int,
improved: bool = False,
bias: bool = True,
):
super().__init__()

self.in_channels = in_channels
Expand All @@ -29,7 +37,8 @@ def reset_parameters(self):
self.lin.reset_parameters()
zeros(self.bias)

def forward(self, x, adj, mask=None, add_loop=True):
def forward(self, x: Tensor, adj: Tensor, mask: OptTensor = None,
add_loop: bool = True) -> Tensor:
r"""
Args:
x (Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B
Expand Down
16 changes: 13 additions & 3 deletions torch_geometric/nn/dense/dense_sage_conv.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Linear

from torch_geometric.typing import OptTensor


class DenseSAGEConv(torch.nn.Module):
r"""See :class:`torch_geometric.nn.conv.SAGEConv`.
Expand All @@ -14,7 +17,13 @@ class DenseSAGEConv(torch.nn.Module):
use :class:`torch_geometric.nn.dense.DenseGraphConv` instead.
"""
def __init__(self, in_channels, out_channels, normalize=False, bias=True):
def __init__(
self,
in_channels: int,
out_channels: int,
normalize: bool = False,
bias: bool = True,
):
super().__init__()

self.in_channels = in_channels
Expand All @@ -30,7 +39,8 @@ def reset_parameters(self):
self.lin_rel.reset_parameters()
self.lin_root.reset_parameters()

def forward(self, x, adj, mask=None):
def forward(self, x: Tensor, adj: Tensor,
mask: OptTensor = None) -> Tensor:
r"""
Args:
x (Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B
Expand All @@ -54,7 +64,7 @@ def forward(self, x, adj, mask=None):
out = self.lin_rel(out) + self.lin_root(x)

if self.normalize:
out = F.normalize(out, p=2, dim=-1)
out = F.normalize(out, p=2.0, dim=-1)

if mask is not None:
out = out * mask.view(B, N, 1).to(x.dtype)
Expand Down

0 comments on commit 03e8414

Please sign in to comment.