Skip to content

Commit

Permalink
Merge branch 'nn-mixhop' of github.com:xnuohz/pytorch_geometric into …
Browse files Browse the repository at this point in the history
…nn-mixhop
  • Loading branch information
rusty1s committed Sep 14, 2023
2 parents 0c74bde + 8d122c3 commit 0b27322
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 12 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Added the `MixHopConv` layer and an corresponding example ([#8025](https://github.com/pyg-team/pytorch_geometric/pull/8025))
- Added the option to pass keyword arguments to the underlying normalization layers within `BasicGNN` and `MLP` ([#8024](https://github.com/pyg-team/pytorch_geometric/pull/8024))
- Added `IBMBNodeLoader` and `IBMBBatchLoader` data loaders ([#6230](https://github.com/pyg-team/pytorch_geometric/pull/6230))
- Added the `NeuralFingerprint` model for learning fingerprints of molecules ([#7919](https://github.com/pyg-team/pytorch_geometric/pull/7919))
- Added `SparseTensor` support to `WLConvContinuous`, `GeneralConv`, `PDNConv` and `ARMAConv` ([#8013](https://github.com/pyg-team/pytorch_geometric/pull/8013))
- Added `LCMAggregation`, an implementation of Learnable Communitive Monoids ([#7976](https://github.com/pyg-team/pytorch_geometric/pull/7976), [#8023](https://github.com/pyg-team/pytorch_geometric/pull/8023))
- Added `LCMAggregation`, an implementation of Learnable Communitive Monoids ([#7976](https://github.com/pyg-team/pytorch_geometric/pull/7976), [#8023](https://github.com/pyg-team/pytorch_geometric/pull/8023), [#8026](https://github.com/pyg-team/pytorch_geometric/pull/8026))
- Added a warning for isolated/non-existing node types in `HeteroData.validate()` ([#7995](https://github.com/pyg-team/pytorch_geometric/pull/7995))
- Added `utils.cumsum` implementation ([#7994](https://github.com/pyg-team/pytorch_geometric/pull/7994))
- Added the `BrcaTcga` dataset ([#7905](https://github.com/pyg-team/pytorch_geometric/pull/7905))
Expand Down
4 changes: 2 additions & 2 deletions test/nn/aggr/test_lcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def test_lcm_aggregation_with_project():


def test_lcm_aggregation_without_project():
x = torch.randn(6, 16)
index = torch.tensor([0, 0, 1, 1, 1, 2])
x = torch.randn(5, 16)
index = torch.tensor([0, 1, 1, 2, 2])

aggr = LCMAggregation(16, 16, project=False)
assert str(aggr) == 'LCMAggregation(16, 16, project=False)'
Expand Down
22 changes: 22 additions & 0 deletions test/nn/models/test_basic_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,28 @@ def test_one_layer_gnn(out_dim, jk):
assert model(x, edge_index).size() == (3, out_channels)


@pytest.mark.parametrize('norm', [
'BatchNorm',
'GraphNorm',
'InstanceNorm',
'LayerNorm',
])
def test_batch(norm):
x = torch.randn(3, 8)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
batch = torch.tensor([0, 0, 1])

model = GraphSAGE(8, 16, num_layers=2, norm=norm)
assert model.supports_norm_batch == (norm != 'BatchNorm')

out = model(x, edge_index, batch=batch)
assert out.size() == (3, 16)

if model.supports_norm_batch:
with pytest.raises(RuntimeError, match="out of bounds"):
model(x, edge_index, batch=batch, batch_size=1)


@onlyOnline
@onlyNeighborSampler
@pytest.mark.parametrize('jk', [None, 'last'])
Expand Down
27 changes: 27 additions & 0 deletions test/nn/models/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,33 @@ def test_mlp(norm, act_first, plain_last):
assert torch.allclose(mlp(x), out)


@pytest.mark.parametrize('norm', [
'BatchNorm',
'GraphNorm',
'InstanceNorm',
'LayerNorm',
])
def test_batch(norm):
x = torch.randn(3, 8)
batch = torch.tensor([0, 0, 1])

model = MLP(
8,
hidden_channels=16,
out_channels=32,
num_layers=2,
norm=norm,
)
assert model.supports_norm_batch == (norm != 'BatchNorm')

out = model(x, batch=batch)
assert out.size() == (3, 32)

if model.supports_norm_batch:
with pytest.raises(RuntimeError, match="out of bounds"):
model(x, batch=batch, batch_size=1)


def test_mlp_return_emb():
x = torch.randn(4, 16)

Expand Down
6 changes: 3 additions & 3 deletions torch_geometric/nn/aggr/lcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,15 @@ def forward(
if x.size(0) % 2 == 1:
# This level of the tree has an odd number of nodes, so the
# remaining unmatched node gets moved to the next level.
x, remainder = x[:-1].contiguous(), x[-1:]
x, remainder = x[:-1], x[-1:]
else:
remainder = None

left_right = x.view(-1, 2, num_nodes, num_features)
right_left = left_right.flip(dims=[1])

left_right = left_right.view(-1, num_features)
right_left = right_left.view(-1, num_features)
left_right = left_right.reshape(-1, num_features)
right_left = right_left.reshape(-1, num_features)

# Execute the GRUCell for all (left, right) pairs in the current
# level of the tree in parallel:
Expand Down
50 changes: 46 additions & 4 deletions torch_geometric/nn/models/basic_gnn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import inspect
from typing import Any, Callable, Dict, Final, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -64,6 +65,7 @@ class BasicGNN(torch.nn.Module):
"""
supports_edge_weight: Final[bool]
supports_edge_attr: Final[bool]
supports_norm_batch: Final[bool]

def __init__(
self,
Expand Down Expand Up @@ -129,6 +131,12 @@ def __init__(
)
if norm_layer is None:
norm_layer = torch.nn.Identity()

self.supports_norm_batch = False
if hasattr(norm_layer, 'forward'):
norm_params = inspect.signature(norm_layer.forward).parameters
self.supports_norm_batch = 'batch' in norm_params

for _ in range(num_layers - 1):
self.norms.append(copy.deepcopy(norm_layer))

Expand Down Expand Up @@ -173,10 +181,12 @@ def forward( # noqa
edge_index,
edge_weight=None,
edge_attr=None,
batch=None,
batch_size=None,
num_sampled_nodes_per_hop=None,
num_sampled_edges_per_hop=None,
):
# type: (Tensor, Tensor, OptTensor, OptTensor, Optional[List[int]], Optional[List[int]]) -> Tensor # noqa
# type: (Tensor, Tensor, OptTensor, OptTensor, OptTensor, Optional[int], Optional[List[int]], Optional[List[int]]) -> Tensor # noqa
pass

@torch.jit._overload_method
Expand All @@ -185,10 +195,12 @@ def forward( # noqa
edge_index,
edge_weight=None,
edge_attr=None,
batch=None,
batch_size=None,
num_sampled_nodes_per_hop=None,
num_sampled_edges_per_hop=None,
):
# type: (Tensor, SparseTensor, OptTensor, OptTensor, Optional[List[int]], Optional[List[int]]) -> Tensor # noqa
# type: (Tensor, SparseTensor, OptTensor, OptTensor, OptTensor, Optional[int], Optional[List[int]], Optional[List[int]]) -> Tensor # noqa
pass

def forward( # noqa
Expand All @@ -197,9 +209,11 @@ def forward( # noqa
edge_index: Tensor, # TODO Support `SparseTensor` in type hint.
edge_weight: OptTensor = None,
edge_attr: OptTensor = None,
batch: OptTensor = None,
batch_size: Optional[int] = None,
num_sampled_nodes_per_hop: Optional[List[int]] = None,
num_sampled_edges_per_hop: Optional[List[int]] = None,
) -> Tensor:
):
r"""
Args:
x (torch.Tensor): The input node features.
Expand All @@ -208,6 +222,17 @@ def forward( # noqa
supported by the underlying GNN layer). (default: :obj:`None`)
edge_attr (torch.Tensor, optional): The edge features (if supported
by the underlying GNN layer). (default: :obj:`None`)
batch (torch.Tensor, optional): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
each element to a specific example.
Only needs to be passed in case the underlying normalization
layers require the :obj:`batch` information.
(default: :obj:`None`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given.
Only needs to be passed in case the underlying normalization
layers require the :obj:`batch` information.
(default: :obj:`None`)
num_sampled_nodes_per_hop (List[int], optional): The number of
sampled nodes per hop.
Useful in :class:`~torch_geometric.loader.NeighborLoader`
Expand Down Expand Up @@ -260,7 +285,10 @@ def forward( # noqa
if i < self.num_layers - 1 or self.jk_mode is not None:
if self.act is not None and self.act_first:
x = self.act(x)
x = norm(x)
if self.supports_norm_batch:
x = norm(x, batch, batch_size)
else:
x = norm(x)
if self.act is not None and not self.act_first:
x = self.act(x)
x = self.dropout(x)
Expand Down Expand Up @@ -397,6 +425,8 @@ def forward(
edge_index: Tensor,
edge_weight: OptTensor = None,
edge_attr: OptTensor = None,
batch: OptTensor = None,
batch_size: Optional[int] = None,
num_sampled_nodes_per_hop: Optional[List[int]] = None,
num_sampled_edges_per_hop: Optional[List[int]] = None,
) -> Tensor:
Expand All @@ -405,6 +435,8 @@ def forward(
edge_index,
edge_weight,
edge_attr,
batch,
batch_size,
num_sampled_nodes_per_hop,
num_sampled_edges_per_hop,
)
Expand All @@ -426,6 +458,8 @@ def forward(
edge_index: SparseTensor,
edge_weight: OptTensor = None,
edge_attr: OptTensor = None,
batch: OptTensor = None,
batch_size: Optional[int] = None,
num_sampled_nodes_per_hop: Optional[List[int]] = None,
num_sampled_edges_per_hop: Optional[List[int]] = None,
) -> Tensor:
Expand All @@ -434,6 +468,8 @@ def forward(
edge_index,
edge_weight,
edge_attr,
batch,
batch_size,
num_sampled_nodes_per_hop,
num_sampled_edges_per_hop,
)
Expand Down Expand Up @@ -492,6 +528,7 @@ class GCN(BasicGNN):
"""
supports_edge_weight: Final[bool] = True
supports_edge_attr: Final[bool] = False
supports_norm_batch: Final[bool]

def init_conv(self, in_channels: int, out_channels: int,
**kwargs) -> MessagePassing:
Expand Down Expand Up @@ -536,6 +573,7 @@ class GraphSAGE(BasicGNN):
"""
supports_edge_weight: Final[bool] = False
supports_edge_attr: Final[bool] = False
supports_norm_batch: Final[bool]

def init_conv(self, in_channels: Union[int, Tuple[int, int]],
out_channels: int, **kwargs) -> MessagePassing:
Expand Down Expand Up @@ -577,6 +615,7 @@ class GIN(BasicGNN):
"""
supports_edge_weight: Final[bool] = False
supports_edge_attr: Final[bool] = False
supports_norm_batch: Final[bool]

def init_conv(self, in_channels: int, out_channels: int,
**kwargs) -> MessagePassing:
Expand Down Expand Up @@ -635,6 +674,7 @@ class GAT(BasicGNN):
"""
supports_edge_weight: Final[bool] = False
supports_edge_attr: Final[bool] = True
supports_norm_batch: Final[bool]

def init_conv(self, in_channels: Union[int, Tuple[int, int]],
out_channels: int, **kwargs) -> MessagePassing:
Expand Down Expand Up @@ -697,6 +737,7 @@ class PNA(BasicGNN):
"""
supports_edge_weight: Final[bool] = False
supports_edge_attr: Final[bool] = True
supports_norm_batch: Final[bool]

def init_conv(self, in_channels: int, out_channels: int,
**kwargs) -> MessagePassing:
Expand Down Expand Up @@ -738,6 +779,7 @@ class EdgeCNN(BasicGNN):
"""
supports_edge_weight: Final[bool] = False
supports_edge_attr: Final[bool] = False
supports_norm_batch: Final[bool]

def init_conv(self, in_channels: int, out_channels: int,
**kwargs) -> MessagePassing:
Expand Down
28 changes: 26 additions & 2 deletions torch_geometric/nn/models/mlp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
import warnings
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, Final, List, Optional, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -71,6 +72,8 @@ class MLP(torch.nn.Module):
bias per layer. (default: :obj:`True`)
**kwargs (optional): Additional deprecated arguments of the MLP layer.
"""
supports_norm_batch: Final[bool]

def __init__(
self,
channel_list: Optional[Union[List[int], int]] = None,
Expand Down Expand Up @@ -160,6 +163,11 @@ def __init__(
norm_layer = Identity()
self.norms.append(norm_layer)

self.supports_norm_batch = False
if len(self.norms) > 0 and hasattr(self.norms[0], 'forward'):
norm_params = inspect.signature(self.norms[0].forward).parameters
self.supports_norm_batch = 'batch' in norm_params

self.reset_parameters()

@property
Expand Down Expand Up @@ -188,11 +196,24 @@ def reset_parameters(self):
def forward(
self,
x: Tensor,
batch: Optional[Tensor] = None,
batch_size: Optional[int] = None,
return_emb: NoneType = None,
) -> Tensor:
r"""
Args:
x (torch.Tensor): The source tensor.
batch (torch.Tensor, optional): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
each element to a specific example.
Only needs to be passed in case the underlying normalization
layers require the :obj:`batch` information.
(default: :obj:`None`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given.
Only needs to be passed in case the underlying normalization
layers require the :obj:`batch` information.
(default: :obj:`None`)
return_emb (bool, optional): If set to :obj:`True`, will
additionally return the embeddings before execution of the
final output layer. (default: :obj:`False`)
Expand All @@ -206,7 +227,10 @@ def forward(
x = lin(x)
if self.act is not None and self.act_first:
x = self.act(x)
x = norm(x)
if self.supports_norm_batch:
x = norm(x, batch, batch_size)
else:
x = norm(x)
if self.act is not None and not self.act_first:
x = self.act(x)
x = F.dropout(x, p=self.dropout[i], training=self.training)
Expand Down

0 comments on commit 0b27322

Please sign in to comment.