-
Notifications
You must be signed in to change notification settings - Fork 2
/
run.py
118 lines (99 loc) · 4.06 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from __future__ import print_function
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from lenet import base_softmax, dropmax
from accumulator import Accumulator
from mnist import mnist_input
import time
import os
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--mnist_path', type=str, default='./mnist')
parser.add_argument('--model', type=str, default='softmax')
parser.add_argument('--N', type=int, default=100)
parser.add_argument('--batch_size', type=int, default=100)
parser.add_argument('--n_epochs', type=int, default=200)
parser.add_argument('--save_freq', type=int, default=20)
parser.add_argument('--savedir', type=str, default=None)
parser.add_argument('--mode', type=str, default='train')
parser.add_argument('--gpu_num', type=int, default=0)
args = parser.parse_args()
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_num)
savedir = './results/%s'%args.model if args.savedir is None else args.savedir
if not os.path.isdir(savedir):
os.makedirs(savedir)
bs = args.batch_size
N = args.N
xtr, ytr, xva, yva, xte, yte = mnist_input(args.mnist_path, [N]*10)
n_train_batches, n_val_batches, n_test_batches = int(N*10/bs), int(N*10/bs), int(10000/bs)
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
if args.model == 'softmax':
model = base_softmax
elif args.model == 'dropmax':
model = dropmax
else:
raise ValueError('Invalid model %s' % args.model)
net = model(x, y, True)
tnet = model(x, y, False, reuse=True)
def train():
if args.model == 'softmax':
loss = net['cent'] + net['wd']
else:
loss = net['cent'] + net['wd'] + net['kl'] + net['aux'] + net['neg_ent']
global_step = tf.train.get_or_create_global_step()
lr_step = int(n_train_batches*args.n_epochs/3)
lr = tf.train.piecewise_constant(tf.cast(global_step, tf.int32),
[lr_step, lr_step*2], [1e-3, 1e-4, 1e-5])
train_op = tf.train.AdamOptimizer(lr).minimize(loss, global_step=global_step)
saver = tf.train.Saver(net['weights'])
logfile = open(os.path.join(savedir, 'train.log'), 'wb', 0)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
train_logger = Accumulator('cent', 'acc')
train_to_run = [train_op, net['cent'], net['acc']]
val_logger = Accumulator('cent', 'acc')
val_to_run = [tnet['cent'], tnet['acc']]
for i in range(args.n_epochs):
# shuffle the training data every epoch
xytr = np.concatenate((xtr, ytr), axis=1)
np.random.shuffle(xytr)
xtr_, ytr_ = xytr[:,:784], xytr[:,784:]
line = 'Epoch %d start, learning rate %f' % (i+1, sess.run(lr))
print(line)
logfile.write((line + '\n').encode())
train_logger.clear()
start = time.time()
for j in range(n_train_batches):
bx, by = xtr_[j*bs:(j+1)*bs,:], ytr_[j*bs:(j+1)*bs,:]
train_logger.accum(sess.run(train_to_run, {x:bx, y:by}))
train_logger.print_(header='train', epoch=i+1,
time=time.time()-start, logfile=logfile)
val_logger.clear()
for j in range(n_val_batches):
bx, by = xva[j*bs:(j+1)*bs,:], yva[j*bs:(j+1)*bs,:]
val_logger.accum(sess.run(val_to_run, {x:bx, y:by}))
val_logger.print_(header='val', epoch=i+1,
time=time.time()-start, logfile=logfile)
print()
logfile.write(b'\n')
logfile.close()
saver.save(sess, os.path.join(savedir, 'model'))
def test():
sess = tf.Session()
saver = tf.train.Saver(tnet['weights'])
saver.restore(sess, os.path.join(savedir, 'model'))
logfile = open(os.path.join(savedir, 'test.log'), 'wb', 0)
logger = Accumulator('cent', 'acc')
logger.accum(sess.run([tnet['cent'], tnet['acc']], {x:xte, y:yte}))
logger.print_(header='test', logfile=logfile)
logfile.close()
if __name__=='__main__':
if args.mode == 'train':
train()
elif args.mode == 'test':
test()
else:
raise ValueError('Invalid mode %s' % args.mode)