Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make LinkX model jittable #6712

Merged
merged 6 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added TorchScript support to the `LINKX` model ([#6712](https://github.com/pyg-team/pytorch_geometric/pull/6712))
- Added `torch.jit` examples for `example/film.py` and `example/gcn.py`([#6602](https://github.com/pyg-team/pytorch_geometric/pull/6692))
- Added `Pad` transform ([#5940](https://github.com/pyg-team/pytorch_geometric/pull/5940))
- Added full batch mode to the inference benchmark ([#6631](https://github.com/pyg-team/pytorch_geometric/pull/6631))
Expand Down
10 changes: 10 additions & 0 deletions test/nn/models/test_linkx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch_sparse import SparseTensor

from torch_geometric.nn import LINKX
from torch_geometric.testing import is_full_test


@pytest.mark.parametrize('num_edge_layers', [1, 2])
Expand All @@ -22,6 +23,15 @@ def test_linkx(num_edge_layers):
assert out.size() == (4, 8)
assert torch.allclose(out, model(x, adj1.t()), atol=1e-4)

if is_full_test():
t = '(OptTensor, Tensor, OptTensor) -> Tensor'
jit = torch.jit.script(model.jittable(t))
assert torch.allclose(jit(x, edge_index), out)

t = '(OptTensor, SparseTensor, OptTensor) -> Tensor'
jit = torch.jit.script(model.jittable(t))
assert torch.allclose(jit(x, adj1.t()), out)

out = model(None, edge_index)
assert out.size() == (4, 8)
assert torch.allclose(out, model(None, adj1.t()), atol=1e-4)
Expand Down
86 changes: 80 additions & 6 deletions torch_geometric/nn/models/linkx.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,28 @@ def reset_parameters(self):
a=math.sqrt(5))
inits.uniform(self.in_channels, self.bias)

def forward(self, edge_index: Adj,
edge_weight: OptTensor = None) -> Tensor:
@torch.jit._overload_method
def forward(self, edge_index, edge_weight=None):
# type: (SparseTensor, OptTensor) -> Tensor
pass

@torch.jit._overload_method
def forward(self, edge_index, edge_weight=None):
# type: (Tensor, OptTensor) -> Tensor
pass

def forward(
self,
edge_index: Adj,
edge_weight: OptTensor = None,
) -> Tensor:
# propagate_type: (weight: Tensor, edge_weight: OptTensor)
out = self.propagate(edge_index, weight=self.weight,
edge_weight=edge_weight, size=None)

if self.bias is not None:
out = out + self.bias

return out

def message(self, weight_j: Tensor, edge_weight: OptTensor) -> Tensor:
Expand Down Expand Up @@ -102,10 +117,14 @@ def __init__(
self.num_edge_layers = num_edge_layers

self.edge_lin = SparseLinear(num_nodes, hidden_channels)

if self.num_edge_layers > 1:
self.edge_norm = BatchNorm1d(hidden_channels)
channels = [hidden_channels] * num_edge_layers
self.edge_mlp = MLP(channels, dropout=0., act_first=True)
else:
self.edge_norm = None
self.edge_mlp = None

channels = [in_channels] + [hidden_channels] * num_node_layers
self.node_mlp = MLP(channels, dropout=0., act_first=True)
Expand All @@ -121,19 +140,35 @@ def __init__(
def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
self.edge_lin.reset_parameters()
if self.num_edge_layers > 1:
if self.edge_norm is not None:
self.edge_norm.reset_parameters()
if self.edge_mlp is not None:
self.edge_mlp.reset_parameters()
self.node_mlp.reset_parameters()
self.cat_lin1.reset_parameters()
self.cat_lin2.reset_parameters()
self.final_mlp.reset_parameters()

def forward(self, x: OptTensor, edge_index: Adj,
edge_weight: OptTensor = None) -> Tensor:
@torch.jit._overload_method
def forward(self, x, edge_index, edge_weight=None):
# type: (OptTensor, SparseTensor, OptTensor) -> Tensor
pass

@torch.jit._overload_method
def forward(self, x, edge_index, edge_weight=None):
# type: (OptTensor, Tensor, OptTensor) -> Tensor
pass

def forward(
self,
x: OptTensor,
edge_index: Adj,
edge_weight: OptTensor = None,
) -> Tensor:
""""""
out = self.edge_lin(edge_index, edge_weight)
if self.num_edge_layers > 1:

if self.edge_norm is not None and self.edge_mlp is not None:
out = out.relu_()
out = self.edge_norm(out)
out = self.edge_mlp(out)
Expand All @@ -147,6 +182,45 @@ def forward(self, x: OptTensor, edge_index: Adj,

return self.final_mlp(out.relu_())

def jittable(self, typing: str) -> torch.nn.Module: # pragma: no cover
edge_index_type = typing.split(',')[1].strip()

class EdgeIndexJittable(torch.nn.Module):
def __init__(self, child):
super().__init__()
self.child = child

def reset_parameters(self):
self.child.reset_parameters()

def forward(self, x: Tensor, edge_index: Tensor,
edge_weight: OptTensor = None) -> Tensor:
return self.child(x, edge_index, edge_weight)

class SparseTensorJittable(torch.nn.Module):
def __init__(self, child):
super().__init__()
self.child = child

def reset_parameters(self):
self.child.reset_parameters()

def forward(self, x: Tensor, edge_index: SparseTensor,
edge_weight: OptTensor = None):
return self.child(x, edge_index, edge_weight)

if self.edge_lin.jittable is not None:
self.edge_lin = self.edge_lin.jittable()

if 'Tensor' == edge_index_type:
jittable_module = EdgeIndexJittable(self)
elif 'SparseTensor' == edge_index_type:
jittable_module = SparseTensorJittable(self)
else:
raise ValueError(f"Could not parse types '{typing}'")

return jittable_module

def __repr__(self) -> str:
return (f'{self.__class__.__name__}(num_nodes={self.num_nodes}, '
f'in_channels={self.in_channels}, '
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/nn/models/mlp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -189,7 +189,7 @@ def forward(
self,
x: Tensor,
return_emb: NoneType = None,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
) -> Tensor:
r"""
Args:
x (torch.Tensor): The source tensor.
Expand Down