-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtraining.py
128 lines (103 loc) · 5.75 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from functions import *
from parameters import *
# Train the model with the desired tuning parameters
def train(model, epochs, log_string, saver):
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# Used to determine when to stop the training early
testing_loss_summary = []
# Keep track of which batch iteration is being trained
iteration = 0
display_step = 30 # The progress of the training will be displayed after every 30 batches
stop_early = 0
stop = 3 # If the batch_loss_testing does not decrease in 3 consecutive checks, stop training
per_epoch = 3 # Test the model 3 times per epoch
testing_check = (len(training_sorted)//batch_size//per_epoch)-1
print()
print("Training Model: {}".format(log_string))
train_writer = tf.summary.FileWriter('./logs/1/train/{}'.format(log_string), sess.graph)
test_writer = tf.summary.FileWriter('./logs/1/test/{}'.format(log_string))
for epoch_i in range(1, epochs+1):
batch_loss = 0
batch_time = 0
for batch_i, (input_batch, target_batch, input_length, target_length) in enumerate(
get_batches(training_sorted, batch_size, threshold)):
start_time = time.time()
summary, loss, _ = sess.run([model.merged,
model.cost,
model.train_op],
{model.inputs: input_batch,
model.targets: target_batch,
model.inputs_length: input_length,
model.targets_length: target_length,
model.keep_prob: keep_probability})
batch_loss += loss
end_time = time.time()
batch_time += end_time - start_time
# Record the progress of training
train_writer.add_summary(summary, iteration)
iteration += 1
if batch_i % display_step == 0 and batch_i > 0:
print('Epoch {:>3}/{} Batch {:>4}/{} - Loss: {:>6.3f}, Seconds: {:>4.2f}'
.format(epoch_i,
epochs,
batch_i,
len(training_sorted) // batch_size,
batch_loss / display_step,
batch_time))
batch_loss = 0
batch_time = 0
#### Testing ####
if batch_i % testing_check == 0 and batch_i > 0:
batch_loss_testing = 0
batch_time_testing = 0
for batch_i, (input_batch, target_batch, input_length, target_length) in enumerate(
get_batches(testing_sorted, batch_size, threshold)):
start_time_testing = time.time()
summary, loss = sess.run([model.merged,
model.cost],
{model.inputs: input_batch,
model.targets: target_batch,
model.inputs_length: input_length,
model.targets_length: target_length,
model.keep_prob: 1})
batch_loss_testing += loss
end_time_testing = time.time()
batch_time_testing += end_time_testing - start_time_testing
# Record the progress of testing
test_writer.add_summary(summary, iteration)
n_batches_testing = batch_i + 1
print('Testing Loss: {:>6.3f}, Seconds: {:>4.2f}'
.format(batch_loss_testing / n_batches_testing,
batch_time_testing))
batch_time_testing = 0
# If the batch_loss_testing is at a new minimum, save the model
testing_loss_summary.append(batch_loss_testing)
if batch_loss_testing <= min(testing_loss_summary):
print('New Record!')
stop_early = 0
checkpoint = "./{}.ckpt".format(log_string)
print(checkpoint)
#Savind model at data/dm.ckpt
save_path = saver.save(sess, checkpoint)
print("Done Saving at ", save_path)
print(os.getcwd())
else:
print("No Improvement.")
stop_early += 1
if stop_early == stop:
break
if stop_early == stop:
print("Stopping Training.")
break
for keep_probability in [0.75]:
for num_layers in [3]:
for threshold in [0.75]:
log_string = 'kp={},nl={},th={}'.format(keep_probability,
num_layers,
threshold)
model, saver = build_graph(keep_probability, rnn_size, num_layers, batch_size,
learning_rate,
embedding_size,
direction)
train(model, epochs, log_string, saver)