Skip to content

Commit

Permalink
Make LinkX model jittable (#6712)
Browse files Browse the repository at this point in the history
This PR make the LinkX model to be jittable and get better performance.

The original LinkX model cann't be JIT due to many reasons:
1. The type of second parameter `edge_index` of LinkX.forward is `Adj`.
So we need to add `@torch.jit._overload_method` for this method.
2. The `edge_norm` and `edge_mlp` in LinkX model isn't always
initialized. So we modify the initialize logic.
3. After we add `@torch.jit._overload_method` for forward, we need a
wrapper of model. Overwise, the TorchScript can't find the forward
function.
4. The return type of MLP is an `Union`, which will caused type error in
TorchScript when we use the return value of MLP.

I think there might be a better way to make it jittable, but I haven't
found yet.

---------

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
ftxj and rusty1s authored Feb 15, 2023
1 parent 844fc10 commit 34668c7
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added TorchScript support to the `LINKX` model ([#6712](https://github.com/pyg-team/pytorch_geometric/pull/6712))
- Added `torch.jit` examples for `example/film.py` and `example/gcn.py`([#6602](https://github.com/pyg-team/pytorch_geometric/pull/6692))
- Added `Pad` transform ([#5940](https://github.com/pyg-team/pytorch_geometric/pull/5940))
- Added full batch mode to the inference benchmark ([#6631](https://github.com/pyg-team/pytorch_geometric/pull/6631))
Expand Down
10 changes: 10 additions & 0 deletions test/nn/models/test_linkx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch_sparse import SparseTensor

from torch_geometric.nn import LINKX
from torch_geometric.testing import is_full_test


@pytest.mark.parametrize('num_edge_layers', [1, 2])
Expand All @@ -22,6 +23,15 @@ def test_linkx(num_edge_layers):
assert out.size() == (4, 8)
assert torch.allclose(out, model(x, adj1.t()), atol=1e-4)

if is_full_test():
t = '(OptTensor, Tensor, OptTensor) -> Tensor'
jit = torch.jit.script(model.jittable(t))
assert torch.allclose(jit(x, edge_index), out)

t = '(OptTensor, SparseTensor, OptTensor) -> Tensor'
jit = torch.jit.script(model.jittable(t))
assert torch.allclose(jit(x, adj1.t()), out)

out = model(None, edge_index)
assert out.size() == (4, 8)
assert torch.allclose(out, model(None, adj1.t()), atol=1e-4)
Expand Down
86 changes: 80 additions & 6 deletions torch_geometric/nn/models/linkx.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,28 @@ def reset_parameters(self):
a=math.sqrt(5))
inits.uniform(self.in_channels, self.bias)

def forward(self, edge_index: Adj,
edge_weight: OptTensor = None) -> Tensor:
@torch.jit._overload_method
def forward(self, edge_index, edge_weight=None):
# type: (SparseTensor, OptTensor) -> Tensor
pass

@torch.jit._overload_method
def forward(self, edge_index, edge_weight=None):
# type: (Tensor, OptTensor) -> Tensor
pass

def forward(
self,
edge_index: Adj,
edge_weight: OptTensor = None,
) -> Tensor:
# propagate_type: (weight: Tensor, edge_weight: OptTensor)
out = self.propagate(edge_index, weight=self.weight,
edge_weight=edge_weight, size=None)

if self.bias is not None:
out = out + self.bias

return out

def message(self, weight_j: Tensor, edge_weight: OptTensor) -> Tensor:
Expand Down Expand Up @@ -102,10 +117,14 @@ def __init__(
self.num_edge_layers = num_edge_layers

self.edge_lin = SparseLinear(num_nodes, hidden_channels)

if self.num_edge_layers > 1:
self.edge_norm = BatchNorm1d(hidden_channels)
channels = [hidden_channels] * num_edge_layers
self.edge_mlp = MLP(channels, dropout=0., act_first=True)
else:
self.edge_norm = None
self.edge_mlp = None

channels = [in_channels] + [hidden_channels] * num_node_layers
self.node_mlp = MLP(channels, dropout=0., act_first=True)
Expand All @@ -121,19 +140,35 @@ def __init__(
def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
self.edge_lin.reset_parameters()
if self.num_edge_layers > 1:
if self.edge_norm is not None:
self.edge_norm.reset_parameters()
if self.edge_mlp is not None:
self.edge_mlp.reset_parameters()
self.node_mlp.reset_parameters()
self.cat_lin1.reset_parameters()
self.cat_lin2.reset_parameters()
self.final_mlp.reset_parameters()

def forward(self, x: OptTensor, edge_index: Adj,
edge_weight: OptTensor = None) -> Tensor:
@torch.jit._overload_method
def forward(self, x, edge_index, edge_weight=None):
# type: (OptTensor, SparseTensor, OptTensor) -> Tensor
pass

@torch.jit._overload_method
def forward(self, x, edge_index, edge_weight=None):
# type: (OptTensor, Tensor, OptTensor) -> Tensor
pass

def forward(
self,
x: OptTensor,
edge_index: Adj,
edge_weight: OptTensor = None,
) -> Tensor:
""""""
out = self.edge_lin(edge_index, edge_weight)
if self.num_edge_layers > 1:

if self.edge_norm is not None and self.edge_mlp is not None:
out = out.relu_()
out = self.edge_norm(out)
out = self.edge_mlp(out)
Expand All @@ -147,6 +182,45 @@ def forward(self, x: OptTensor, edge_index: Adj,

return self.final_mlp(out.relu_())

def jittable(self, typing: str) -> torch.nn.Module: # pragma: no cover
edge_index_type = typing.split(',')[1].strip()

class EdgeIndexJittable(torch.nn.Module):
def __init__(self, child):
super().__init__()
self.child = child

def reset_parameters(self):
self.child.reset_parameters()

def forward(self, x: Tensor, edge_index: Tensor,
edge_weight: OptTensor = None) -> Tensor:
return self.child(x, edge_index, edge_weight)

class SparseTensorJittable(torch.nn.Module):
def __init__(self, child):
super().__init__()
self.child = child

def reset_parameters(self):
self.child.reset_parameters()

def forward(self, x: Tensor, edge_index: SparseTensor,
edge_weight: OptTensor = None):
return self.child(x, edge_index, edge_weight)

if self.edge_lin.jittable is not None:
self.edge_lin = self.edge_lin.jittable()

if 'Tensor' == edge_index_type:
jittable_module = EdgeIndexJittable(self)
elif 'SparseTensor' == edge_index_type:
jittable_module = SparseTensorJittable(self)
else:
raise ValueError(f"Could not parse types '{typing}'")

return jittable_module

def __repr__(self) -> str:
return (f'{self.__class__.__name__}(num_nodes={self.num_nodes}, '
f'in_channels={self.in_channels}, '
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/nn/models/mlp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -189,7 +189,7 @@ def forward(
self,
x: Tensor,
return_emb: NoneType = None,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
) -> Tensor:
r"""
Args:
x (torch.Tensor): The source tensor.
Expand Down

0 comments on commit 34668c7

Please sign in to comment.