-
Notifications
You must be signed in to change notification settings - Fork 12
/
Plotter.py
121 lines (88 loc) · 3.43 KB
/
Plotter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#!/usr/bin/python
# -*- coding: utf-8 -*-
"""
Plotter.py
Visualization functions.
"""
import torch
from torch.autograd.variable import Variable
import networkx as nx
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import numpy as np
import datasets
from options import Options
import models
from utils import load_checkpoint
import pdb
__author__ = "Pau Riba"
__email__ = "priba@cvc.uab.cat"
def plot_graph(v, am, outname):
fig = plt.figure()
A = (am.abs().sum(2)>0).cpu().numpy()
g = nx.from_numpy_matrix(A)
position = {k: v[k].cpu().numpy() for k in range(v.size(0))}
center = np.mean(list(position.values()),axis=0)
max_pos = np.max(np.abs(list(position.values())-center))
nx.draw(g, pos=position)
plt.ylim([center[1]-max_pos-0.5, center[1]+max_pos+0.5])
plt.xlim([center[0]-max_pos-0.5, center[0]+max_pos+0.5])
plt.savefig(outname)
def plot_dataset(data, net, cuda):
for i in range(len(data)):
print(data.getId(i))
v, am, _ = data[i]
g_size = torch.LongTensor([v.size(0)])
plot_graph( v, am, 'original.png')
v, am = v.unsqueeze(0), am.unsqueeze(0)
if cuda:
v, am, g_size = v.cuda(), am.cuda(), g_size.cuda()
v, am = Variable(v, volatile=True), Variable(am, volatile=True)
# Compute features
v = net(v, am, g_size, output='nodes')
v, am = v.squeeze(0).data, am.squeeze(0).data
plot_graph( v, am, 'processed.png')
raw_input("Press Enter to continue...")
def main():
print('Prepare dataset')
# Dataset
data_train, data_valid, data_test = datasets.load_data(args.dataset, args.data_path, args.representation, args.normalization)
print('Create model')
if args.representation=='adj':
print('\t* Discrete Edges')
net = models.MpnnGGNN(in_size=2, e=[1], hidden_state_size=args.hidden_size, message_size=args.hidden_size, n_layers=args.nlayers, discrete_edge=True, out_type='regression', target_size=data_train.getTargetSize())
elif args.representation=='feat':
print('\t* Feature Edges')
net = models.MpnnGGNN(in_size=2, e=2, hidden_state_size=args.hidden_size, message_size=args.hidden_size, n_layers=args.nlayers, discrete_edge=False, out_type='regression', target_size=data_train.getTargetSize())
else:
raise NameError('Representation ' + args.representation + ' not implemented!')
print('Check CUDA')
if args.cuda and args.ngpu > 1:
print('\t* Data Parallel **NOT TESTED**')
net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))
if args.cuda:
print('\t* CUDA')
net = net.cuda()
if args.load is not None:
print('Loading model')
checkpoint = load_checkpoint(args.load)
net.load_state_dict(checkpoint['state_dict'])
start_epoch = checkpoint['epoch']
best_acc = checkpoint['best_acc']
else:
raise NameError('Load path must be set!')
# Train
plot_dataset(data_train, net, args.ngpu>0)
# Validation
plot_dataset(data_valid, net, args.ngpu>0)
# Test
plot_dataset(data_test, net, args.ngpu>0)
if __name__ == '__main__':
# Parse options
args = Options().parse()
# Check cuda
args.cuda = args.ngpu > 0 and torch.cuda.is_available()
if args.load is None:
raise Exception('Cannot plot without loading a model.')
main()