-
Notifications
You must be signed in to change notification settings - Fork 19
/
main.py
49 lines (38 loc) · 1.37 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""
Between-class Learning for Image Classification.
Yuji Tokozume, Yoshitaka Ushiku, and Tatsuya Harada
"""
import sys
import os
import chainer
import opts
import models
import dataset
from train import Trainer
def main():
opt = opts.parse()
chainer.cuda.get_device_from_id(opt.gpu).use()
for i in range(1, opt.nTrials + 1):
print('+-- Trial {} --+'.format(i))
train(opt, i)
def train(opt, trial):
model = getattr(models, opt.netType)(opt.nClasses)
model.to_gpu()
optimizer = chainer.optimizers.NesterovAG(lr=opt.LR, momentum=opt.momentum)
optimizer.setup(model)
optimizer.add_hook(chainer.optimizer.WeightDecay(opt.weightDecay))
train_iter, val_iter = dataset.setup(opt)
trainer = Trainer(model, optimizer, train_iter, val_iter, opt)
for epoch in range(1, opt.nEpochs + 1):
train_loss, train_top1 = trainer.train(epoch)
val_top1 = trainer.val()
sys.stderr.write('\r\033[K')
sys.stdout.write(
'| Epoch: {}/{} | Train: LR {} Loss {:.3f} top1 {:.2f} | Val: top1 {:.2f}\n'.format(
epoch, opt.nEpochs, trainer.optimizer.lr, train_loss, train_top1, val_top1))
sys.stdout.flush()
if opt.save != 'None':
chainer.serializers.save_npz(
os.path.join(opt.save, 'model_trial{}.npz'.format(trial)), model)
if __name__ == '__main__':
main()