Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added DyRep #27

Merged
merged 10 commits into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 58 additions & 91 deletions examples/edgeprediction/dyrep_LP.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,54 +4,35 @@
- https://github.com/twitter-research/tgn
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py

Date:
- Apr. 21, 2023
Spec.:
- Memory Updater: RNN
- Embedding Module: ID
- Message Function: ATTN
"""

import os.path as osp
import numpy as np

import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn import Linear

from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
# from torch_geometric.nn import TGNMemory, TransformerConv
# from torch_geometric.nn import TGNMemory
from torch_geometric.nn import TransformerConv
from torch_geometric.nn.models.tgn import (
IdentityMessage,
LastAggregator,
LastNeighborLoader,
)
import math
import time

# internal imports
from models.tgn_dyrep import TGNMemory
from edgepred_utils import *



overall_start = time.time()

seed = 42
LR = 0.0001
batch_size = 200
K = 10 # for computing metrics@k
n_epoch = 1

memory_dim = time_dim = embedding_dim = 100


# set random seed for predictable experiments
set_random_seed(seed)

# set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

# data loading
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'JODIE')
dataset = JODIEDataset(path, name='wikipedia')
data = dataset[0]
Expand All @@ -60,20 +41,30 @@
# expensive memory transfer costs for mini-batches:
data = data.to(device)

# set the global parameters
LR = 0.0001
batch_size = 200
n_epoch = 50
seed = 123
memory_dim = time_dim = embedding_dim = 100

# set the seed
set_random_seed(seed)


# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())

# split the data
train_data, val_data, test_data = data.train_val_test_split(
val_ratio=0.15, test_ratio=0.15)

train_loader = TemporalDataLoader(train_data, batch_size=batch_size)
val_loader = TemporalDataLoader(val_data, batch_size=batch_size)
test_loader = TemporalDataLoader(test_data, batch_size=batch_size)

# neighhorhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=10, device=device)



Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we move GraphAttentionEmbedding to models folder as well? Since it will be used in multiple example datasets

class GraphAttentionEmbedding(torch.nn.Module):
def __init__(self, in_channels, out_channels, msg_dim, time_enc):
super().__init__()
Expand All @@ -99,7 +90,7 @@ def __init__(self, in_channels):
def forward(self, z_src, z_dst):
h = self.lin_src(z_src) + self.lin_dst(z_dst)
h = h.relu()
return self.lin_final(h)
return self.lin_final(h).sigmoid()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better approach might be to specify the non-linearity in the arguments for the layer
ie. act="sigmoid" and if act=="sigmoid", return self.lin_final(h).sigmoid()



memory = TGNMemory(
Expand All @@ -109,8 +100,7 @@ def forward(self, z_src, z_dst):
time_dim,
message_module=IdentityMessage(data.msg.size(-1), memory_dim, time_dim),
aggregator_module=LastAggregator(),
memory_updater_type='rnn', # for DyRep, the momory updater is an RNNCell.
use_destination_embedding_in_message=True # only for DyRep
memory_updater_type='rnn' # TGN: 'gru', JODIE & DyRep: 'rnn'
).to(device)

gnn = GraphAttentionEmbedding(
Expand All @@ -120,6 +110,7 @@ def forward(self, z_src, z_dst):
time_enc=memory.time_enc,
).to(device)


link_pred = LinkPredictor(in_channels=embedding_dim).to(device)

optimizer = torch.optim.Adam(
Expand All @@ -131,6 +122,7 @@ def forward(self, z_src, z_dst):
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)



def train():
memory.train()
gnn.train()
Expand All @@ -156,25 +148,21 @@ def train():

# Get updated memory of all nodes involved in the computation.
z, last_update = memory(n_id)
z = gnn(z, last_update, edge_index, data.t[e_id].to(device),
data.msg[e_id].to(device))

# pos_out = link_pred(z[assoc[src]], z[assoc[pos_dst]])
# neg_out = link_pred(z[assoc[src]], z[assoc[neg_dst]])

# loss = criterion(pos_out, torch.ones_like(pos_out))
# loss += criterion(neg_out, torch.zeros_like(neg_out))

# Update memory and neighbor loader with ground-truth state.
memory.update_state(src, pos_dst, t, msg, z, assoc)

pos_out = link_pred(memory.memory[src], memory.memory[pos_dst])
neg_out = link_pred(memory.memory[src], memory.memory[neg_dst])
pos_out = link_pred(z[assoc[src]], z[assoc[pos_dst]])
neg_out = link_pred(z[assoc[src]], z[assoc[neg_dst]])

loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))

# update the neighborhood loader
# Update memory with ground-truth state.
z = gnn(z, last_update, edge_index, data.t[e_id].to(device),
data.msg[e_id].to(device))
src_embedding = z[assoc[src]].detach().clone()
pos_dst_embedding = z[assoc[pos_dst]].detach().clone()
memory.update_state(src, pos_dst, t, msg, src_embedding, pos_dst_embedding)

# Update the neighbor loader
neighbor_loader.insert(src, pos_dst)

loss.backward()
Expand All @@ -186,12 +174,12 @@ def train():


@torch.no_grad()
def test_one_pos_vs_one_neg(loader):
def test(loader):
memory.eval()
gnn.eval()
link_pred.eval()

torch.manual_seed(seed) # Ensure deterministic sampling across epochs. Note that random negative edges for training are selected as random.
torch.manual_seed(seed) # Ensure deterministic sampling across epochs.

aps, aucs = [], []
for batch in loader:
Expand All @@ -206,9 +194,7 @@ def test_one_pos_vs_one_neg(loader):
assoc[n_id] = torch.arange(n_id.size(0), device=device)

z, last_update = memory(n_id)
z = gnn(z, last_update, edge_index, data.t[e_id].to(device),
data.msg[e_id].to(device))


pos_out = link_pred(z[assoc[src]], z[assoc[pos_dst]])
neg_out = link_pred(z[assoc[src]], z[assoc[neg_dst]])

Expand All @@ -220,57 +206,38 @@ def test_one_pos_vs_one_neg(loader):
aps.append(average_precision_score(y_true, y_pred))
aucs.append(roc_auc_score(y_true, y_pred))

memory.update_state(src, pos_dst, t, msg)
# update the memory
z = gnn(z, last_update, edge_index, data.t[e_id].to(device),
data.msg[e_id].to(device))
src_embedding = z[assoc[src]]
pos_dst_embedding = z[assoc[pos_dst]]
memory.update_state(src, pos_dst, t, msg, src_embedding, pos_dst_embedding)

# update the neighborhood loader
neighbor_loader.insert(src, pos_dst)

perf_metrics = {'ap': float(torch.tensor(aps).mean()),
'auc': float(torch.tensor(aucs).mean()),
}

return perf_metrics



return float(torch.tensor(aps).mean()), float(torch.tensor(aucs).mean())



# Train & Validation
print("INFO: =========================================")
print("INFO: ===========*** DyRep model ***===========")
print("INFO: =========================================")
print("=============================================")
print("=============*** DyRep model ***=============")
print("=============================================")

for epoch in range(n_epoch):
# =========== Train & Validation
for epoch in range(1, n_epoch + 1):
start_epoch_train = time.time()
loss = train()
end_epoch_train = time.time()
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Elapsed Time (s): {end_epoch_train - start_epoch_train: .4f}')
val_perf_metrics = test_one_pos_vs_one_neg(val_loader)
val_ap, val_auc = val_perf_metrics['ap'], val_perf_metrics['auc']
val_ap, val_auc = test(val_loader)
print(f'\tVal AP: {val_ap:.4f}, Val AUC: {val_auc:.4f}')

# ============
# === TEST ===
# ============

DLP_EVAL_SETUP = 'one_vs_one' # 'one_vs_all': each positive edge vs. all relevant negative edges, 'any_vs_all': any positive edges with the same source vs. all relevant negative edges

start_test = time.time()
if DLP_EVAL_SETUP == 'any_vs_all':
perf_metrics_test = test_exh_any_pos_vs_all(test_loader)
elif DLP_EVAL_SETUP == 'one_vs_all':
perf_metrics_test = test_exh_one_pos_vs_all(test_loader)
elif DLP_EVAL_SETUP == 'one_vs_one':
perf_metrics_test = test_one_pos_vs_one_neg(test_loader)
else:
raise ValueError("Undefined test evaluation setup for dynamic link prediction!!!")

end_test = time.time()
print(f"INFO: Test Evaluation Setup: {DLP_EVAL_SETUP}")
print(f"INFO: >>> K={K}")
for perf_name, perf_value in perf_metrics_test.items():
print(f"\tTest: {perf_name}: {perf_value: .4f}")
print(f'Test: Elapsed Time (s): {end_test - start_test: .4f}')

overall_end = time.time()
print(f'Overall Elapsed Time (s): {overall_end - overall_start: .4f}')
print("INFO: =======================================")
# =========== Test
print("---------------------------------------------")
start_test_time = time.time()
test_ap, test_auc = test(test_loader)
end_test_time = time.time()
print("INFO: Final TEST Performance:")
print(f'\tTest AP: {test_ap:.4f}, Test AUC: {test_auc:.4f}, Elapsed Time (s): {end_test_time - start_test_time: .4f}')
print("=============================================")
Loading