Skip to content

Commit

Permalink
[Type Hints] nn.RENet (#5715)
Browse files Browse the repository at this point in the history
## Changes

- added type hints to `nn.RENet`
- added unit-test using TorchScript to `test_re_net.py`

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
TanveshT and rusty1s authored Oct 17, 2022
1 parent 311cae0 commit ee09efe
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 11 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641))
- Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642))
- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735))
- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614))
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))
Expand Down
8 changes: 8 additions & 0 deletions test/nn/models/test_re_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch_geometric.datasets.icews import EventDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import RENet
from torch_geometric.testing import is_full_test


class MyTestEventDataset(EventDataset):
Expand Down Expand Up @@ -50,6 +51,9 @@ def test_re_net():
model = RENet(dataset.num_nodes, dataset.num_rels, hidden_channels=16,
seq_len=4)

if is_full_test():
jit = torch.jit.export(model)

logits = torch.randn(6, 6)
y = torch.tensor([0, 1, 2, 3, 4, 5])

Expand All @@ -59,6 +63,10 @@ def test_re_net():

for data in loader:
log_prob_obj, log_prob_sub = model(data)
if is_full_test():
log_prob_obj_jit, log_prob_sub_jit = jit(data)
assert torch.allclose(log_prob_obj_jit, log_prob_obj)
assert torch.allclose(log_prob_sub_jit, log_prob_sub)
model.test(log_prob_obj, data.obj)
model.test(log_prob_sub, data.sub)

Expand Down
37 changes: 27 additions & 10 deletions torch_geometric/nn/models/re_net.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import math
from typing import Callable, List, Tuple

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import GRU, Linear, Parameter
from torch_scatter import scatter_mean

from torch_geometric.data.data import Data


class RENet(torch.nn.Module):
r"""The Recurrent Event Network model from the `"Recurrent Event Network
Expand Down Expand Up @@ -40,8 +44,16 @@ class RENet(torch.nn.Module):
bias (bool, optional): If set to :obj:`False`, all layers will not
learn an additive bias. (default: :obj:`True`)
"""
def __init__(self, num_nodes, num_rels, hidden_channels, seq_len,
num_layers=1, dropout=0., bias=True):
def __init__(
self,
num_nodes: int,
num_rels: int,
hidden_channels: int,
seq_len: int,
num_layers: int = 1,
dropout: float = 0.,
bias: bool = True,
):
super().__init__()

self.num_nodes = num_nodes
Expand Down Expand Up @@ -73,7 +85,7 @@ def reset_parameters(self):
self.obj_lin.reset_parameters()

@staticmethod
def pre_transform(seq_len):
def pre_transform(seq_len: int) -> Callable:
r"""Precomputes history objects
.. math::
Expand All @@ -83,18 +95,23 @@ def pre_transform(seq_len):
:math:`k` denoting the sequence length :obj:`seq_len`.
"""
class PreTransform(object):
def __init__(self, seq_len):
def __init__(self, seq_len: int):
self.seq_len = seq_len
self.inc = 5000
self.t_last = 0
self.sub_hist = self.increase_hist_node_size([])
self.obj_hist = self.increase_hist_node_size([])

def increase_hist_node_size(self, hist):
def increase_hist_node_size(self, hist: List[int]) -> List[int]:
hist_inc = torch.zeros((self.inc, self.seq_len + 1, 0))
return hist + hist_inc.tolist()

def get_history(self, hist, node, rel):
def get_history(
self,
hist: List[int],
node: int,
rel: int,
) -> Tuple[Tensor, Tensor]:
hists, ts = [], []
for s in range(seq_len):
h = hist[node][s]
Expand All @@ -106,13 +123,13 @@ def get_history(self, hist, node, rel):
t = torch.cat(ts, dim=0)[r == rel]
return node, t

def step(self, hist):
def step(self, hist: List[int]) -> List[int]:
for i in range(len(hist)):
hist[i] = hist[i][1:]
hist[i].append([])
return hist

def __call__(self, data):
def __call__(self, data: Data) -> Data:
sub, rel, obj, t = data.sub, data.rel, data.obj, data.t

if max(sub, obj) + 1 > len(self.sub_hist): # pragma: no cover
Expand Down Expand Up @@ -142,7 +159,7 @@ def __repr__(self) -> str: # pragma: no cover

return PreTransform(seq_len)

def forward(self, data):
def forward(self, data: Data) -> Tuple[Tensor, Tensor]:
"""Given a :obj:`data` batch, computes the forward pass.
Args:
Expand Down Expand Up @@ -191,7 +208,7 @@ def forward(self, data):

return log_prob_obj, log_prob_sub

def test(self, logits, y):
def test(self, logits: Tensor, y: Tensor) -> Tensor:
"""Given ground-truth :obj:`y`, computes Mean Reciprocal Rank (MRR)
and Hits at 1/3/10."""

Expand Down

0 comments on commit ee09efe

Please sign in to comment.