Skip to content

Commit

Permalink
Merge pull request #94 from JuliaGast/main
Browse files Browse the repository at this point in the history
cleaning up tgn training script
  • Loading branch information
shenyangHuang authored Jul 9, 2024
2 parents 1bf3c14 + d49c5ba commit 170f60c
Show file tree
Hide file tree
Showing 12 changed files with 420 additions and 772 deletions.
75 changes: 39 additions & 36 deletions examples/linkproppred/tgbl-coin/tgn.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,14 @@ def test(loader, neg_sampler, split_mode):

# Start...
start_overall = timeit.default_timer()
DATA = "tgbl-coin"


# ========== set parameters...
args, _ = get_args()
args.data = DATA
print("INFO: Arguments:", args)

DATA = "tgbl-coin"
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
Expand Down Expand Up @@ -232,41 +234,6 @@ def test(loader, neg_sampler, split_mode):
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())

# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)

# define the model end-to-end
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
).to(device)

gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)

link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)

model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}

optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()

# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)


print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
Expand All @@ -292,6 +259,42 @@ def test(loader, neg_sampler, split_mode):
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)

# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)

# define the model end-to-end
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
).to(device)

gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)

link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)

model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}

optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()

# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)


# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
Expand Down
75 changes: 38 additions & 37 deletions examples/linkproppred/tgbl-comment/tgn.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,13 @@ def test(loader, neg_sampler, split_mode):

# Start...
start_overall = timeit.default_timer()
DATA = "tgbl-comment"

# ========== set parameters...
args, _ = get_args()
args.data = DATA
print("INFO: Arguments:", args)

DATA = "tgbl-comment"
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
Expand Down Expand Up @@ -232,42 +233,6 @@ def test(loader, neg_sampler, split_mode):
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())

# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)

# define the model end-to-end
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
).to(device)

gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)

link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)

model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}

optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()

# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)


print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
Expand All @@ -292,6 +257,42 @@ def test(loader, neg_sampler, split_mode):
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)

# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)

# define the model end-to-end
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
).to(device)

gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)

link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)

model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}

optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()

# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)


# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
Expand Down
73 changes: 37 additions & 36 deletions examples/linkproppred/tgbl-flight/tgn.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,13 @@ def test(loader, neg_sampler, split_mode):

# Start...
start_overall = timeit.default_timer()
DATA = "tgbl-flight"

# ========== set parameters...
args, _ = get_args()
args.data = DATA
print("INFO: Arguments:", args)

DATA = "tgbl-flight"
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
Expand Down Expand Up @@ -232,41 +233,6 @@ def test(loader, neg_sampler, split_mode):
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())

# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)

# define the model end-to-end
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
).to(device)

gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)

link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)

model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}

optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()

# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)


print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
Expand All @@ -292,6 +258,41 @@ def test(loader, neg_sampler, split_mode):
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)

# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)

# define the model end-to-end
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
).to(device)

gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)

link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)

model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}

optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()

# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)

# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
Expand Down
Loading

0 comments on commit 170f60c

Please sign in to comment.