-
Notifications
You must be signed in to change notification settings - Fork 30
/
train.py
130 lines (98 loc) · 4.72 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import argparse
import torch
from torch.autograd import Variable
from datetime import datetime, timedelta
import batcher
from batcher import Batcher
import models
from models import ArcBinaryClassifier
parser = argparse.ArgumentParser()
parser.add_argument('--batchSize', type=int, default=128, help='input batch size')
parser.add_argument('--imageSize', type=int, default=32, help='the height / width of the input image to ARC')
parser.add_argument('--glimpseSize', type=int, default=8, help='the height / width of glimpse seen by ARC')
parser.add_argument('--numStates', type=int, default=128, help='number of hidden states in ARC controller')
parser.add_argument('--numGlimpses', type=int, default=6, help='the number glimpses of each image in pair seen by ARC')
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument('--name', default=None, help='Custom name for this configuration. Needed for saving'
' model checkpoints in a separate folder.')
parser.add_argument('--load', default=None, help='the model to load from. Start fresh if not specified.')
def get_pct_accuracy(pred: Variable, target) -> int:
hard_pred = (pred > 0.5).int()
correct = (hard_pred == target).sum().data[0]
accuracy = float(correct) / target.size()[0]
accuracy = int(accuracy * 100)
return accuracy
def train():
opt = parser.parse_args()
if opt.cuda:
batcher.use_cuda = True
models.use_cuda = True
if opt.name is None:
# if no name is given, we generate a name from the parameters.
# only those parameters are taken, which if changed break torch.load compatibility.
opt.name = "{}_{}_{}_{}".format(opt.numGlimpses, opt.glimpseSize, opt.numStates,
"cuda" if opt.cuda else "cpu")
print("Will start training {} with parameters:\n{}\n\n".format(opt.name, opt))
# make directory for storing models.
models_path = os.path.join("saved_models", opt.name)
os.makedirs(models_path, exist_ok=True)
# initialise the model
discriminator = ArcBinaryClassifier(num_glimpses=opt.numGlimpses,
glimpse_h=opt.glimpseSize,
glimpse_w=opt.glimpseSize,
controller_out=opt.numStates)
if opt.cuda:
discriminator.cuda()
# load from a previous checkpoint, if specified.
if opt.load is not None:
discriminator.load_state_dict(torch.load(os.path.join(models_path, opt.load)))
# set up the optimizer.
bce = torch.nn.BCELoss()
if opt.cuda:
bce = bce.cuda()
optimizer = torch.optim.Adam(params=discriminator.parameters(), lr=opt.lr)
# load the dataset in memory.
loader = Batcher(batch_size=opt.batchSize, image_size=opt.imageSize)
# ready to train ...
best_validation_loss = None
saving_threshold = 1.02
last_saved = datetime.utcnow()
save_every = timedelta(minutes=10)
i = -1
while True:
i += 1
X, Y = loader.fetch_batch("train")
pred = discriminator(X)
loss = bce(pred, Y.float())
if i % 10 == 0:
# validate your model
X_val, Y_val = loader.fetch_batch("val")
pred_val = discriminator(X_val)
loss_val = bce(pred_val, Y_val.float())
training_loss = loss.data[0]
validation_loss = loss_val.data[0]
print("Iteration: {} \t Train: Acc={}%, Loss={} \t\t Validation: Acc={}%, Loss={}".format(
i, get_pct_accuracy(pred, Y), training_loss, get_pct_accuracy(pred_val, Y_val), validation_loss
))
if best_validation_loss is None:
best_validation_loss = validation_loss
if best_validation_loss > (saving_threshold * validation_loss):
print("Significantly improved validation loss from {} --> {}. Saving...".format(
best_validation_loss, validation_loss
))
discriminator.save_to_file(os.path.join(models_path, str(validation_loss)))
best_validation_loss = validation_loss
last_saved = datetime.utcnow()
if last_saved + save_every < datetime.utcnow():
print("It's been too long since we last saved the model. Saving...")
discriminator.save_to_file(os.path.join(models_path, str(validation_loss)))
last_saved = datetime.utcnow()
optimizer.zero_grad()
loss.backward()
optimizer.step()
def main() -> None:
train()
if __name__ == "__main__":
main()