-
Notifications
You must be signed in to change notification settings - Fork 403
/
Copy pathmlp.py
206 lines (161 loc) · 7.22 KB
/
mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import argparse
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from ogb.linkproppred import PygLinkPropPredDataset, Evaluator
from logger import Logger
class LinkPredictor(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
dropout):
super(LinkPredictor, self).__init__()
self.lins = torch.nn.ModuleList()
self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
for _ in range(num_layers - 2):
self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
self.lins.append(torch.nn.Linear(hidden_channels, out_channels))
self.dropout = dropout
def reset_parameters(self):
for lin in self.lins:
lin.reset_parameters()
def forward(self, x_i, x_j):
x = x_i * x_j
for lin in self.lins[:-1]:
x = lin(x)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lins[-1](x)
return torch.sigmoid(x)
def train(predictor, x, split_edge, optimizer, batch_size):
predictor.train()
pos_train_edge = split_edge['train']['edge'].to(x.device)
total_loss = total_examples = 0
for perm in DataLoader(range(pos_train_edge.size(0)), batch_size,
shuffle=True):
optimizer.zero_grad()
edge = pos_train_edge[perm].t()
pos_out = predictor(x[edge[0]], x[edge[1]])
pos_loss = -torch.log(pos_out + 1e-15).mean()
# Just do some trivial random sampling.
edge = torch.randint(0, x.size(0), edge.size(), dtype=torch.long,
device=x.device)
neg_out = predictor(x[edge[0]], x[edge[1]])
neg_loss = -torch.log(1 - neg_out + 1e-15).mean()
loss = pos_loss + neg_loss
loss.backward()
optimizer.step()
num_examples = pos_out.size(0)
total_loss += loss.item() * num_examples
total_examples += num_examples
return total_loss / total_examples
@torch.no_grad()
def test(predictor, x, split_edge, evaluator, batch_size):
predictor.eval()
pos_train_edge = split_edge['train']['edge'].to(x.device)
pos_valid_edge = split_edge['valid']['edge'].to(x.device)
neg_valid_edge = split_edge['valid']['edge_neg'].to(x.device)
pos_test_edge = split_edge['test']['edge'].to(x.device)
neg_test_edge = split_edge['test']['edge_neg'].to(x.device)
pos_train_preds = []
for perm in DataLoader(range(pos_train_edge.size(0)), batch_size):
edge = pos_train_edge[perm].t()
pos_train_preds += [predictor(x[edge[0]], x[edge[1]]).squeeze().cpu()]
pos_train_pred = torch.cat(pos_train_preds, dim=0)
pos_valid_preds = []
for perm in DataLoader(range(pos_valid_edge.size(0)), batch_size):
edge = pos_valid_edge[perm].t()
pos_valid_preds += [predictor(x[edge[0]], x[edge[1]]).squeeze().cpu()]
pos_valid_pred = torch.cat(pos_valid_preds, dim=0)
neg_valid_preds = []
for perm in DataLoader(range(neg_valid_edge.size(0)), batch_size):
edge = neg_valid_edge[perm].t()
neg_valid_preds += [predictor(x[edge[0]], x[edge[1]]).squeeze().cpu()]
neg_valid_pred = torch.cat(neg_valid_preds, dim=0)
pos_test_preds = []
for perm in DataLoader(range(pos_test_edge.size(0)), batch_size):
edge = pos_test_edge[perm].t()
pos_test_preds += [predictor(x[edge[0]], x[edge[1]]).squeeze().cpu()]
pos_test_pred = torch.cat(pos_test_preds, dim=0)
neg_test_preds = []
for perm in DataLoader(range(neg_test_edge.size(0)), batch_size):
edge = neg_test_edge[perm].t()
neg_test_preds += [predictor(x[edge[0]], x[edge[1]]).squeeze().cpu()]
neg_test_pred = torch.cat(neg_test_preds, dim=0)
results = {}
for K in [10, 50, 100]:
evaluator.K = K
train_hits = evaluator.eval({
'y_pred_pos': pos_train_pred,
'y_pred_neg': neg_valid_pred,
})[f'hits@{K}']
valid_hits = evaluator.eval({
'y_pred_pos': pos_valid_pred,
'y_pred_neg': neg_valid_pred,
})[f'hits@{K}']
test_hits = evaluator.eval({
'y_pred_pos': pos_test_pred,
'y_pred_neg': neg_test_pred,
})[f'hits@{K}']
results[f'Hits@{K}'] = (train_hits, valid_hits, test_hits)
return results
def main():
parser = argparse.ArgumentParser(description='OGBL-COLLAB (MLP)')
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--log_steps', type=int, default=1)
parser.add_argument('--use_node_embedding', action='store_true')
parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('--hidden_channels', type=int, default=256)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--batch_size', type=int, default=64 * 1024)
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--epochs', type=int, default=200)
parser.add_argument('--eval_steps', type=int, default=1)
parser.add_argument('--runs', type=int, default=10)
args = parser.parse_args()
print(args)
device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)
dataset = PygLinkPropPredDataset(name='ogbl-collab')
split_edge = dataset.get_edge_split()
data = dataset[0]
x = data.x
if args.use_node_embedding:
embedding = torch.load('embedding.pt', map_location='cpu')
x = torch.cat([x, embedding], dim=-1)
x = x.to(device)
predictor = LinkPredictor(x.size(-1), args.hidden_channels, 1,
args.num_layers, args.dropout).to(device)
evaluator = Evaluator(name='ogbl-collab')
loggers = {
'Hits@10': Logger(args.runs, args),
'Hits@50': Logger(args.runs, args),
'Hits@100': Logger(args.runs, args),
}
for run in range(args.runs):
predictor.reset_parameters()
optimizer = torch.optim.Adam(predictor.parameters(), lr=args.lr)
for epoch in range(1, 1 + args.epochs):
loss = train(predictor, x, split_edge, optimizer, args.batch_size)
if epoch % args.eval_steps == 0:
results = test(predictor, x, split_edge, evaluator,
args.batch_size)
for key, result in results.items():
loggers[key].add_result(run, result)
if epoch % args.log_steps == 0:
for key, result in results.items():
train_hits, valid_hits, test_hits = result
print(key)
print(f'Run: {run + 1:02d}, '
f'Epoch: {epoch:02d}, '
f'Loss: {loss:.4f}, '
f'Train: {100 * train_hits:.2f}%, '
f'Valid: {100 * valid_hits:.2f}%, '
f'Test: {100 * test_hits:.2f}%')
print('---')
for key in loggers.keys():
print(key)
loggers[key].print_statistics(run)
for key in loggers.keys():
print(key)
loggers[key].print_statistics()
if __name__ == "__main__":
main()