forked from sirius-ai/MobileFaceNet_TF
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfreeze_graph.py
100 lines (81 loc) · 3.98 KB
/
freeze_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from nets.MobileFaceNet import inference
# from losses.face_losses import cos_loss
from verification import evaluate
from scipy.optimize import brentq
from utils.common import train
from scipy import interpolate
from datetime import datetime
from sklearn import metrics
import tensorflow as tf
import numpy as np
import argparse
import time
import os
from tensorflow.python.framework import graph_util
slim = tf.contrib.slim
def get_parser():
parser = argparse.ArgumentParser(description='parameters to train net')
parser.add_argument('--pretrained_model', type=str, default='', help='Load a pretrained model before training starts.')
parser.add_argument('--output_file', type=str, help='Filename for the exported graphdef protobuf (.pb)')
args = parser.parse_args()
return args
def freeze_graph_def(sess, input_graph_def, output_node_names):
for node in input_graph_def.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in xrange(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']
elif node.op == 'AssignAdd':
node.op = 'Add'
if 'use_locking' in node.attr: del node.attr['use_locking']
# Get the list of important nodes
whitelist_names = []
for node in input_graph_def.node:
if (node.name.startswith('MobileFaceNet') or node.name.startswith('embeddings')):
whitelist_names.append(node.name)
# Replace all the variables in the graph with constants of the same values
output_graph_def = graph_util.convert_variables_to_constants(
sess, input_graph_def, output_node_names.split(","),
variable_names_whitelist=whitelist_names)
return output_graph_def
if __name__ == '__main__':
with tf.Graph().as_default():
args = get_parser()
# define placeholder
inputs = tf.placeholder(name='img_inputs', shape=[None, 112, 112, 3], dtype=tf.float32)
labels = tf.placeholder(name='img_labels', shape=[None, ], dtype=tf.int64)
phase_train_placeholder = tf.placeholder_with_default(tf.constant(False, dtype=tf.bool), shape=None, name='phase_train')
# pretrained model path
pretrained_model = None
if args.pretrained_model:
pretrained_model = os.path.expanduser(args.pretrained_model)
print('Pre-trained model: %s' % pretrained_model)
# identity the input, for inference
inputs = tf.identity(inputs, 'input')
prelogits, net_points = inference(inputs, bottleneck_layer_size=128, phase_train=False, weight_decay=5e-5)
embeddings = tf.nn.l2_normalize(prelogits, 1, 1e-10, name='embeddings')
sess = tf.Session()
# saver to load pretrained model or save model
# MobileFaceNet_vars = [v for v in tf.trainable_variables() if v.name.startswith('MobileFaceNet')]
saver = tf.train.Saver(tf.trainable_variables())
# init all variables
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
# load pretrained model
if pretrained_model:
print('Restoring pretrained model: %s' % pretrained_model)
# ckpt = tf.train.get_checkpoint_state(pretrained_model)
# print(ckpt)
saver.restore(sess, pretrained_model)
# Retrieve the protobuf graph definition and fix the batch norm nodes
input_graph_def = sess.graph.as_graph_def()
# Freeze the graph def
output_graph_def = freeze_graph_def(sess, input_graph_def, 'embeddings')
# Serialize and dump the output graph to the filesystem
with tf.gfile.GFile(args.output_file, 'wb') as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph: %s" % (len(output_graph_def.node), args.output_file))