forked from piyush2896/ResNet50-Tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
52 lines (39 loc) · 1.55 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import tensorflow as tf
import numpy as np
from model import ResNet50
from datalab import DataLabTrain, DataLabTest
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
def train():
Y_hat, model_params = ResNet50()
#Y_hat = tf.sigmoid(Z)
X = model_params['input']
Y_true = tf.placeholder(dtype=tf.float32, shape=[None, 2])
Z = model_params['out']['Z']
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=Z, labels=Y_true))
train_step = tf.train.AdamOptimizer(1e-3).minimize(loss)
saver = tf.train.Saver()
with tf.Session() as sess:
try:
train_gen = DataLabTrain('./datasets/train_set/').generator()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
ix = 0
for X_true, Y_true_ in train_gen:
ix += 1
if ix % 10 == 0:
l, _ = sess.run([loss, train_step], feed_dict={X:X_true, Y_true:Y_true_})
#acc = np.mean(y.astype('int32') == Y_true_.astype('int32'))
print('epoch: ' + str(ix) + ' loss: ' + str(l))
else:
sess.run([train_step], feed_dict={X: X_true, Y_true: Y_true_})
if ix % 500 == 0:
path = './models/model' + (str(ix))
os.makedirs(path)
saver.save(sess, path + '/model.ckpt')
if ix == 5000:
break
finally:
sess.close()
if __name__ == '__main__':
train()