diff --git a/CHANGELOG.md b/CHANGELOG.md index c3fcb9d69f73..a988db773875 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added TorchScript support to the `LINKX` model ([#6712](https://github.com/pyg-team/pytorch_geometric/pull/6712)) - Added `torch.jit` examples for `example/film.py` and `example/gcn.py`([#6602](https://github.com/pyg-team/pytorch_geometric/pull/6692)) - Added `Pad` transform ([#5940](https://github.com/pyg-team/pytorch_geometric/pull/5940)) - Added full batch mode to the inference benchmark ([#6631](https://github.com/pyg-team/pytorch_geometric/pull/6631)) diff --git a/test/nn/models/test_linkx.py b/test/nn/models/test_linkx.py index 5c592c3afcaa..d0ed377bc26a 100644 --- a/test/nn/models/test_linkx.py +++ b/test/nn/models/test_linkx.py @@ -3,6 +3,7 @@ from torch_sparse import SparseTensor from torch_geometric.nn import LINKX +from torch_geometric.testing import is_full_test @pytest.mark.parametrize('num_edge_layers', [1, 2]) @@ -22,6 +23,15 @@ def test_linkx(num_edge_layers): assert out.size() == (4, 8) assert torch.allclose(out, model(x, adj1.t()), atol=1e-4) + if is_full_test(): + t = '(OptTensor, Tensor, OptTensor) -> Tensor' + jit = torch.jit.script(model.jittable(t)) + assert torch.allclose(jit(x, edge_index), out) + + t = '(OptTensor, SparseTensor, OptTensor) -> Tensor' + jit = torch.jit.script(model.jittable(t)) + assert torch.allclose(jit(x, adj1.t()), out) + out = model(None, edge_index) assert out.size() == (4, 8) assert torch.allclose(out, model(None, adj1.t()), atol=1e-4) diff --git a/torch_geometric/nn/models/linkx.py b/torch_geometric/nn/models/linkx.py index aa16d33e4502..5804f532c3ba 100644 --- a/torch_geometric/nn/models/linkx.py +++ b/torch_geometric/nn/models/linkx.py @@ -30,13 +30,28 @@ def reset_parameters(self): a=math.sqrt(5)) inits.uniform(self.in_channels, self.bias) - def forward(self, edge_index: Adj, - edge_weight: OptTensor = None) -> Tensor: + @torch.jit._overload_method + def forward(self, edge_index, edge_weight=None): + # type: (SparseTensor, OptTensor) -> Tensor + pass + + @torch.jit._overload_method + def forward(self, edge_index, edge_weight=None): + # type: (Tensor, OptTensor) -> Tensor + pass + + def forward( + self, + edge_index: Adj, + edge_weight: OptTensor = None, + ) -> Tensor: # propagate_type: (weight: Tensor, edge_weight: OptTensor) out = self.propagate(edge_index, weight=self.weight, edge_weight=edge_weight, size=None) + if self.bias is not None: out = out + self.bias + return out def message(self, weight_j: Tensor, edge_weight: OptTensor) -> Tensor: @@ -102,10 +117,14 @@ def __init__( self.num_edge_layers = num_edge_layers self.edge_lin = SparseLinear(num_nodes, hidden_channels) + if self.num_edge_layers > 1: self.edge_norm = BatchNorm1d(hidden_channels) channels = [hidden_channels] * num_edge_layers self.edge_mlp = MLP(channels, dropout=0., act_first=True) + else: + self.edge_norm = None + self.edge_mlp = None channels = [in_channels] + [hidden_channels] * num_node_layers self.node_mlp = MLP(channels, dropout=0., act_first=True) @@ -121,19 +140,35 @@ def __init__( def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.edge_lin.reset_parameters() - if self.num_edge_layers > 1: + if self.edge_norm is not None: self.edge_norm.reset_parameters() + if self.edge_mlp is not None: self.edge_mlp.reset_parameters() self.node_mlp.reset_parameters() self.cat_lin1.reset_parameters() self.cat_lin2.reset_parameters() self.final_mlp.reset_parameters() - def forward(self, x: OptTensor, edge_index: Adj, - edge_weight: OptTensor = None) -> Tensor: + @torch.jit._overload_method + def forward(self, x, edge_index, edge_weight=None): + # type: (OptTensor, SparseTensor, OptTensor) -> Tensor + pass + + @torch.jit._overload_method + def forward(self, x, edge_index, edge_weight=None): + # type: (OptTensor, Tensor, OptTensor) -> Tensor + pass + + def forward( + self, + x: OptTensor, + edge_index: Adj, + edge_weight: OptTensor = None, + ) -> Tensor: """""" out = self.edge_lin(edge_index, edge_weight) - if self.num_edge_layers > 1: + + if self.edge_norm is not None and self.edge_mlp is not None: out = out.relu_() out = self.edge_norm(out) out = self.edge_mlp(out) @@ -147,6 +182,45 @@ def forward(self, x: OptTensor, edge_index: Adj, return self.final_mlp(out.relu_()) + def jittable(self, typing: str) -> torch.nn.Module: # pragma: no cover + edge_index_type = typing.split(',')[1].strip() + + class EdgeIndexJittable(torch.nn.Module): + def __init__(self, child): + super().__init__() + self.child = child + + def reset_parameters(self): + self.child.reset_parameters() + + def forward(self, x: Tensor, edge_index: Tensor, + edge_weight: OptTensor = None) -> Tensor: + return self.child(x, edge_index, edge_weight) + + class SparseTensorJittable(torch.nn.Module): + def __init__(self, child): + super().__init__() + self.child = child + + def reset_parameters(self): + self.child.reset_parameters() + + def forward(self, x: Tensor, edge_index: SparseTensor, + edge_weight: OptTensor = None): + return self.child(x, edge_index, edge_weight) + + if self.edge_lin.jittable is not None: + self.edge_lin = self.edge_lin.jittable() + + if 'Tensor' == edge_index_type: + jittable_module = EdgeIndexJittable(self) + elif 'SparseTensor' == edge_index_type: + jittable_module = SparseTensorJittable(self) + else: + raise ValueError(f"Could not parse types '{typing}'") + + return jittable_module + def __repr__(self) -> str: return (f'{self.__class__.__name__}(num_nodes={self.num_nodes}, ' f'in_channels={self.in_channels}, ' diff --git a/torch_geometric/nn/models/mlp.py b/torch_geometric/nn/models/mlp.py index 26356b09e341..402a6c087363 100644 --- a/torch_geometric/nn/models/mlp.py +++ b/torch_geometric/nn/models/mlp.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.nn.functional as F @@ -189,7 +189,7 @@ def forward( self, x: Tensor, return_emb: NoneType = None, - ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + ) -> Tensor: r""" Args: x (torch.Tensor): The source tensor.