Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Type Hints] DenseSAGEConv and DenseGCNConv #5664

Merged
merged 13 commits into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222))
### Changed
- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641))
- Improved type hint support ([#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664))
- Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642))
- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669))
Expand Down
5 changes: 5 additions & 0 deletions test/nn/dense/test_dense_gcn_conv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch

from torch_geometric.nn import DenseGCNConv, GCNConv
from torch_geometric.testing import is_full_test


def test_dense_gcn_conv():
Expand Down Expand Up @@ -38,6 +39,10 @@ def test_dense_gcn_conv():
dense_out = dense_conv(x, adj, mask)
assert dense_out.size() == (2, 3, channels)

if is_full_test():
jit = torch.jit.script(dense_conv)
assert torch.allclose(jit(x, adj, mask), dense_out)

assert dense_out[1, 2].abs().sum().item() == 0
dense_out = dense_out.view(6, channels)[:-1]
assert torch.allclose(sparse_out, dense_out, atol=1e-04)
Expand Down
5 changes: 5 additions & 0 deletions test/nn/dense/test_dense_sage_conv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch

from torch_geometric.nn import DenseSAGEConv, SAGEConv
from torch_geometric.testing import is_full_test


def test_dense_sage_conv():
Expand Down Expand Up @@ -38,6 +39,10 @@ def test_dense_sage_conv():
dense_out = dense_conv(x, adj, mask)
assert dense_out.size() == (2, 3, channels)

if is_full_test():
jit = torch.jit.script(dense_conv)
assert torch.allclose(jit(x, adj, mask), dense_out)

assert dense_out[1, 2].abs().sum().item() == 0
dense_out = dense_out.view(6, channels)[:-1]
assert torch.allclose(sparse_out, dense_out, atol=1e-04)
Expand Down
13 changes: 11 additions & 2 deletions torch_geometric/nn/dense/dense_gcn_conv.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
import torch
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import OptTensor


class DenseGCNConv(torch.nn.Module):
r"""See :class:`torch_geometric.nn.conv.GCNConv`.
"""
def __init__(self, in_channels, out_channels, improved=False, bias=True):
def __init__(
self,
in_channels: int,
out_channels: int,
improved: bool = False,
bias: bool = True,
):
super().__init__()

self.in_channels = in_channels
Expand All @@ -29,7 +37,8 @@ def reset_parameters(self):
self.lin.reset_parameters()
zeros(self.bias)

def forward(self, x, adj, mask=None, add_loop=True):
def forward(self, x: Tensor, adj: Tensor, mask: OptTensor = None,
add_loop: bool = True) -> Tensor:
r"""
Args:
x (Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B
Expand Down
16 changes: 13 additions & 3 deletions torch_geometric/nn/dense/dense_sage_conv.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Linear

from torch_geometric.typing import OptTensor


class DenseSAGEConv(torch.nn.Module):
r"""See :class:`torch_geometric.nn.conv.SAGEConv`.
Expand All @@ -14,7 +17,13 @@ class DenseSAGEConv(torch.nn.Module):
use :class:`torch_geometric.nn.dense.DenseGraphConv` instead.

"""
def __init__(self, in_channels, out_channels, normalize=False, bias=True):
def __init__(
self,
in_channels: int,
out_channels: int,
normalize: bool = False,
bias: bool = True,
):
super().__init__()

self.in_channels = in_channels
Expand All @@ -30,7 +39,8 @@ def reset_parameters(self):
self.lin_rel.reset_parameters()
self.lin_root.reset_parameters()

def forward(self, x, adj, mask=None):
def forward(self, x: Tensor, adj: Tensor,
mask: OptTensor = None) -> Tensor:
r"""
Args:
x (Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B
Expand All @@ -54,7 +64,7 @@ def forward(self, x, adj, mask=None):
out = self.lin_rel(out) + self.lin_root(x)

if self.normalize:
out = F.normalize(out, p=2, dim=-1)
out = F.normalize(out, p=2.0, dim=-1)

if mask is not None:
out = out * mask.view(B, N, 1).to(x.dtype)
Expand Down