-
Notifications
You must be signed in to change notification settings - Fork 2
/
finetune.py
115 lines (87 loc) · 5.44 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import argparse
import pandas as pd
import os.path as osp
import torch
import torch.nn as nn
from tqdm.auto import tqdm
import torch_geometric.transforms as T
from torch_geometric.data import Data
from torch.utils.data import DataLoader
from torch_geometric import seed_everything
from torch_geometric.data import Data
from torch_geometric.utils import to_undirected, add_self_loops, scatter
# custom modules
from maskgae.model import MaskGAE, DegreeDecoder, EdgeDecoder, GNNEncoder
from maskgae.mask import MaskEdge, MaskPath
def train_linkpred(model, data, args, device="cpu"):
optimizer = torch.optim.Adam(model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay)
for epoch in tqdm(range(1, 1 + args.epochs)):
model.train()
loss = model.train_epoch(data.to(device), optimizer,
alpha=args.alpha,
batch_size=args.batch_size)
parser = argparse.ArgumentParser()
parser.add_argument("--root", type=str, default="data/icdm2023_session1_test", help="path to the data directory. ")
parser.add_argument("--layer", type=str, default="gcn", help="GNN layer, (default: gcn)")
parser.add_argument("--encoder_activation", type=str, default="elu", help="Activation function for GNN encoder, (default: elu)")
parser.add_argument('--encoder_channels', type=int, default=128, help='Channels of GNN encoder layers. (default: 128)')
parser.add_argument('--hidden_channels', type=int, default=64, help='Channels of hidden representation. (default: 64)')
parser.add_argument('--decoder_channels', type=int, default=32, help='Channels of decoder layers. (default: 128)')
parser.add_argument('--encoder_layers', type=int, default=2, help='Number of layers for encoder. (default: 2)')
parser.add_argument('--decoder_layers', type=int, default=2, help='Number of layers for decoders. (default: 2)')
parser.add_argument('--encoder_dropout', type=float, default=0.2, help='Dropout probability of encoder. (default: 0.8)')
parser.add_argument('--decoder_dropout', type=float, default=0.0, help='Dropout probability of decoder. (default: 0.2)')
parser.add_argument('--alpha', type=float, default=0., help='loss weight for degree prediction. (default: 0.)')
parser.add_argument('--lr', type=float, d