-
Notifications
You must be signed in to change notification settings - Fork 129
/
save_model_mobilenetv2.py
executable file
·52 lines (38 loc) · 1.98 KB
/
save_model_mobilenetv2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#! /usr/bin/env python
# coding=utf-8
import tensorflow as tf
from core.yolov3_mobilenetv2 import YOLOV3
from tensorflow.saved_model import signature_def_utils, signature_constants, tag_constants
from tensorflow.saved_model import utils as save_model_utils
img_size = 608
num_channels = 3
with tf.name_scope('input'):
input_data = tf.placeholder(dtype=tf.float32, shape=[None, img_size, img_size, num_channels], name='input_data')
model = YOLOV3(input_data, trainable=tf.cast(False,dtype=tf.bool))
print(model.conv_sbbox, model.conv_mbbox, model.conv_lbbox)
print("{} trainable variables".format(len(tf.trainable_variables())))
def model_transfer(savemodel_file_path, ckpt_file):
with tf.Session() as sess:
saver = tf.train.Saver()
saver.restore(sess, ckpt_file)
x_op = sess.graph.get_operation_by_name('input/input_data')
x = x_op.outputs[0]
pred_op = sess.graph.get_operation_by_name('pred_multi_scale/concat')
pred = pred_op.outputs[0]
print("prediction signature")
prediction_signature = signature_def_utils.build_signature_def(
inputs={"input": save_model_utils.build_tensor_info(x)},
outputs={"output":save_model_utils.build_tensor_info(pred)},
method_name=signature_constants.PREDICT_METHOD_NAME)
builder = tf.saved_model.builder.SavedModelBuilder(savemodel_file_path)
builder.add_meta_graph_and_variables(sess, [tag_constants.SERVING],
signature_def_map={"predict": prediction_signature})
print("saved model 已经导出成功...")
builder.save()
sess.close()
if __name__ == "__main__":
savemodel_file_path = "./savemodel/yolov3/2"
ckpt_file = "./checkpoint/yolov3_train_loss_199.6682.ckpt_10"
output_node_names = ["input/input_data", "pred_sbbox/concat_2", "pred_mbbox/concat_2", "pred_lbbox/concat_2",
"pred_multi_scale/concat"]
model_transfer(savemodel_file_path, ckpt_file)