Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Code Coverage] models/tgn.py #6662

Merged
merged 5 commits into from
Feb 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug in `Data.subgraph()` and `HeteroData.subgraph()` ([#6613](https://github.com/pyg-team/pytorch_geometric/pull/6613)
- Fixed a bug in `PNAConv` and `DegreeScalerAggregation` to correctly incorporate degree statistics of isolated nodes ([#6609](https://github.com/pyg-team/pytorch_geometric/pull/6609))
- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664))
- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664), [#6662](https://github.com/pyg-team/pytorch_geometric/pull/6662))
- Fixed a bug in which `data.to_heterogeneous()` filtered attributs in the wrong dimension ([#6522](https://github.com/pyg-team/pytorch_geometric/pull/6522))
- Breaking Change: Temporal sampling will now also sample nodes with an equal timestamp to the seed time (requires `pyg-lib>0.1.0`) ([#6517](https://github.com/pyg-team/pytorch_geometric/pull/6517))
- Changed `DataLoader` workers with affinity to start at `cpu0` ([#6512](https://github.com/pyg-team/pytorch_geometric/pull/6512))
Expand Down
78 changes: 78 additions & 0 deletions test/nn/models/test_tgn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import torch

from torch_geometric.data import TemporalData
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TGNMemory
from torch_geometric.nn.models.tgn import (
IdentityMessage,
LastAggregator,
LastNeighborLoader,
)


def test_tgn():
memory_dim = 16
time_dim = 16
src = torch.tensor([0, 1, 0, 2, 0, 3, 1, 4, 2, 3])
dst = torch.tensor([1, 2, 1, 1, 3, 2, 4, 3, 3, 4])
t = torch.arange(10)
msg = torch.randn(10, 16)
data = TemporalData(src=src, dst=dst, t=t, msg=msg)
loader = TemporalDataLoader(data, batch_size=5)

neighbor_loader = LastNeighborLoader(data.num_nodes, size=3)
assert neighbor_loader.cur_e_id == 0
assert neighbor_loader.e_id.size() == (data.num_nodes, 3)

memory = TGNMemory(
num_nodes=data.num_nodes,
raw_msg_dim=data.msg.size(-1),
memory_dim=memory_dim,
time_dim=time_dim,
message_module=IdentityMessage(data.msg.size(-1), memory_dim,
time_dim),
aggregator_module=LastAggregator(),
)
assert memory.memory.size() == (data.num_nodes, memory_dim)
assert memory.last_update.size() == (data.num_nodes, )

# Test during TGNMemory training:
memory.train()
for i, batch in enumerate(loader):
n_id = torch.cat([batch.src, batch.dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
z, last_update = memory(n_id)
memory.update_state(batch.src, batch.dst, batch.t, batch.msg)
neighbor_loader.insert(batch.src, batch.dst)
if i == 0:
assert n_id.size(0) == 4
assert edge_index.numel() == 0
assert e_id.numel() == 0
assert z.size() == (n_id.size(0), memory_dim)
assert torch.sum(last_update) == 0
else:
assert n_id.size(0) == 5
assert edge_index.numel() == 12
assert e_id.numel() == 6
assert z.size() == (n_id.size(0), memory_dim)
assert torch.equal(last_update, torch.tensor([4, 3, 3, 4, 0]))

# Test after TGNMemory training:
memory.eval()
all_n_id = torch.arange(data.num_nodes)
z, last_update = memory(all_n_id)
assert z.size() == (data.num_nodes, memory_dim)
assert torch.equal(last_update, torch.tensor([4, 6, 8, 9, 9]))

post_src = torch.tensor([3, 4])
post_dst = torch.tensor([4, 3])
post_t = torch.tensor([10, 10])
post_msg = torch.randn(2, 16)
memory.update_state(post_src, post_dst, post_t, post_msg)
post_z, post_last_update = memory(all_n_id)
assert torch.allclose(z[0:3], post_z[0:3])
assert torch.equal(post_last_update, torch.tensor([4, 6, 8, 10, 10]))
zechengz marked this conversation as resolved.
Show resolved Hide resolved

memory.reset_state()
assert memory.memory.sum() == 0
assert memory.last_update.sum() == 0
34 changes: 20 additions & 14 deletions torch_geometric/nn/models/tgn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import Callable, Tuple
from typing import Callable, Dict, Tuple

import torch
from torch import Tensor
Expand All @@ -8,6 +8,8 @@
from torch_geometric.nn.inits import zeros
from torch_geometric.utils import scatter

TGNMessageStoreType = Dict[int, Tuple[Tensor, Tensor, Tensor, Tensor]]


class TGNMemory(torch.nn.Module):
r"""The Temporal Graph Network (TGN) memory model from the
Expand Down Expand Up @@ -90,7 +92,8 @@ def forward(self, n_id: Tensor) -> Tuple[Tensor, Tensor]:

return memory, last_update

def update_state(self, src, dst, t, raw_msg):
def update_state(self, src: Tensor, dst: Tensor, t: Tensor,
raw_msg: Tensor):
"""Updates the memory with newly encountered interactions
:obj:`(src, dst, t, raw_msg)`."""
n_id = torch.cat([src, dst]).unique()
Expand All @@ -111,12 +114,12 @@ def __reset_message_store__(self):
self.msg_s_store = {j: (i, i, i, msg) for j in range(self.num_nodes)}
self.msg_d_store = {j: (i, i, i, msg) for j in range(self.num_nodes)}

def __update_memory__(self, n_id):
def __update_memory__(self, n_id: Tensor):
memory, last_update = self.__get_updated_memory__(n_id)
self.memory[n_id] = memory
self.last_update[n_id] = last_update

def __get_updated_memory__(self, n_id):
def __get_updated_memory__(self, n_id: Tensor) -> Tuple[Tensor, Tensor]:
self.__assoc__[n_id] = torch.arange(n_id.size(0), device=n_id.device)

# Compute messages (src -> dst).
Expand All @@ -142,13 +145,15 @@ def __get_updated_memory__(self, n_id):

return memory, last_update

def __update_msg_store__(self, src, dst, t, raw_msg, msg_store):
def __update_msg_store__(self, src: Tensor, dst: Tensor, t: Tensor,
raw_msg: Tensor, msg_store: TGNMessageStoreType):
n_id, perm = src.sort()
n_id, count = n_id.unique_consecutive(return_counts=True)
for i, idx in zip(n_id.tolist(), perm.split(count.tolist())):
msg_store[i] = (src[idx], dst[idx], t[idx], raw_msg[idx])

def __compute_msg__(self, n_id, msg_store, msg_module):
def __compute_msg__(self, n_id: Tensor, msg_store: TGNMessageStoreType,
msg_module: Callable):
data = [msg_store[i] for i in n_id.tolist()]
src, dst, t, raw_msg = list(zip(*data))
src = torch.cat(src, dim=0)
Expand Down Expand Up @@ -177,12 +182,13 @@ def __init__(self, raw_msg_dim: int, memory_dim: int, time_dim: int):
super().__init__()
self.out_channels = raw_msg_dim + 2 * memory_dim + time_dim

def forward(self, z_src, z_dst, raw_msg, t_enc):
def forward(self, z_src: Tensor, z_dst: Tensor, raw_msg: Tensor,
t_enc: Tensor):
return torch.cat([z_src, z_dst, raw_msg, t_enc], dim=-1)


class LastAggregator(torch.nn.Module):
def forward(self, msg, index, t, dim_size):
def forward(self, msg: Tensor, index: Tensor, t: Tensor, dim_size: int):
from torch_scatter import scatter_max
_, argmax = scatter_max(t, index, dim=0, dim_size=dim_size)
out = msg.new_zeros((dim_size, msg.size(-1)))
Expand All @@ -192,20 +198,20 @@ def forward(self, msg, index, t, dim_size):


class MeanAggregator(torch.nn.Module):
def forward(self, msg, index, t, dim_size):
def forward(self, msg: Tensor, index: Tensor, t: Tensor, dim_size: int):
return scatter(msg, index, dim=0, dim_size=dim_size, reduce='mean')


class TimeEncoder(torch.nn.Module):
def __init__(self, out_channels):
def __init__(self, out_channels: int):
super().__init__()
self.out_channels = out_channels
self.lin = Linear(1, out_channels)

def reset_parameters(self):
self.lin.reset_parameters()

def forward(self, t):
def forward(self, t: Tensor) -> Tensor:
return self.lin(t.view(-1, 1)).cos()


Expand All @@ -222,7 +228,7 @@ def __init__(self, num_nodes: int, size: int, device=None):

self.reset_state()

def __call__(self, n_id):
def __call__(self, n_id: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
neighbors = self.neighbors[n_id]
nodes = n_id.view(-1, 1).repeat(1, self.size)
e_id = self.e_id[n_id]
Expand All @@ -238,8 +244,8 @@ def __call__(self, n_id):

return n_id, torch.stack([neighbors, nodes]), e_id

def insert(self, src, dst):
# Inserts newly encountered interactions into an ever growing
def insert(self, src: Tensor, dst: Tensor):
# Inserts newly encountered interactions into an ever-growing
# (undirected) temporal graph.

# Collect central nodes, their neighbors and the current event ids.
Expand Down