Skip to content

Commit

Permalink
Add TorchScript support for RECT_L (#6727)
Browse files Browse the repository at this point in the history
This PR add the TorchScript support for `RECT_L` model.

The fail reason and our solution for original code is very similar with
PR [#6721](#6712),
except that this model using the `torch.jit.export` on `embed` and
`get_semantic_labels` methods. And another fail reason is
`@torch.no_grad`, which bring some error msg I cann't understand.

Adding TorchScript support will bring a lot of extra code and reduce the
code readability. I will consider how to do better in another PR.

---------

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
ftxj and rusty1s authored Feb 17, 2023
1 parent a0ffd6f commit 392f501
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added TorchScript support to the `RECT_L` model ([#6727](https://github.com/pyg-team/pytorch_geometric/pull/6727))
- Added TorchScript support to the `Node2Vec` model ([#6726](https://github.com/pyg-team/pytorch_geometric/pull/6726))
- Added `utils.to_edge_index` to convert sparse tensors to edge indices and edge attributes ([#6728](https://github.com/pyg-team/pytorch_geometric/issues/6728))
- Fixed expected data format in `PolBlogs` dataset ([#6714](https://github.com/pyg-team/pytorch_geometric/issues/6714))
Expand Down
26 changes: 22 additions & 4 deletions test/nn/models/test_rect.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,42 @@
import torch
from torch_sparse import SparseTensor

from torch_geometric.nn import RECT_L
from torch_geometric.testing import is_full_test


def test_rect():
x = torch.randn(6, 8)
y = torch.tensor([1, 0, 0, 2, 1, 1])
edge_index = torch.tensor([[0, 1, 1, 2, 4, 5], [1, 0, 2, 1, 5, 4]])
adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(6, 6))
mask = torch.randint(0, 2, (6, ), dtype=torch.bool)

model = RECT_L(8, 16)
assert str(model) == 'RECT_L(8, 16)'

out = model(x, edge_index)
assert out.size() == (6, 8)
assert torch.allclose(out, model(x, adj.t()))

# Test `embed`:
out = model.embed(x, edge_index)
assert out.size() == (6, 16)
embed_out = model.embed(x, edge_index)
assert embed_out.size() == (6, 16)
assert torch.allclose(embed_out, model.embed(x, adj.t()))

# Test `get_semantic_labels`:
out = model.get_semantic_labels(x, y, mask)
assert out.size() == (int(mask.sum()), 8)
labeds_out = model.get_semantic_labels(x, y, mask)
assert labeds_out.size() == (int(mask.sum()), 8)

if is_full_test():
t = '(Tensor, Tensor, OptTensor) -> Tensor'
jit = torch.jit.script(model.jittable(t))
assert torch.allclose(jit(x, edge_index), out)
assert torch.allclose(embed_out, jit.embed(x, edge_index))
assert torch.allclose(labeds_out, jit.get_semantic_labels(x, y, mask))

t = '(Tensor, SparseTensor, OptTensor) -> Tensor'
jit = torch.jit.script(model.jittable(t))
assert torch.allclose(jit(x, adj.t()), out)
assert torch.allclose(embed_out, jit.embed(x, adj.t()))
assert torch.allclose(labeds_out, jit.get_semantic_labels(x, y, mask))
95 changes: 87 additions & 8 deletions torch_geometric/nn/models/rect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.nn import Linear

from torch_geometric.nn import GCNConv
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.typing import Adj, OptTensor, SparseTensor
from torch_geometric.utils import scatter


Expand Down Expand Up @@ -48,25 +48,104 @@ def reset_parameters(self):
self.lin.reset_parameters()
torch.nn.init.xavier_uniform_(self.lin.weight.data)

@torch.jit._overload_method
def forward(self, x, edge_index, edge_weight=None):
# type: (Tensor, SparseTensor, OptTensor) -> Tensor
pass

@torch.jit._overload_method
def forward(self, x, edge_index, edge_weight=None):
# type: (Tensor, Tensor, OptTensor) -> Tensor
pass

def forward(self, x: Tensor, edge_index: Adj,
edge_weight: OptTensor = None) -> Tensor:
""""""
x = self.conv(x, edge_index, edge_weight)
x = F.dropout(x, p=self.dropout, training=self.training)
return self.lin(x)

@torch.no_grad()
@torch.jit._overload_method
def embed(self, x, edge_index, edge_weight=None):
# type: (Tensor, SparseTensor, OptTensor) -> Tensor
pass

@torch.jit._overload_method
def embed(self, x, edge_index, edge_weight=None):
# type: (Tensor, Tensor, OptTensor) -> Tensor
pass

def embed(self, x: Tensor, edge_index: Adj,
edge_weight: OptTensor = None) -> Tensor:
return self.conv(x, edge_index, edge_weight)
with torch.no_grad():
return self.conv(x, edge_index, edge_weight)

@torch.no_grad()
def get_semantic_labels(self, x: Tensor, y: Tensor,
mask: Tensor) -> Tensor:
"""Replaces the original labels by their class-centers."""
y = y[mask]
mean = scatter(x[mask], y, dim=0, reduce='mean')
return mean[y]
r"""Replaces the original labels by their class-centers."""
with torch.no_grad():
y = y[mask]
mean = scatter(x[mask], y, dim=0, reduce='mean')
return mean[y]

def jittable(self, typing: str) -> torch.nn.Module: # pragma: no cover
edge_index_type = typing.split(',')[1].strip()

class EdgeIndexJittable(torch.nn.Module):
def __init__(self, child):
super().__init__()
self.child = child

def reset_parameters(self):
self.child.reset_parameters()

def forward(self, x: Tensor, edge_index: Tensor,
edge_weight: OptTensor = None) -> Tensor:
return self.child(x, edge_index, edge_weight)

@torch.jit.export
def embed(self, x: Tensor, edge_index: Tensor,
edge_weight: OptTensor = None) -> Tensor:
return self.child.embed(x, edge_index, edge_weight)

@torch.jit.export
def get_semantic_labels(self, x: Tensor, y: Tensor,
mask: Tensor) -> Tensor:
return self.child.get_semantic_labels(x, y, mask)

class SparseTensorJittable(torch.nn.Module):
def __init__(self, child):
super().__init__()
self.child = child

def reset_parameters(self):
self.child.reset_parameters()

def forward(self, x: Tensor, edge_index: SparseTensor,
edge_weight: OptTensor = None):
return self.child(x, edge_index, edge_weight)

@torch.jit.export
def embed(self, x: Tensor, edge_index: SparseTensor,
edge_weight: OptTensor = None) -> Tensor:
return self.child.embed(x, edge_index, edge_weight)

@torch.jit.export
def get_semantic_labels(self, x: Tensor, y: Tensor,
mask: Tensor) -> Tensor:
return self.child.get_semantic_labels(x, y, mask)

if self.conv.jittable is not None:
self.conv = self.conv.jittable()

if 'Tensor' == edge_index_type:
jittable_module = EdgeIndexJittable(self)
elif 'SparseTensor' == edge_index_type:
jittable_module = SparseTensorJittable(self)
else:
raise ValueError(f"Could not parse types '{typing}'")

return jittable_module

def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
Expand Down

0 comments on commit 392f501

Please sign in to comment.