-
Notifications
You must be signed in to change notification settings - Fork 0
/
freeze.py
74 lines (58 loc) · 2.74 KB
/
freeze.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os, shutil
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
def freeze_graph(model_folder, output_folder):
# We retrieve our checkpoint fullpath
try:
checkpoint = tf.train.get_checkpoint_state(model_folder)
input_checkpoint = checkpoint.model_checkpoint_path
print("[INFO] input_checkpoint:", input_checkpoint)
except:
input_checkpoint = model_folder
print("[INFO] Model folder", model_folder)
# We clear devices to allow TensorFlow to control on which device it will load operations
clear_devices = True
# We import the meta graph and retrieve a Saver
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)
# We start a session and restore the graph weights
with tf.Session() as sess:
saver.restore(sess, input_checkpoint)
input_tensor = sess.graph.get_tensor_by_name("InputData/X:0")
output_tensor = sess.graph.get_tensor_by_name("FullyConnected/Softmax:0")
builder = tf.saved_model.builder.SavedModelBuilder(output_folder)
prediction_signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs={
"mfcc": tf.saved_model.utils.build_tensor_info(input_tensor)
},
outputs={
"emotions": tf.saved_model.utils.build_tensor_info(output_tensor)
},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME,
)
tensor_info_x = tf.saved_model.utils.build_tensor_info(input_tensor)
tensor_info_y = tf.saved_model.utils.build_tensor_info(output_tensor)
classification_signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs={tf.saved_model.signature_constants.CLASSIFY_INPUTS: tensor_info_x},
outputs={
tf.saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES: tensor_info_y
},
method_name=tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME,
)
builder.add_meta_graph_and_variables(
sess,
[tf.saved_model.tag_constants.SERVING],
signature_def_map={
"predict_emotions": prediction_signature,
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: classification_signature,
},
main_op=tf.tables_initializer(),
)
builder.save()
print("[INFO] output_graph:", model_folder)
print("[INFO] all done")
if __name__ == '__main__':
input_folder = os.path.abspath('model')
output_folder = os.path.abspath('output')
if os.path.exists(output_folder):
shutil.rmtree(output_folder)
freeze_graph(input_folder, output_folder)