-
Notifications
You must be signed in to change notification settings - Fork 69
/
run_DBLP.py
203 lines (184 loc) · 10.3 KB
/
run_DBLP.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import time
import argparse
import torch
import torch.nn.functional as F
import numpy as np
from utils.pytorchtools import EarlyStopping
from utils.data import load_DBLP_data
from utils.tools import index_generator, evaluate_results_nc, parse_minibatch
from model import MAGNN_nc_mb
# Params
out_dim = 4
dropout_rate = 0.5
lr = 0.005
weight_decay = 0.001
etypes_list = [[0, 1], [0, 2, 3, 1], [0, 4, 5, 1]]
def run_model_DBLP(feats_type, hidden_dim, num_heads, attn_vec_dim, rnn_type,
num_epochs, patience, batch_size, neighbor_samples, repeat, save_postfix):
adjlists, edge_metapath_indices_list, features_list, adjM, type_mask, labels, train_val_test_idx = load_DBLP_data()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
features_list = [torch.FloatTensor(features).to(device) for features in features_list]
if feats_type == 0:
in_dims = [features.shape[1] for features in features_list]
elif feats_type == 1:
in_dims = [features_list[0].shape[1]] + [10] * (len(features_list) - 1)
for i in range(1, len(features_list)):
features_list[i] = torch.zeros((features_list[i].shape[0], 10)).to(device)
elif feats_type == 2:
in_dims = [features.shape[0] for features in features_list]
in_dims[0] = features_list[0].shape[1]
for i in range(1, len(features_list)):
dim = features_list[i].shape[0]
indices = np.vstack((np.arange(dim), np.arange(dim)))
indices = torch.LongTensor(indices)
values = torch.FloatTensor(np.ones(dim))
features_list[i] = torch.sparse.FloatTensor(indices, values, torch.Size([dim, dim])).to(device)
elif feats_type == 3:
in_dims = [features.shape[0] for features in features_list]
for i in range(len(features_list)):
dim = features_list[i].shape[0]
indices = np.vstack((np.arange(dim), np.arange(dim)))
indices = torch.LongTensor(indices)
values = torch.FloatTensor(np.ones(dim))
features_list[i] = torch.sparse.FloatTensor(indices, values, torch.Size([dim, dim])).to(device)
labels = torch.LongTensor(labels).to(device)
train_idx = train_val_test_idx['train_idx']
train_idx = np.sort(train_idx)
val_idx = train_val_test_idx['val_idx']
val_idx = np.sort(val_idx)
test_idx = train_val_test_idx['test_idx']
test_idx = np.sort(test_idx)
svm_macro_f1_lists = []
svm_micro_f1_lists = []
nmi_mean_list = []
nmi_std_list = []
ari_mean_list = []
ari_std_list = []
for _ in range(repeat):
net = MAGNN_nc_mb(3, 6, etypes_list, in_dims, hidden_dim, out_dim, num_heads, attn_vec_dim, rnn_type, dropout_rate)
net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)
# training loop
net.train()
early_stopping = EarlyStopping(patience=patience, verbose=True, save_path='checkpoint/checkpoint_{}.pt'.format(save_postfix))
dur1 = []
dur2 = []
dur3 = []
train_idx_generator = index_generator(batch_size=batch_size, indices=train_idx)
val_idx_generator = index_generator(batch_size=batch_size, indices=val_idx, shuffle=False)
for epoch in range(num_epochs):
t_start = time.time()
# training
net.train()
for iteration in range(train_idx_generator.num_iterations()):
# forward
t0 = time.time()
train_idx_batch = train_idx_generator.next()
train_idx_batch.sort()
train_g_list, train_indices_list, train_idx_batch_mapped_list = parse_minibatch(
adjlists, edge_metapath_indices_list, train_idx_batch, device, neighbor_samples)
t1 = time.time()
dur1.append(t1 - t0)
logits, embeddings = net(
(train_g_list, features_list, type_mask, train_indices_list, train_idx_batch_mapped_list))
logp = F.log_softmax(logits, 1)
train_loss = F.nll_loss(logp, labels[train_idx_batch])
t2 = time.time()
dur2.append(t2 - t1)
# autograd
optimizer.zero_grad()
train_loss.backward()
optimizer.step()
t3 = time.time()
dur3.append(t3 - t2)
# print training info
if iteration % 50 == 0:
print(
'Epoch {:05d} | Iteration {:05d} | Train_Loss {:.4f} | Time1(s) {:.4f} | Time2(s) {:.4f} | Time3(s) {:.4f}'.format(
epoch, iteration, train_loss.item(), np.mean(dur1), np.mean(dur2), np.mean(dur3)))
# validation
net.eval()
val_logp = []
with torch.no_grad():
for iteration in range(val_idx_generator.num_iterations()):
# forward
val_idx_batch = val_idx_generator.next()
val_g_list, val_indices_list, val_idx_batch_mapped_list = parse_minibatch(
adjlists, edge_metapath_indices_list, val_idx_batch, device, neighbor_samples)
logits, embeddings = net(
(val_g_list, features_list, type_mask, val_indices_list, val_idx_batch_mapped_list))
logp = F.log_softmax(logits, 1)
val_logp.append(logp)
val_loss = F.nll_loss(torch.cat(val_logp, 0), labels[val_idx])
t_end = time.time()
# print validation info
print('Epoch {:05d} | Val_Loss {:.4f} | Time(s) {:.4f}'.format(
epoch, val_loss.item(), t_end - t_start))
# early stopping
early_stopping(val_loss, net)
if early_stopping.early_stop:
print('Early stopping!')
break
# testing with evaluate_results_nc
test_idx_generator = index_generator(batch_size=batch_size, indices=test_idx, shuffle=False)
net.load_state_dict(torch.load('checkpoint/checkpoint_{}.pt'.format(save_postfix)))
net.eval()
test_embeddings = []
with torch.no_grad():
for iteration in range(test_idx_generator.num_iterations()):
# forward
test_idx_batch = test_idx_generator.next()
test_g_list, test_indices_list, test_idx_batch_mapped_list = parse_minibatch(adjlists,
edge_metapath_indices_list,
test_idx_batch,
device, neighbor_samples)
logits, embeddings = net((test_g_list, features_list, type_mask, test_indices_list, test_idx_batch_mapped_list))
test_embeddings.append(embeddings)
test_embeddings = torch.cat(test_embeddings, 0)
svm_macro_f1_list, svm_micro_f1_list, nmi_mean, nmi_std, ari_mean, ari_std = evaluate_results_nc(
test_embeddings.cpu().numpy(), labels[test_idx].cpu().numpy(), num_classes=out_dim)
svm_macro_f1_lists.append(svm_macro_f1_list)
svm_micro_f1_lists.append(svm_micro_f1_list)
nmi_mean_list.append(nmi_mean)
nmi_std_list.append(nmi_std)
ari_mean_list.append(ari_mean)
ari_std_list.append(ari_std)
# print out a summary of the evaluations
svm_macro_f1_lists = np.transpose(np.array(svm_macro_f1_lists), (1, 0, 2))
svm_micro_f1_lists = np.transpose(np.array(svm_micro_f1_lists), (1, 0, 2))
nmi_mean_list = np.array(nmi_mean_list)
nmi_std_list = np.array(nmi_std_list)
ari_mean_list = np.array(ari_mean_list)
ari_std_list = np.array(ari_std_list)
print('----------------------------------------------------------------')
print('SVM tests summary')
print('Macro-F1: ' + ', '.join(['{:.6f}~{:.6f} ({:.1f})'.format(
macro_f1[:, 0].mean(), macro_f1[:, 1].mean(), train_size) for macro_f1, train_size in
zip(svm_macro_f1_lists, [0.8, 0.6, 0.4, 0.2])]))
print('Micro-F1: ' + ', '.join(['{:.6f}~{:.6f} ({:.1f})'.format(
micro_f1[:, 0].mean(), micro_f1[:, 1].mean(), train_size) for micro_f1, train_size in
zip(svm_micro_f1_lists, [0.8, 0.6, 0.4, 0.2])]))
print('K-means tests summary')
print('NMI: {:.6f}~{:.6f}'.format(nmi_mean_list.mean(), nmi_std_list.mean()))
print('ARI: {:.6f}~{:.6f}'.format(ari_mean_list.mean(), ari_std_list.mean()))
if __name__ == '__main__':
ap = argparse.ArgumentParser(description='MRGNN testing for the DBLP dataset')
ap.add_argument('--feats-type', type=int, default=2,
help='Type of the node features used. ' +
'0 - loaded features; ' +
'1 - only target node features (zero vec for others); ' +
'2 - only target node features (id vec for others); ' +
'3 - all id vec. Default is 2.')
ap.add_argument('--hidden-dim', type=int, default=64, help='Dimension of the node hidden state. Default is 64.')
ap.add_argument('--num-heads', type=int, default=8, help='Number of the attention heads. Default is 8.')
ap.add_argument('--attn-vec-dim', type=int, default=128, help='Dimension of the attention vector. Default is 128.')
ap.add_argument('--rnn-type', default='RotatE0', help='Type of the aggregator. Default is RotatE0.')
ap.add_argument('--epoch', type=int, default=100, help='Number of epochs. Default is 100.')
ap.add_argument('--patience', type=int, default=5, help='Patience. Default is 5.')
ap.add_argument('--batch-size', type=int, default=8, help='Batch size. Default is 8.')
ap.add_argument('--samples', type=int, default=100, help='Number of neighbors sampled. Default is 100.')
ap.add_argument('--repeat', type=int, default=1, help='Repeat the training and testing for N times. Default is 1.')
ap.add_argument('--save-postfix', default='DBLP', help='Postfix for the saved model and result. Default is DBLP.')
args = ap.parse_args()
run_model_DBLP(args.feats_type, args.hidden_dim, args.num_heads, args.attn_vec_dim, args.rnn_type,
args.epoch, args.patience, args.batch_size, args.samples, args.repeat, args.save_postfix)