Skip to content

Commit

Permalink
[Code Coverage] models/tgn.py (#6662)
Browse files Browse the repository at this point in the history
Add test to TGN
Update typing.
  • Loading branch information
zechengz authored Feb 11, 2023
1 parent 6e62bb2 commit 9136c83
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 15 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug in `Data.subgraph()` and `HeteroData.subgraph()` ([#6613](https://github.com/pyg-team/pytorch_geometric/pull/6613)
- Fixed a bug in `PNAConv` and `DegreeScalerAggregation` to correctly incorporate degree statistics of isolated nodes ([#6609](https://github.com/pyg-team/pytorch_geometric/pull/6609))
- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664))
- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664), [#6662](https://github.com/pyg-team/pytorch_geometric/pull/6662))
- Fixed a bug in which `data.to_heterogeneous()` filtered attributs in the wrong dimension ([#6522](https://github.com/pyg-team/pytorch_geometric/pull/6522))
- Breaking Change: Temporal sampling will now also sample nodes with an equal timestamp to the seed time (requires `pyg-lib>0.1.0`) ([#6517](https://github.com/pyg-team/pytorch_geometric/pull/6517))
- Changed `DataLoader` workers with affinity to start at `cpu0` ([#6512](https://github.com/pyg-team/pytorch_geometric/pull/6512))
Expand Down
78 changes: 78 additions & 0 deletions test/nn/models/test_tgn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import torch

from torch_geometric.data import TemporalData
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TGNMemory
from torch_geometric.nn.models.tgn import (
IdentityMessage,
LastAggregator,
LastNeighborLoader,
)


def test_tgn():
memory_dim = 16
time_dim = 16
src = torch.tensor([0, 1, 0, 2, 0, 3, 1, 4, 2, 3])
dst = torch.tensor([1, 2, 1, 1, 3, 2, 4, 3, 3, 4])
t = torch.arange(10)
msg = torch.randn(10, 16)
data = TemporalData(src=src, dst=dst, t=t, msg=msg)
loader = TemporalDataLoader(data, batch_size=5)

neighbor_loader = LastNeighborLoader(data.num_nodes, size=3)
assert neighbor_loader.cur_e_id == 0
assert neighbor_loader.e_id.size() == (data.num_nodes, 3)

memory = TGNMemory(
num_nodes=data.num_nodes,
raw_msg_dim=data.msg.size(-1),
memory_dim=memory_dim,
time_dim=time_dim,
message_module=IdentityMessage(data.msg.size(-1), memory_dim,
time_dim),
aggregator_module=LastAggregator(),
)
assert memory.memory.size() == (data.num_nodes, memory_dim)
assert memory.last_update.size() == (data.num_nodes, )

# Test during TGNMemory training:
memory.train()
for i, batch in enumerate(loader):
n_id = torch.cat([batch.src, batch.dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
z, last_update = memory(n_id)
memory.update_state(batch.src, batch.dst, batch.t, batch.msg)
neighbor_loader.insert(batch.src, batch.dst)
if i == 0:
assert n_id.size(0) == 4
assert edge_index.numel() == 0
assert e_id.numel() == 0
assert z.size() == (n_id.size(0), memory_dim)
assert torch.sum(last_update) == 0
else:
assert n_id.size(0) == 5
assert edge_index.numel() == 12
assert e_id.numel() == 6
assert z.size() == (n_id.size(0), memory_dim)
assert torch.equal(last_update, torch.tensor([4, 3, 3, 4, 0]))

# Test after TGNMemory training:
memory.eval()
all_n_id = torch.arange(data.num_nodes)
z, last_update = memory(all_n_id)
assert z.size() == (data.num_nodes, memory_dim)
assert torch.equal(last_update, torch.tensor([4, 6, 8, 9, 9]))

post_src = torch.tensor([3, 4])
post_dst = torch.tensor([4, 3])
post_t = torch.tensor([10, 10])
post_msg = torch.randn(2, 16)
memory.update_state(post_src, post_dst, post_t, post_msg)
post_z, post_last_update = memory(all_n_id)
assert torch.allclose(z[0:3], post_z[0:3])
assert torch.equal(post_last_update, torch.tensor([4, 6, 8, 10, 10]))

memory.reset_state()
assert memory.memory.sum() == 0
assert memory.last_update.sum() == 0
34 changes: 20 additions & 14 deletions torch_geometric/nn/models/tgn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import Callable, Tuple
from typing import Callable, Dict, Tuple

import torch
from torch import Tensor
Expand All @@ -8,6 +8,8 @@
from torch_geometric.nn.inits import zeros
from torch_geometric.utils import scatter

TGNMessageStoreType = Dict[int, Tuple[Tensor, Tensor, Tensor, Tensor]]


class TGNMemory(torch.nn.Module):
r"""The Temporal Graph Network (TGN) memory model from the
Expand Down Expand Up @@ -90,7 +92,8 @@ def forward(self, n_id: Tensor) -> Tuple[Tensor, Tensor]:

return memory, last_update

def update_state(self, src, dst, t, raw_msg):
def update_state(self, src: Tensor, dst: Tensor, t: Tensor,
raw_msg: Tensor):
"""Updates the memory with newly encountered interactions
:obj:`(src, dst, t, raw_msg)`."""
n_id = torch.cat([src, dst]).unique()
Expand All @@ -111,12 +114,12 @@ def __reset_message_store__(self):
self.msg_s_store = {j: (i, i, i, msg) for j in range(self.num_nodes)}
self.msg_d_store = {j: (i, i, i, msg) for j in range(self.num_nodes)}

def __update_memory__(self, n_id):
def __update_memory__(self, n_id: Tensor):
memory, last_update = self.__get_updated_memory__(n_id)
self.memory[n_id] = memory
self.last_update[n_id] = last_update

def __get_updated_memory__(self, n_id):
def __get_updated_memory__(self, n_id: Tensor) -> Tuple[Tensor, Tensor]:
self.__assoc__[n_id] = torch.arange(n_id.size(0), device=n_id.device)

# Compute messages (src -> dst).
Expand All @@ -142,13 +145,15 @@ def __get_updated_memory__(self, n_id):

return memory, last_update

def __update_msg_store__(self, src, dst, t, raw_msg, msg_store):
def __update_msg_store__(self, src: Tensor, dst: Tensor, t: Tensor,
raw_msg: Tensor, msg_store: TGNMessageStoreType):
n_id, perm = src.sort()
n_id, count = n_id.unique_consecutive(return_counts=True)
for i, idx in zip(n_id.tolist(), perm.split(count.tolist())):
msg_store[i] = (src[idx], dst[idx], t[idx], raw_msg[idx])

def __compute_msg__(self, n_id, msg_store, msg_module):
def __compute_msg__(self, n_id: Tensor, msg_store: TGNMessageStoreType,
msg_module: Callable):
data = [msg_store[i] for i in n_id.tolist()]
src, dst, t, raw_msg = list(zip(*data))
src = torch.cat(src, dim=0)
Expand Down Expand Up @@ -177,12 +182,13 @@ def __init__(self, raw_msg_dim: int, memory_dim: int, time_dim: int):
super().__init__()
self.out_channels = raw_msg_dim + 2 * memory_dim + time_dim

def forward(self, z_src, z_dst, raw_msg, t_enc):
def forward(self, z_src: Tensor, z_dst: Tensor, raw_msg: Tensor,
t_enc: Tensor):
return torch.cat([z_src, z_dst, raw_msg, t_enc], dim=-1)


class LastAggregator(torch.nn.Module):
def forward(self, msg, index, t, dim_size):
def forward(self, msg: Tensor, index: Tensor, t: Tensor, dim_size: int):
from torch_scatter import scatter_max
_, argmax = scatter_max(t, index, dim=0, dim_size=dim_size)
out = msg.new_zeros((dim_size, msg.size(-1)))
Expand All @@ -192,20 +198,20 @@ def forward(self, msg, index, t, dim_size):


class MeanAggregator(torch.nn.Module):
def forward(self, msg, index, t, dim_size):
def forward(self, msg: Tensor, index: Tensor, t: Tensor, dim_size: int):
return scatter(msg, index, dim=0, dim_size=dim_size, reduce='mean')


class TimeEncoder(torch.nn.Module):
def __init__(self, out_channels):
def __init__(self, out_channels: int):
super().__init__()
self.out_channels = out_channels
self.lin = Linear(1, out_channels)

def reset_parameters(self):
self.lin.reset_parameters()

def forward(self, t):
def forward(self, t: Tensor) -> Tensor:
return self.lin(t.view(-1, 1)).cos()


Expand All @@ -222,7 +228,7 @@ def __init__(self, num_nodes: int, size: int, device=None):

self.reset_state()

def __call__(self, n_id):
def __call__(self, n_id: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
neighbors = self.neighbors[n_id]
nodes = n_id.view(-1, 1).repeat(1, self.size)
e_id = self.e_id[n_id]
Expand All @@ -238,8 +244,8 @@ def __call__(self, n_id):

return n_id, torch.stack([neighbors, nodes]), e_id

def insert(self, src, dst):
# Inserts newly encountered interactions into an ever growing
def insert(self, src: Tensor, dst: Tensor):
# Inserts newly encountered interactions into an ever-growing
# (undirected) temporal graph.

# Collect central nodes, their neighbors and the current event ids.
Expand Down

0 comments on commit 9136c83

Please sign in to comment.