-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added DyRep #27
Merged
Merged
added DyRep #27
Changes from 5 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
176708d
added DyRep
3f07a21
added evaluator & negative sampler for link prediction
476a6a0
cleaned the utlity functions
c333700
distinct evaluation setting for TGN-wikipedia
0b2f5d7
remove small redundancies
2aa145a
added one-vs-many negative sampling
b9c6ee4
modularizing
e510116
cleaning and modularizing
43e593f
save&load negative samples from disk
3aa8696
Merge branch 'main' into farimah_tg_models
shenyangHuang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,54 +4,35 @@ | |
- https://github.com/twitter-research/tgn | ||
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py | ||
|
||
Date: | ||
- Apr. 21, 2023 | ||
Spec.: | ||
- Memory Updater: RNN | ||
- Embedding Module: ID | ||
- Message Function: ATTN | ||
""" | ||
|
||
import os.path as osp | ||
import numpy as np | ||
|
||
import torch | ||
from sklearn.metrics import average_precision_score, roc_auc_score | ||
from torch.nn import Linear | ||
|
||
from torch_geometric.datasets import JODIEDataset | ||
from torch_geometric.loader import TemporalDataLoader | ||
# from torch_geometric.nn import TGNMemory, TransformerConv | ||
# from torch_geometric.nn import TGNMemory | ||
from torch_geometric.nn import TransformerConv | ||
from torch_geometric.nn.models.tgn import ( | ||
IdentityMessage, | ||
LastAggregator, | ||
LastNeighborLoader, | ||
) | ||
import math | ||
import time | ||
|
||
# internal imports | ||
from models.tgn_dyrep import TGNMemory | ||
from edgepred_utils import * | ||
|
||
|
||
|
||
overall_start = time.time() | ||
|
||
seed = 42 | ||
LR = 0.0001 | ||
batch_size = 200 | ||
K = 10 # for computing metrics@k | ||
n_epoch = 1 | ||
|
||
memory_dim = time_dim = embedding_dim = 100 | ||
|
||
|
||
# set random seed for predictable experiments | ||
set_random_seed(seed) | ||
|
||
# set the device | ||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
# device = torch.device('cpu') | ||
|
||
# data loading | ||
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'JODIE') | ||
dataset = JODIEDataset(path, name='wikipedia') | ||
data = dataset[0] | ||
|
@@ -60,20 +41,30 @@ | |
# expensive memory transfer costs for mini-batches: | ||
data = data.to(device) | ||
|
||
# set the global parameters | ||
LR = 0.0001 | ||
batch_size = 200 | ||
n_epoch = 50 | ||
seed = 123 | ||
memory_dim = time_dim = embedding_dim = 100 | ||
|
||
# set the seed | ||
set_random_seed(seed) | ||
|
||
|
||
# Ensure to only sample actual destination nodes as negatives. | ||
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max()) | ||
|
||
# split the data | ||
train_data, val_data, test_data = data.train_val_test_split( | ||
val_ratio=0.15, test_ratio=0.15) | ||
|
||
train_loader = TemporalDataLoader(train_data, batch_size=batch_size) | ||
val_loader = TemporalDataLoader(val_data, batch_size=batch_size) | ||
test_loader = TemporalDataLoader(test_data, batch_size=batch_size) | ||
|
||
# neighhorhood sampler | ||
neighbor_loader = LastNeighborLoader(data.num_nodes, size=10, device=device) | ||
|
||
|
||
|
||
class GraphAttentionEmbedding(torch.nn.Module): | ||
def __init__(self, in_channels, out_channels, msg_dim, time_enc): | ||
super().__init__() | ||
|
@@ -99,7 +90,7 @@ def __init__(self, in_channels): | |
def forward(self, z_src, z_dst): | ||
h = self.lin_src(z_src) + self.lin_dst(z_dst) | ||
h = h.relu() | ||
return self.lin_final(h) | ||
return self.lin_final(h).sigmoid() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. better approach might be to specify the non-linearity in the arguments for the layer |
||
|
||
|
||
memory = TGNMemory( | ||
|
@@ -109,8 +100,7 @@ def forward(self, z_src, z_dst): | |
time_dim, | ||
message_module=IdentityMessage(data.msg.size(-1), memory_dim, time_dim), | ||
aggregator_module=LastAggregator(), | ||
memory_updater_type='rnn', # for DyRep, the momory updater is an RNNCell. | ||
use_destination_embedding_in_message=True # only for DyRep | ||
memory_updater_type='rnn' # TGN: 'gru', JODIE & DyRep: 'rnn' | ||
).to(device) | ||
|
||
gnn = GraphAttentionEmbedding( | ||
|
@@ -120,6 +110,7 @@ def forward(self, z_src, z_dst): | |
time_enc=memory.time_enc, | ||
).to(device) | ||
|
||
|
||
link_pred = LinkPredictor(in_channels=embedding_dim).to(device) | ||
|
||
optimizer = torch.optim.Adam( | ||
|
@@ -131,6 +122,7 @@ def forward(self, z_src, z_dst): | |
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device) | ||
|
||
|
||
|
||
def train(): | ||
memory.train() | ||
gnn.train() | ||
|
@@ -156,25 +148,21 @@ def train(): | |
|
||
# Get updated memory of all nodes involved in the computation. | ||
z, last_update = memory(n_id) | ||
z = gnn(z, last_update, edge_index, data.t[e_id].to(device), | ||
data.msg[e_id].to(device)) | ||
|
||
# pos_out = link_pred(z[assoc[src]], z[assoc[pos_dst]]) | ||
# neg_out = link_pred(z[assoc[src]], z[assoc[neg_dst]]) | ||
|
||
# loss = criterion(pos_out, torch.ones_like(pos_out)) | ||
# loss += criterion(neg_out, torch.zeros_like(neg_out)) | ||
|
||
# Update memory and neighbor loader with ground-truth state. | ||
memory.update_state(src, pos_dst, t, msg, z, assoc) | ||
|
||
pos_out = link_pred(memory.memory[src], memory.memory[pos_dst]) | ||
neg_out = link_pred(memory.memory[src], memory.memory[neg_dst]) | ||
pos_out = link_pred(z[assoc[src]], z[assoc[pos_dst]]) | ||
neg_out = link_pred(z[assoc[src]], z[assoc[neg_dst]]) | ||
|
||
loss = criterion(pos_out, torch.ones_like(pos_out)) | ||
loss += criterion(neg_out, torch.zeros_like(neg_out)) | ||
|
||
# update the neighborhood loader | ||
# Update memory with ground-truth state. | ||
z = gnn(z, last_update, edge_index, data.t[e_id].to(device), | ||
data.msg[e_id].to(device)) | ||
src_embedding = z[assoc[src]].detach().clone() | ||
pos_dst_embedding = z[assoc[pos_dst]].detach().clone() | ||
memory.update_state(src, pos_dst, t, msg, src_embedding, pos_dst_embedding) | ||
|
||
# Update the neighbor loader | ||
neighbor_loader.insert(src, pos_dst) | ||
|
||
loss.backward() | ||
|
@@ -186,12 +174,12 @@ def train(): | |
|
||
|
||
@torch.no_grad() | ||
def test_one_pos_vs_one_neg(loader): | ||
def test(loader): | ||
memory.eval() | ||
gnn.eval() | ||
link_pred.eval() | ||
|
||
torch.manual_seed(seed) # Ensure deterministic sampling across epochs. Note that random negative edges for training are selected as random. | ||
torch.manual_seed(seed) # Ensure deterministic sampling across epochs. | ||
|
||
aps, aucs = [], [] | ||
for batch in loader: | ||
|
@@ -206,9 +194,7 @@ def test_one_pos_vs_one_neg(loader): | |
assoc[n_id] = torch.arange(n_id.size(0), device=device) | ||
|
||
z, last_update = memory(n_id) | ||
z = gnn(z, last_update, edge_index, data.t[e_id].to(device), | ||
data.msg[e_id].to(device)) | ||
|
||
|
||
pos_out = link_pred(z[assoc[src]], z[assoc[pos_dst]]) | ||
neg_out = link_pred(z[assoc[src]], z[assoc[neg_dst]]) | ||
|
||
|
@@ -220,57 +206,38 @@ def test_one_pos_vs_one_neg(loader): | |
aps.append(average_precision_score(y_true, y_pred)) | ||
aucs.append(roc_auc_score(y_true, y_pred)) | ||
|
||
memory.update_state(src, pos_dst, t, msg) | ||
# update the memory | ||
z = gnn(z, last_update, edge_index, data.t[e_id].to(device), | ||
data.msg[e_id].to(device)) | ||
src_embedding = z[assoc[src]] | ||
pos_dst_embedding = z[assoc[pos_dst]] | ||
memory.update_state(src, pos_dst, t, msg, src_embedding, pos_dst_embedding) | ||
|
||
# update the neighborhood loader | ||
neighbor_loader.insert(src, pos_dst) | ||
|
||
perf_metrics = {'ap': float(torch.tensor(aps).mean()), | ||
'auc': float(torch.tensor(aucs).mean()), | ||
} | ||
|
||
return perf_metrics | ||
|
||
|
||
|
||
return float(torch.tensor(aps).mean()), float(torch.tensor(aucs).mean()) | ||
|
||
|
||
|
||
# Train & Validation | ||
print("INFO: =========================================") | ||
print("INFO: ===========*** DyRep model ***===========") | ||
print("INFO: =========================================") | ||
print("=============================================") | ||
print("=============*** DyRep model ***=============") | ||
print("=============================================") | ||
|
||
for epoch in range(n_epoch): | ||
# =========== Train & Validation | ||
for epoch in range(1, n_epoch + 1): | ||
start_epoch_train = time.time() | ||
loss = train() | ||
end_epoch_train = time.time() | ||
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Elapsed Time (s): {end_epoch_train - start_epoch_train: .4f}') | ||
val_perf_metrics = test_one_pos_vs_one_neg(val_loader) | ||
val_ap, val_auc = val_perf_metrics['ap'], val_perf_metrics['auc'] | ||
val_ap, val_auc = test(val_loader) | ||
print(f'\tVal AP: {val_ap:.4f}, Val AUC: {val_auc:.4f}') | ||
|
||
# ============ | ||
# === TEST === | ||
# ============ | ||
|
||
DLP_EVAL_SETUP = 'one_vs_one' # 'one_vs_all': each positive edge vs. all relevant negative edges, 'any_vs_all': any positive edges with the same source vs. all relevant negative edges | ||
|
||
start_test = time.time() | ||
if DLP_EVAL_SETUP == 'any_vs_all': | ||
perf_metrics_test = test_exh_any_pos_vs_all(test_loader) | ||
elif DLP_EVAL_SETUP == 'one_vs_all': | ||
perf_metrics_test = test_exh_one_pos_vs_all(test_loader) | ||
elif DLP_EVAL_SETUP == 'one_vs_one': | ||
perf_metrics_test = test_one_pos_vs_one_neg(test_loader) | ||
else: | ||
raise ValueError("Undefined test evaluation setup for dynamic link prediction!!!") | ||
|
||
end_test = time.time() | ||
print(f"INFO: Test Evaluation Setup: {DLP_EVAL_SETUP}") | ||
print(f"INFO: >>> K={K}") | ||
for perf_name, perf_value in perf_metrics_test.items(): | ||
print(f"\tTest: {perf_name}: {perf_value: .4f}") | ||
print(f'Test: Elapsed Time (s): {end_test - start_test: .4f}') | ||
|
||
overall_end = time.time() | ||
print(f'Overall Elapsed Time (s): {overall_end - overall_start: .4f}') | ||
print("INFO: =======================================") | ||
# =========== Test | ||
print("---------------------------------------------") | ||
start_test_time = time.time() | ||
test_ap, test_auc = test(test_loader) | ||
end_test_time = time.time() | ||
print("INFO: Final TEST Performance:") | ||
print(f'\tTest AP: {test_ap:.4f}, Test AUC: {test_auc:.4f}, Elapsed Time (s): {end_test_time - start_test_time: .4f}') | ||
print("=============================================") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we move
GraphAttentionEmbedding
tomodels
folder as well? Since it will be used in multiple example datasets