-
Notifications
You must be signed in to change notification settings - Fork 0
/
flow_train.py
124 lines (87 loc) · 3.47 KB
/
flow_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from utils.train_utils import *
from utils.plotting.distributions import *
from utils.plotting.plots import *
from load_data import *
from flow_model import INN
import sys, os
import config_flow as c
import opts
opts.parse(sys.argv)
config_str = ""
config_str += "==="*30 + "\n"
config_str += "Config options:\n\n"
for v in dir(c):
if v[0]=='_': continue
s=eval('c.%s'%(v))
config_str += " {:25}\t{}\n".format(v,s)
config_str += "==="*30 + "\n"
print(config_str)
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
train_loader, validate_loader, dataset_size, data_shape, scales = Loader(c.dataset, c.batch_size, c.test, c.scaler, c.weighted)
if c.weighted:
data_shape -= 1
Flow = INN(num_coupling_layers=c.n_blocks, in_dim=data_shape, num_layers=c.n_layers, internal_size=c.n_units)
Flow.define_model_architecture()
Flow.set_optimizer()
print("\n" + "==="*30 + "\n")
print(Flow.model)
print('Total parameters: %d' % sum([np.prod(p.size()) for p in Flow.params_trainable]))
print("\n" + "==="*30 + "\n")
try:
log_dir = c.save_dir
if not os.path.exists(log_dir + '/' + c.dataset + '/' + '/n_epochs_' + str(c.n_epochs)):
os.makedirs(log_dir + '/' + c.dataset + '/' + '/n_epochs_' + str(c.n_epochs))
F_loss_meter = AverageMeter()
for epoch in range(c.n_epochs):
for iteration in range(c.n_its_per_epoch):
i=0
for data in train_loader:
Flow.model.train()
Flow.optim.zero_grad()
if c.weighted:
events = data[:,:-1]
weights = data[:,-1]
gauss_output = Flow.model(events.double())
temp = torch.sum(gauss_output**2/2,1)
f_loss = torch.mean(weights * temp) - torch.mean(weights * Flow.model.log_jacobian(run_forward=False))
else:
events = data / scales
gauss_output = Flow.model(events.double())
f_loss = torch.mean(gauss_output**2/2) - torch.mean(Flow.model.log_jacobian(run_forward=False)) / gauss_output.shape[1]
F_loss_meter.update(f_loss.item())
f_loss.backward()
Flow.optim.step()
i += 1
if epoch == 0 or epoch % c.show_interval == 0:
print_log(epoch, c.n_epochs, i + 1, len(train_loader), Flow.scheduler.optimizer.param_groups[0]['lr'],
c.show_interval, F_loss_meter, F_loss_meter, Flow=True)
elif (epoch + 1) == len(train_loader):
print_log(epoch, c.n_epochs, i + 1, len(train_loader), Flow.scheduler.optimizer.param_groups[0]['lr'],
(i + 1) % c.show_interval, F_loss_meter, F_loss_meter, Flow=True)
F_loss_meter.reset()
if epoch % c.save_interval == 0 or epoch + 1 == c.n_epochs:
if c.save_model == True:
checkpoint_F = {
'epoch': epoch,
'model': Flow.model.state_dict(),
'optimizer': Flow.optim.state_dict(),
}
save_checkpoint(checkpoint_F, log_dir + '/' + c.dataset + '/n_epochs_' + str(c.n_epochs), 'checkpoint_F_epoch_%03d' % (epoch))
if c.test == True:
size = 1000
else:
size = 100000
with torch.no_grad():
real = get_real_data(c.dataset, c.test, size)
noise = torch.randn(size, data_shape).double().to(device)
if c.weighted:
inv = Flow.model(noise, rev=True).detach().numpy().reshape(size,data_shape)
else:
inv = Flow.model(noise, rev=True).detach().numpy().reshape(size,data_shape) * scales
distributions = Distribution(real, inv, 'epoch_%03d' % (epoch) + '_target', log_dir + '/' + c.dataset + '/n_epochs_' + str(c.n_epochs), c.dataset)
distributions.plot()
Flow.scheduler.step()
except:
if c.checkpoint_on_error:
model.save(c.filename + '_ABORT')
raise