diff --git a/examples/linkproppred/tgbl-coin/tgn.py b/examples/linkproppred/tgbl-coin/tgn.py index 22c6ec98..e100618f 100644 --- a/examples/linkproppred/tgbl-coin/tgn.py +++ b/examples/linkproppred/tgbl-coin/tgn.py @@ -186,12 +186,14 @@ def test(loader, neg_sampler, split_mode): # Start... start_overall = timeit.default_timer() +DATA = "tgbl-coin" + # ========== set parameters... args, _ = get_args() +args.data = DATA print("INFO: Arguments:", args) -DATA = "tgbl-coin" LR = args.lr BATCH_SIZE = args.bs K_VALUE = args.k_value @@ -232,41 +234,6 @@ def test(loader, neg_sampler, split_mode): # Ensure to only sample actual destination nodes as negatives. min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max()) -# neighborhood sampler -neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device) - -# define the model end-to-end -memory = TGNMemory( - data.num_nodes, - data.msg.size(-1), - MEM_DIM, - TIME_DIM, - message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM), - aggregator_module=LastAggregator(), -).to(device) - -gnn = GraphAttentionEmbedding( - in_channels=MEM_DIM, - out_channels=EMB_DIM, - msg_dim=data.msg.size(-1), - time_enc=memory.time_enc, -).to(device) - -link_pred = LinkPredictor(in_channels=EMB_DIM).to(device) - -model = {'memory': memory, - 'gnn': gnn, - 'link_pred': link_pred} - -optimizer = torch.optim.Adam( - set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()), - lr=LR, -) -criterion = torch.nn.BCEWithLogitsLoss() - -# Helper vector to map global node indices to local ones. -assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device) - print("==========================================================") print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============") @@ -292,6 +259,42 @@ def test(loader, neg_sampler, split_mode): torch.manual_seed(run_idx + SEED) set_random_seed(run_idx + SEED) + # neighborhood sampler + neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device) + + # define the model end-to-end + memory = TGNMemory( + data.num_nodes, + data.msg.size(-1), + MEM_DIM, + TIME_DIM, + message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM), + aggregator_module=LastAggregator(), + ).to(device) + + gnn = GraphAttentionEmbedding( + in_channels=MEM_DIM, + out_channels=EMB_DIM, + msg_dim=data.msg.size(-1), + time_enc=memory.time_enc, + ).to(device) + + link_pred = LinkPredictor(in_channels=EMB_DIM).to(device) + + model = {'memory': memory, + 'gnn': gnn, + 'link_pred': link_pred} + + optimizer = torch.optim.Adam( + set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()), + lr=LR, + ) + criterion = torch.nn.BCEWithLogitsLoss() + + # Helper vector to map global node indices to local ones. + assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device) + + # define an early stopper save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/' save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}' diff --git a/examples/linkproppred/tgbl-comment/tgn.py b/examples/linkproppred/tgbl-comment/tgn.py index b077534d..9f7a28ce 100644 --- a/examples/linkproppred/tgbl-comment/tgn.py +++ b/examples/linkproppred/tgbl-comment/tgn.py @@ -186,12 +186,13 @@ def test(loader, neg_sampler, split_mode): # Start... start_overall = timeit.default_timer() +DATA = "tgbl-comment" # ========== set parameters... args, _ = get_args() +args.data = DATA print("INFO: Arguments:", args) -DATA = "tgbl-comment" LR = args.lr BATCH_SIZE = args.bs K_VALUE = args.k_value @@ -232,42 +233,6 @@ def test(loader, neg_sampler, split_mode): # Ensure to only sample actual destination nodes as negatives. min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max()) -# neighborhood sampler -neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device) - -# define the model end-to-end -memory = TGNMemory( - data.num_nodes, - data.msg.size(-1), - MEM_DIM, - TIME_DIM, - message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM), - aggregator_module=LastAggregator(), -).to(device) - -gnn = GraphAttentionEmbedding( - in_channels=MEM_DIM, - out_channels=EMB_DIM, - msg_dim=data.msg.size(-1), - time_enc=memory.time_enc, -).to(device) - -link_pred = LinkPredictor(in_channels=EMB_DIM).to(device) - -model = {'memory': memory, - 'gnn': gnn, - 'link_pred': link_pred} - -optimizer = torch.optim.Adam( - set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()), - lr=LR, -) -criterion = torch.nn.BCEWithLogitsLoss() - -# Helper vector to map global node indices to local ones. -assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device) - - print("==========================================================") print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============") print("==========================================================") @@ -292,6 +257,42 @@ def test(loader, neg_sampler, split_mode): torch.manual_seed(run_idx + SEED) set_random_seed(run_idx + SEED) + # neighborhood sampler + neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device) + + # define the model end-to-end + memory = TGNMemory( + data.num_nodes, + data.msg.size(-1), + MEM_DIM, + TIME_DIM, + message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM), + aggregator_module=LastAggregator(), + ).to(device) + + gnn = GraphAttentionEmbedding( + in_channels=MEM_DIM, + out_channels=EMB_DIM, + msg_dim=data.msg.size(-1), + time_enc=memory.time_enc, + ).to(device) + + link_pred = LinkPredictor(in_channels=EMB_DIM).to(device) + + model = {'memory': memory, + 'gnn': gnn, + 'link_pred': link_pred} + + optimizer = torch.optim.Adam( + set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()), + lr=LR, + ) + criterion = torch.nn.BCEWithLogitsLoss() + + # Helper vector to map global node indices to local ones. + assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device) + + # define an early stopper save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/' save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}' diff --git a/examples/linkproppred/tgbl-flight/tgn.py b/examples/linkproppred/tgbl-flight/tgn.py index 6772d48e..b1aa621a 100644 --- a/examples/linkproppred/tgbl-flight/tgn.py +++ b/examples/linkproppred/tgbl-flight/tgn.py @@ -186,12 +186,13 @@ def test(loader, neg_sampler, split_mode): # Start... start_overall = timeit.default_timer() +DATA = "tgbl-flight" # ========== set parameters... args, _ = get_args() +args.data = DATA print("INFO: Arguments:", args) -DATA = "tgbl-flight" LR = args.lr BATCH_SIZE = args.bs K_VALUE = args.k_value @@ -232,41 +233,6 @@ def test(loader, neg_sampler, split_mode): # Ensure to only sample actual destination nodes as negatives. min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max()) -# neighborhood sampler -neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device) - -# define the model end-to-end -memory = TGNMemory( - data.num_nodes, - data.msg.size(-1), - MEM_DIM, - TIME_DIM, - message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM), - aggregator_module=LastAggregator(), -).to(device) - -gnn = GraphAttentionEmbedding( - in_channels=MEM_DIM, - out_channels=EMB_DIM, - msg_dim=data.msg.size(-1), - time_enc=memory.time_enc, -).to(device) - -link_pred = LinkPredictor(in_channels=EMB_DIM).to(device) - -model = {'memory': memory, - 'gnn': gnn, - 'link_pred': link_pred} - -optimizer = torch.optim.Adam( - set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()), - lr=LR, -) -criterion = torch.nn.BCEWithLogitsLoss() - -# Helper vector to map global node indices to local ones. -assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device) - print("==========================================================") print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============") @@ -292,6 +258,41 @@ def test(loader, neg_sampler, split_mode): torch.manual_seed(run_idx + SEED) set_random_seed(run_idx + SEED) + # neighborhood sampler + neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device) + + # define the model end-to-end + memory = TGNMemory( + data.num_nodes, + data.msg.size(-1), + MEM_DIM, + TIME_DIM, + message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM), + aggregator_module=LastAggregator(), + ).to(device) + + gnn = GraphAttentionEmbedding( + in_channels=MEM_DIM, + out_channels=EMB_DIM, + msg_dim=data.msg.size(-1), + time_enc=memory.time_enc, + ).to(device) + + link_pred = LinkPredictor(in_channels=EMB_DIM).to(device) + + model = {'memory': memory, + 'gnn': gnn, + 'link_pred': link_pred} + + optimizer = torch.optim.Adam( + set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()), + lr=LR, + ) + criterion = torch.nn.BCEWithLogitsLoss() + + # Helper vector to map global node indices to local ones. + assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device) + # define an early stopper save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/' save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}' diff --git a/examples/linkproppred/tgbl-lastfm/tgn.py b/examples/linkproppred/tgbl-lastfm/tgn.py index a31986b9..be1731bf 100644 --- a/examples/linkproppred/tgbl-lastfm/tgn.py +++ b/examples/linkproppred/tgbl-lastfm/tgn.py @@ -183,12 +183,14 @@ def test(loader, neg_sampler, split_mode): # Start... start_overall = timeit.default_timer() +DATA = "tgbl-lastfm" + # ========== set parameters... args, _ = get_args() +args.data = DATA print("INFO: Arguments:", args) -DATA = "tgbl-lastfm" LR = args.lr BATCH_SIZE = args.bs K_VALUE = args.k_value @@ -230,42 +232,6 @@ def test(loader, neg_sampler, split_mode): # Ensure to only sample actual destination nodes as negatives. min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max()) -# neighhorhood sampler -neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device) - -# define the model end-to-end -memory = TGNMemory( - data.num_nodes, - data.msg.size(-1), - MEM_DIM, - TIME_DIM, - message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM), - aggregator_module=LastAggregator(), -).to(device) - -gnn = GraphAttentionEmbedding( - in_channels=MEM_DIM, - out_channels=EMB_DIM, - msg_dim=data.msg.size(-1), - time_enc=memory.time_enc, -).to(device) - -link_pred = LinkPredictor(in_channels=EMB_DIM).to(device) - -model = {'memory': memory, - 'gnn': gnn, - 'link_pred': link_pred} - -optimizer = torch.optim.Adam( - set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()), - lr=LR, -) -criterion = torch.nn.BCEWithLogitsLoss() - -# Helper vector to map global node indices to local ones. -assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device) - - print("==========================================================") print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============") print("==========================================================") @@ -295,6 +261,41 @@ def test(loader, neg_sampler, split_mode): save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}' early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id, tolerance=TOLERANCE, patience=PATIENCE) + + # neighhorhood sampler + neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device) + + # define the model end-to-end + memory = TGNMemory( + data.num_nodes, + data.msg.size(-1), + MEM_DIM, + TIME_DIM, + message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM), + aggregator_module=LastAggregator(), + ).to(device) + + gnn = GraphAttentionEmbedding( + in_channels=MEM_DIM, + out_channels=EMB_DIM, + msg_dim=data.msg.size(-1), + time_enc=memory.time_enc, + ).to(device) + + link_pred = LinkPredictor(in_channels=EMB_DIM).to(device) + + model = {'memory': memory, + 'gnn': gnn, + 'link_pred': link_pred} + + optimizer = torch.optim.Adam( + set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()), + lr=LR, + ) + criterion = torch.nn.BCEWithLogitsLoss() + + # Helper vector to map global node indices to local ones. + assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device) # ==================================================== Train & Validation # loading the validation negative samples diff --git a/examples/linkproppred/tgbl-review/tgn.py b/examples/linkproppred/tgbl-review/tgn.py index 8791b601..f03504e2 100644 --- a/examples/linkproppred/tgbl-review/tgn.py +++ b/examples/linkproppred/tgbl-review/tgn.py @@ -186,12 +186,14 @@ def test(loader, neg_sampler, split_mode): # Start... start_overall = timeit.default_timer() +DATA = "tgbl-review" + # ========== set parameters... args, _ = get_args() +args.data = DATA print("INFO: Arguments:", args) -DATA = "tgbl-review" LR = args.lr BATCH_SIZE = args.bs K_VALUE = args.k_value @@ -232,42 +234,6 @@ def test(loader, neg_sampler, split_mode): # Ensure to only sample actual destination nodes as negatives. min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max()) -# neighborhood sampler -neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device) - -# define the model end-to-end -memory = TGNMemory( - data.num_nodes, - data.msg.size(-1), - MEM_DIM, - TIME_DIM, - message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM), - aggregator_module=LastAggregator(), -).to(device) - -gnn = GraphAttentionEmbedding( - in_channels=MEM_DIM, - out_channels=EMB_DIM, - msg_dim=data.msg.size(-1), - time_enc=memory.time_enc, -).to(device) - -link_pred = LinkPredictor(in_channels=EMB_DIM).to(device) - -model = {'memory': memory, - 'gnn': gnn, - 'link_pred': link_pred} - -optimizer = torch.optim.Adam( - set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()), - lr=LR, -) -criterion = torch.nn.BCEWithLogitsLoss() - -# Helper vector to map global node indices to local ones. -assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device) - - print("==========================================================") print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============") print("==========================================================") @@ -292,6 +258,43 @@ def test(loader, neg_sampler, split_mode): torch.manual_seed(run_idx + SEED) set_random_seed(run_idx + SEED) + + # neighborhood sampler + neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device) + + # define the model end-to-end + memory = TGNMemory( + data.num_nodes, + data.msg.size(-1), + MEM_DIM, + TIME_DIM, + message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM), + aggregator_module=LastAggregator(), + ).to(device) + + gnn = GraphAttentionEmbedding( + in_channels=MEM_DIM, + out_channels=EMB_DIM, + msg_dim=data.msg.size(-1), + time_enc=memory.time_enc, + ).to(device) + + link_pred = LinkPredictor(in_channels=EMB_DIM).to(device) + + model = {'memory': memory, + 'gnn': gnn, + 'link_pred': link_pred} + + optimizer = torch.optim.Adam( + set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()), + lr=LR, + ) + criterion = torch.nn.BCEWithLogitsLoss() + + # Helper vector to map global node indices to local ones. + assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device) + + # define an early stopper save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/' save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}' diff --git a/examples/linkproppred/tgbl-subreddit/tgn.py b/examples/linkproppred/tgbl-subreddit/tgn.py index a5c3dc99..5d36ccb8 100644 --- a/examples/linkproppred/tgbl-subreddit/tgn.py +++ b/examples/linkproppred/tgbl-subreddit/tgn.py @@ -184,12 +184,13 @@ def test(loader, neg_sampler, split_mode): # Start... start_overall = timeit.default_timer() +DATA = "tgbl-subreddit" # ========== set parameters... args, _ = get_args() +args.data = DATA print("INFO: Arguments:", args) -DATA = "tgbl-subreddit" LR = args.lr BATCH_SIZE = args.bs K_VALUE = args.k_value @@ -231,42 +232,6 @@ def test(loader, neg_sampler, split_mode): # Ensure to only sample actual destination nodes as negatives. min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max()) -# neighhorhood sampler -neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device) - -# define the model end-to-end -memory = TGNMemory( - data.num_nodes, - data.msg.size(-1), - MEM_DIM, - TIME_DIM, - message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM), - aggregator_module=LastAggregator(), -).to(device) - -gnn = GraphAttentionEmbedding( - in_channels=MEM_DIM, - out_channels=EMB_DIM, - msg_dim=data.msg.size(-1), - time_enc=memory.time_enc, -).to(device) - -link_pred = LinkPredictor(in_channels=EMB_DIM).to(device) - -model = {'memory': memory, - 'gnn': gnn, - 'link_pred': link_pred} - -optimizer = torch.optim.Adam( - set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()), - lr=LR, -) -criterion = torch.nn.BCEWithLogitsLoss() - -# Helper vector to map global node indices to local ones. -assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device) - - print("==========================================================") print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============") print("==========================================================") @@ -291,6 +256,42 @@ def test(loader, neg_sampler, split_mode): torch.manual_seed(run_idx + SEED) set_random_seed(run_idx + SEED) + # neighhorhood sampler + neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device) + + # define the model end-to-end + memory = TGNMemory( + data.num_nodes, + data.msg.size(-1), + MEM_DIM, + TIME_DIM, + message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM), + aggregator_module=LastAggregator(), + ).to(device) + + gnn = GraphAttentionEmbedding( + in_channels=MEM_DIM, + out_channels=EMB_DIM, + msg_dim=data.msg.size(-1), + time_enc=memory.time_enc, + ).to(device) + + link_pred = LinkPredictor(in_channels=EMB_DIM).to(device) + + model = {'memory': memory, + 'gnn': gnn, + 'link_pred': link_pred} + + optimizer = torch.optim.Adam( + set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()), + lr=LR, + ) + criterion = torch.nn.BCEWithLogitsLoss() + + # Helper vector to map global node indices to local ones. + assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device) + + # define an early stopper save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/' save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}' diff --git a/examples/linkproppred/tgbl-wiki/tgn.py b/examples/linkproppred/tgbl-wiki/tgn.py index 95b04bf4..edadce6a 100644 --- a/examples/linkproppred/tgbl-wiki/tgn.py +++ b/examples/linkproppred/tgbl-wiki/tgn.py @@ -185,12 +185,13 @@ def test(loader, neg_sampler, split_mode): # Start... start_overall = timeit.default_timer() +DATA = "tgbl-wiki" # ========== set parameters... args, _ = get_args() +args.data = DATA print("INFO: Arguments:", args) -DATA = "tgbl-wiki" LR = args.lr BATCH_SIZE = args.bs K_VALUE = args.k_value @@ -231,42 +232,6 @@ def test(loader, neg_sampler, split_mode): # Ensure to only sample actual destination nodes as negatives. min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max()) -# neighborhood sampler -neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device) - -# define the model end-to-end -memory = TGNMemory( - data.num_nodes, - data.msg.size(-1), - MEM_DIM, - TIME_DIM, - message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM), - aggregator_module=LastAggregator(), -).to(device) - -gnn = GraphAttentionEmbedding( - in_channels=MEM_DIM, - out_channels=EMB_DIM, - msg_dim=data.msg.size(-1), - time_enc=memory.time_enc, -).to(device) - -link_pred = LinkPredictor(in_channels=EMB_DIM).to(device) - -model = {'memory': memory, - 'gnn': gnn, - 'link_pred': link_pred} - -optimizer = torch.optim.Adam( - set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()), - lr=LR, -) -criterion = torch.nn.BCEWithLogitsLoss() - -# Helper vector to map global node indices to local ones. -assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device) - - print("==========================================================") print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============") print("==========================================================") @@ -291,6 +256,42 @@ def test(loader, neg_sampler, split_mode): torch.manual_seed(run_idx + SEED) set_random_seed(run_idx + SEED) + + # neighborhood sampler + neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device) + + # define the model end-to-end + memory = TGNMemory( + data.num_nodes, + data.msg.size(-1), + MEM_DIM, + TIME_DIM, + message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM), + aggregator_module=LastAggregator(), + ).to(device) + + gnn = GraphAttentionEmbedding( + in_channels=MEM_DIM, + out_channels=EMB_DIM, + msg_dim=data.msg.size(-1), + time_enc=memory.time_enc, + ).to(device) + + link_pred = LinkPredictor(in_channels=EMB_DIM).to(device) + + model = {'memory': memory, + 'gnn': gnn, + 'link_pred': link_pred} + + optimizer = torch.optim.Adam( + set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()), + lr=LR, + ) + criterion = torch.nn.BCEWithLogitsLoss() + + # Helper vector to map global node indices to local ones. + assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device) + # define an early stopper save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/' save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}' diff --git a/examples/linkproppred/tgbl-wiki/tgn_mem.py b/examples/linkproppred/tgbl-wiki/tgn_mem.py deleted file mode 100644 index 66b5236d..00000000 --- a/examples/linkproppred/tgbl-wiki/tgn_mem.py +++ /dev/null @@ -1,371 +0,0 @@ -""" -Dynamic Link Prediction with a TGN model with Early Stopping -Reference: - - https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py - -command for an example run: - python examples/linkproppred/tgbl-wiki/tgn.py --data "tgbl-wiki" --num_run 1 --seed 1 -""" - -import math -import timeit - -import os -import os.path as osp -from pathlib import Path -import numpy as np - -import torch -from sklearn.metrics import average_precision_score, roc_auc_score -from torch.nn import Linear - -from torch_geometric.datasets import JODIEDataset -from torch_geometric.loader import TemporalDataLoader - -from torch_geometric.nn import TransformerConv - -# internal imports -from tgb.utils.utils import get_args, set_random_seed, save_results -from tgb.linkproppred.evaluate import Evaluator -from modules.decoder import LinkPredictor -from modules.emb_module import GraphAttentionEmbedding -from modules.msg_func import IdentityMessage -from modules.msg_agg import LastAggregator -from modules.neighbor_loader import LastNeighborLoader -from modules.memory_module import TGNMemory -from modules.early_stopping import EarlyStopMonitor -from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset - - -# ========== -# ========== Define helper function... -# ========== - -def train(): - r""" - Training procedure for TGN model - This function uses some objects that are globally defined in the current scrips - - Parameters: - None - Returns: - None - - """ - - model['memory'].train() - model['gnn'].train() - model['link_pred'].train() - - model['memory'].reset_state() # Start with a fresh memory. - neighbor_loader.reset_state() # Start with an empty graph. - - total_loss = 0 - for batch in train_loader: - batch = batch.to(device) - optimizer.zero_grad() - - src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg - - # Sample negative destination nodes. - neg_dst = torch.randint( - min_dst_idx, - max_dst_idx + 1, - (src.size(0),), - dtype=torch.long, - device=device, - ) - - n_id = torch.cat([src, pos_dst, neg_dst]).unique() - n_id, edge_index, e_id = neighbor_loader(n_id) - assoc[n_id] = torch.arange(n_id.size(0), device=device) - - # Get updated memory of all nodes involved in the computation. - z, last_update = model['memory'](n_id) - z = model['gnn']( - z, - last_update, - edge_index, - data.t[e_id].to(device), - data.msg[e_id].to(device), - ) - - pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]]) - neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]]) - - loss = criterion(pos_out, torch.ones_like(pos_out)) - loss += criterion(neg_out, torch.zeros_like(neg_out)) - - # Update memory and neighbor loader with ground-truth state. - model['memory'].update_state(src, pos_dst, t, msg) - neighbor_loader.insert(src, pos_dst) - - loss.backward() - optimizer.step() - model['memory'].detach() - total_loss += float(loss) * batch.num_events - - return total_loss / train_data.num_events - - -@torch.no_grad() -def test(loader, neg_sampler, split_mode): - r""" - Evaluated the dynamic link prediction - Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges - - Parameters: - loader: an object containing positive attributes of the positive edges of the evaluation set - neg_sampler: an object that gives the negative edges corresponding to each positive edge - split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives - Returns: - perf_metric: the result of the performance evaluation - """ - model['memory'].eval() - model['gnn'].eval() - model['link_pred'].eval() - - perf_list = [] - - for pos_batch in loader: - pos_src, pos_dst, pos_t, pos_msg = ( - pos_batch.src, - pos_batch.dst, - pos_batch.t, - pos_batch.msg, - ) - - neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode) - - for idx, neg_batch in enumerate(neg_batch_list): - src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device) - dst = torch.tensor( - np.concatenate( - ([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]), - axis=0, - ), - device=device, - ) - - n_id = torch.cat([src, dst]).unique() - n_id, edge_index, e_id = neighbor_loader(n_id) - assoc[n_id] = torch.arange(n_id.size(0), device=device) - - # Get updated memory of all nodes involved in the computation. - z, last_update = model['memory'](n_id) - z = model['gnn']( - z, - last_update, - edge_index, - data.t[e_id].to(device), - data.msg[e_id].to(device), - ) - - y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]]) - - # compute MRR - input_dict = { - "y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]), - "y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()), - "eval_metric": [metric], - } - perf_list.append(evaluator.eval(input_dict)[metric]) - - # Update memory and neighbor loader with ground-truth state. - model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg) - neighbor_loader.insert(pos_src, pos_dst) - - perf_metrics = float(torch.tensor(perf_list).mean()) - - return perf_metrics - -# ========== -# ========== -# ========== - - -# Start... -start_overall = timeit.default_timer() - -# ========== set parameters... -args, _ = get_args() -print("INFO: Arguments:", args) - -DATA = "tgbl-wiki" -LR = args.lr -BATCH_SIZE = args.bs -K_VALUE = args.k_value -NUM_EPOCH = args.num_epoch -SEED = args.seed -MEM_DIM = args.mem_dim -TIME_DIM = args.time_dim -EMB_DIM = args.emb_dim -TOLERANCE = args.tolerance -PATIENCE = args.patience -NUM_RUNS = args.num_run -NUM_NEIGHBORS = 10 - - -MODEL_NAME = 'TGN' -# ========== - -# set the device -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - -# data loading -dataset = PyGLinkPropPredDataset(name=DATA, root="datasets") -train_mask = dataset.train_mask -val_mask = dataset.val_mask -test_mask = dataset.test_mask -data = dataset.get_TemporalData() -data = data.to(device) -metric = dataset.eval_metric - -train_data = data[train_mask] -val_data = data[val_mask] -test_data = data[test_mask] - -train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE) -val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE) -test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE) - -# Ensure to only sample actual destination nodes as negatives. -min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max()) - -# neighborhood sampler -neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device) - -# define the model end-to-end -memory = TGNMemory( - data.num_nodes, - data.msg.size(-1), - MEM_DIM, - TIME_DIM, - message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM), - aggregator_module=LastAggregator(), -).to(device) - -gnn = GraphAttentionEmbedding( - in_channels=MEM_DIM, - out_channels=EMB_DIM, - msg_dim=data.msg.size(-1), - time_enc=memory.time_enc, -).to(device) - -link_pred = LinkPredictor(in_channels=EMB_DIM).to(device) - -model = {'memory': memory, - 'gnn': gnn, - 'link_pred': link_pred} - -optimizer = torch.optim.Adam( - set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()), - lr=LR, -) -criterion = torch.nn.BCEWithLogitsLoss() - -# Helper vector to map global node indices to local ones. -assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device) - - -print("==========================================================") -print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============") -print("==========================================================") - -evaluator = Evaluator(name=DATA) -neg_sampler = dataset.negative_sampler - -# for saving the results... -results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results' -if not osp.exists(results_path): - os.mkdir(results_path) - print('INFO: Create directory {}'.format(results_path)) -Path(results_path).mkdir(parents=True, exist_ok=True) -results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json' - -for run_idx in range(NUM_RUNS): - print('-------------------------------------------------------------------------------') - print(f"INFO: >>>>> Run: {run_idx} <<<<<") - start_run = timeit.default_timer() - - # set the seed for deterministic results... - torch.manual_seed(run_idx + SEED) - set_random_seed(run_idx + SEED) - - # define an early stopper - save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/' - save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}' - early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id, - tolerance=TOLERANCE, patience=PATIENCE) - - # ==================================================== Train & Validation - # loading the validation negative samples - dataset.load_val_ns() - - val_perf_list = [] - start_train_val = timeit.default_timer() - for epoch in range(1, NUM_EPOCH + 1): - # training - start_epoch_train = timeit.default_timer() - loss = train() - print( - f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}" - ) - - #! checking GPU usage - free_mem, total_mem = torch.cuda.mem_get_info() - print ("--------------GPU memory usage-----------") - print ("there are ", free_mem, " free memory") - print ("there are ", total_mem, " total available memory") - print ("there are ", total_mem - free_mem, " used memory") - print ("--------------GPU memory usage-----------") - - - - - # validation - start_val = timeit.default_timer() - perf_metric_val = test(val_loader, neg_sampler, split_mode="val") - print(f"\tValidation {metric}: {perf_metric_val: .4f}") - print(f"\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}") - val_perf_list.append(perf_metric_val) - - # check for early stopping - if early_stopper.step_check(perf_metric_val, model): - break - - train_val_time = timeit.default_timer() - start_train_val - print(f"Train & Validation: Elapsed Time (s): {train_val_time: .4f}") - - # ==================================================== Test - # first, load the best model - early_stopper.load_checkpoint(model) - - # loading the test negative samples - dataset.load_test_ns() - - # final testing - start_test = timeit.default_timer() - perf_metric_test = test(test_loader, neg_sampler, split_mode="test") - - print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ") - print(f"\tTest: {metric}: {perf_metric_test: .4f}") - test_time = timeit.default_timer() - start_test - print(f"\tTest: Elapsed Time (s): {test_time: .4f}") - - save_results({'model': MODEL_NAME, - 'data': DATA, - 'run': run_idx, - 'seed': SEED, - f'val {metric}': val_perf_list, - f'test {metric}': perf_metric_test, - 'test_time': test_time, - 'tot_train_val_time': train_val_time - }, - results_filename) - - print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<") - print('-------------------------------------------------------------------------------') - -print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}") -print("==============================================================") diff --git a/examples/linkproppred/thgl-forum/tgn.py b/examples/linkproppred/thgl-forum/tgn.py index f81b342e..0758a503 100644 --- a/examples/linkproppred/thgl-forum/tgn.py +++ b/examples/linkproppred/thgl-forum/tgn.py @@ -194,11 +194,13 @@ def test(loader, neg_sampler, split_mode): # Start... start_overall = timeit.default_timer() +DATA = "thgl-forum" + # ========== set parameters... args, _ = get_args() +args.data = DATA print("INFO: Arguments:", args) -DATA = "thgl-forum" LR = args.lr BATCH_SIZE = args.bs K_VALUE = args.k_value @@ -273,41 +275,6 @@ def test(loader, neg_sampler, split_mode): # Ensure to only sample actual destination nodes as negatives. min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max()) -# neighhorhood sampler -neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device) - -# define the model end-to-end -memory = TGNMemory( - data.num_nodes, - data.msg.size(-1), - MEM_DIM, - TIME_DIM, - message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM), - aggregator_module=LastAggregator(), -).to(device) - -gnn = GraphAttentionEmbedding( - in_channels=MEM_DIM, - out_channels=EMB_DIM, - msg_dim=data.msg.size(-1), - time_enc=memory.time_enc, -).to(device) - -link_pred = LinkPredictor(in_channels=EMB_DIM).to(device) - -model = {'memory': memory, - 'gnn': gnn, - 'link_pred': link_pred} - -optimizer = torch.optim.Adam( - set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()), - lr=LR, -) -criterion = torch.nn.BCEWithLogitsLoss() - -# Helper vector to map global node indices to local ones. -assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device) - print("==========================================================") print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============") @@ -333,6 +300,41 @@ def test(loader, neg_sampler, split_mode): torch.manual_seed(run_idx + SEED) set_random_seed(run_idx + SEED) + # neighhorhood sampler + neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device) + + # define the model end-to-end + memory = TGNMemory( + data.num_nodes, + data.msg.size(-1), + MEM_DIM, + TIME_DIM, + message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM), + aggregator_module=LastAggregator(), + ).to(device) + + gnn = GraphAttentionEmbedding( + in_channels=MEM_DIM, + out_channels=EMB_DIM, + msg_dim=data.msg.size(-1), + time_enc=memory.time_enc, + ).to(device) + + link_pred = LinkPredictor(in_channels=EMB_DIM).to(device) + + model = {'memory': memory, + 'gnn': gnn, + 'link_pred': link_pred} + + optimizer = torch.optim.Adam( + set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()), + lr=LR, + ) + criterion = torch.nn.BCEWithLogitsLoss() + + # Helper vector to map global node indices to local ones. + assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device) + # define an early stopper save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/' save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}' diff --git a/examples/linkproppred/thgl-github/tgn.py b/examples/linkproppred/thgl-github/tgn.py index 7d7e2ee7..bc5fe315 100644 --- a/examples/linkproppred/thgl-github/tgn.py +++ b/examples/linkproppred/thgl-github/tgn.py @@ -197,12 +197,14 @@ def test(loader, neg_sampler, split_mode): # Start... start_overall = timeit.default_timer() +DATA = "thgl-github" + # ========== set parameters... args, _ = get_args() +args.data = DATA print("INFO: Arguments:", args) -DATA = "thgl-github" LR = args.lr BATCH_SIZE = args.bs K_VALUE = args.k_value @@ -281,41 +283,6 @@ def test(loader, neg_sampler, split_mode): # Ensure to only sample actual destination nodes as negatives. min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max()) -# neighhorhood sampler -neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device) - -# define the model end-to-end -memory = TGNMemory( - data.num_nodes, - data.msg.size(-1), - MEM_DIM, - TIME_DIM, - message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM), - aggregator_module=LastAggregator(), -).to(device) - -gnn = GraphAttentionEmbedding( - in_channels=MEM_DIM, - out_channels=EMB_DIM, - msg_dim=data.msg.size(-1), - time_enc=memory.time_enc, -).to(device) - -link_pred = LinkPredictor(in_channels=EMB_DIM).to(device) - -model = {'memory': memory, - 'gnn': gnn, - 'link_pred': link_pred} - -optimizer = torch.optim.Adam( - set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()), - lr=LR, -) -criterion = torch.nn.BCEWithLogitsLoss() - -# Helper vector to map global node indices to local ones. -assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device) - print("==========================================================") print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============") @@ -345,6 +312,41 @@ def test(loader, neg_sampler, split_mode): torch.manual_seed(run_idx + SEED) set_random_seed(run_idx + SEED) + # neighhorhood sampler + neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device) + + # define the model end-to-end + memory = TGNMemory( + data.num_nodes, + data.msg.size(-1), + MEM_DIM, + TIME_DIM, + message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM), + aggregator_module=LastAggregator(), + ).to(device) + + gnn = GraphAttentionEmbedding( + in_channels=MEM_DIM, + out_channels=EMB_DIM, + msg_dim=data.msg.size(-1), + time_enc=memory.time_enc, + ).to(device) + + link_pred = LinkPredictor(in_channels=EMB_DIM).to(device) + + model = {'memory': memory, + 'gnn': gnn, + 'link_pred': link_pred} + + optimizer = torch.optim.Adam( + set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()), + lr=LR, + ) + criterion = torch.nn.BCEWithLogitsLoss() + + # Helper vector to map global node indices to local ones. + assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device) + # # define an early stopper # save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/' # save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}' diff --git a/examples/linkproppred/thgl-myket/tgn.py b/examples/linkproppred/thgl-myket/tgn.py index 5fe85d26..48bb8ada 100644 --- a/examples/linkproppred/thgl-myket/tgn.py +++ b/examples/linkproppred/thgl-myket/tgn.py @@ -193,12 +193,14 @@ def test(loader, neg_sampler, split_mode): # Start... start_overall = timeit.default_timer() +DATA = "thgl-myket" + # ========== set parameters... args, _ = get_args() +args.data = DATA print("INFO: Arguments:", args) -DATA = "thgl-myket" LR = args.lr BATCH_SIZE = args.bs K_VALUE = args.k_value @@ -274,41 +276,6 @@ def test(loader, neg_sampler, split_mode): # Ensure to only sample actual destination nodes as negatives. min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max()) -# neighhorhood sampler -neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device) - -# define the model end-to-end -memory = TGNMemory( - data.num_nodes, - data.msg.size(-1), - MEM_DIM, - TIME_DIM, - message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM), - aggregator_module=LastAggregator(), -).to(device) - -gnn = GraphAttentionEmbedding( - in_channels=MEM_DIM, - out_channels=EMB_DIM, - msg_dim=data.msg.size(-1), - time_enc=memory.time_enc, -).to(device) - -link_pred = LinkPredictor(in_channels=EMB_DIM).to(device) - -model = {'memory': memory, - 'gnn': gnn, - 'link_pred': link_pred} - -optimizer = torch.optim.Adam( - set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()), - lr=LR, -) -criterion = torch.nn.BCEWithLogitsLoss() - -# Helper vector to map global node indices to local ones. -assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device) - print("==========================================================") print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============") @@ -334,6 +301,41 @@ def test(loader, neg_sampler, split_mode): torch.manual_seed(run_idx + SEED) set_random_seed(run_idx + SEED) + # neighhorhood sampler + neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device) + + # define the model end-to-end + memory = TGNMemory( + data.num_nodes, + data.msg.size(-1), + MEM_DIM, + TIME_DIM, + message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM), + aggregator_module=LastAggregator(), + ).to(device) + + gnn = GraphAttentionEmbedding( + in_channels=MEM_DIM, + out_channels=EMB_DIM, + msg_dim=data.msg.size(-1), + time_enc=memory.time_enc, + ).to(device) + + link_pred = LinkPredictor(in_channels=EMB_DIM).to(device) + + model = {'memory': memory, + 'gnn': gnn, + 'link_pred': link_pred} + + optimizer = torch.optim.Adam( + set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()), + lr=LR, + ) + criterion = torch.nn.BCEWithLogitsLoss() + + # Helper vector to map global node indices to local ones. + assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device) + # define an early stopper save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/' save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}' diff --git a/examples/linkproppred/thgl-software/tgn.py b/examples/linkproppred/thgl-software/tgn.py index cb162d8d..fab787e7 100644 --- a/examples/linkproppred/thgl-software/tgn.py +++ b/examples/linkproppred/thgl-software/tgn.py @@ -193,12 +193,14 @@ def test(loader, neg_sampler, split_mode): # Start... start_overall = timeit.default_timer() +DATA = "thgl-software" + # ========== set parameters... args, _ = get_args() +args.data = DATA print("INFO: Arguments:", args) -DATA = "thgl-software" LR = args.lr BATCH_SIZE = args.bs K_VALUE = args.k_value @@ -273,41 +275,6 @@ def test(loader, neg_sampler, split_mode): # Ensure to only sample actual destination nodes as negatives. min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max()) -# neighhorhood sampler -neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device) - -# define the model end-to-end -memory = TGNMemory( - data.num_nodes, - data.msg.size(-1), - MEM_DIM, - TIME_DIM, - message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM), - aggregator_module=LastAggregator(), -).to(device) - -gnn = GraphAttentionEmbedding( - in_channels=MEM_DIM, - out_channels=EMB_DIM, - msg_dim=data.msg.size(-1), - time_enc=memory.time_enc, -).to(device) - -link_pred = LinkPredictor(in_channels=EMB_DIM).to(device) - -model = {'memory': memory, - 'gnn': gnn, - 'link_pred': link_pred} - -optimizer = torch.optim.Adam( - set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()), - lr=LR, -) -criterion = torch.nn.BCEWithLogitsLoss() - -# Helper vector to map global node indices to local ones. -assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device) - print("==========================================================") print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============") @@ -333,6 +300,41 @@ def test(loader, neg_sampler, split_mode): torch.manual_seed(run_idx + SEED) set_random_seed(run_idx + SEED) + # neighhorhood sampler + neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device) + + # define the model end-to-end + memory = TGNMemory( + data.num_nodes, + data.msg.size(-1), + MEM_DIM, + TIME_DIM, + message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM), + aggregator_module=LastAggregator(), + ).to(device) + + gnn = GraphAttentionEmbedding( + in_channels=MEM_DIM, + out_channels=EMB_DIM, + msg_dim=data.msg.size(-1), + time_enc=memory.time_enc, + ).to(device) + + link_pred = LinkPredictor(in_channels=EMB_DIM).to(device) + + model = {'memory': memory, + 'gnn': gnn, + 'link_pred': link_pred} + + optimizer = torch.optim.Adam( + set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()), + lr=LR, + ) + criterion = torch.nn.BCEWithLogitsLoss() + + # Helper vector to map global node indices to local ones. + assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device) + # define an early stopper save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/' save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'