-
Notifications
You must be signed in to change notification settings - Fork 45
/
exporter.py
378 lines (319 loc) · 14.7 KB
/
exporter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 30 15:13:27 2018
@author: shirhe-lyh
"""
"""Functions to export inference graph.
Modified from: TensorFlow models/research/object_detection/export.py
"""
import logging
import os
import tempfile
import tensorflow as tf
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import saver as saver_lib
slim = tf.contrib.slim
# TODO: Replace with freeze_graph.freeze_graph_with_def_protos when
# newer version of Tensorflow becomes more common.
def freeze_graph_with_def_protos(
input_graph_def,
input_saver_def,
input_checkpoint,
output_node_names,
restore_op_name,
filename_tensor_name,
clear_devices,
initializer_nodes,
variable_names_blacklist=''):
"""Converts all variables in a graph and checkpoint into constants."""
del restore_op_name, filename_tensor_name # Unused by updated loading code.
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
if not saver_lib.checkpoint_exists(input_checkpoint):
raise ValueError(
"Input checkpoint ' + input_checkpoint + ' does not exist!")
if not output_node_names:
raise ValueError(
'You must supply the name of a node to --output_node_names.')
# Remove all the explicit device specifications for this node. This helps
# to make the graph more portable.
if clear_devices:
for node in input_graph_def.node:
node.device = ''
with tf.Graph().as_default():
tf.import_graph_def(input_graph_def, name='')
config = tf.ConfigProto(graph_options=tf.GraphOptions())
with session.Session(config=config) as sess:
if input_saver_def:
saver = saver_lib.Saver(saver_def=input_saver_def)
saver.restore(sess, input_checkpoint)
else:
var_list = {}
reader = pywrap_tensorflow.NewCheckpointReader(
input_checkpoint)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
try:
tensor = sess.graph.get_tensor_by_name(key + ':0')
except KeyError:
# This tensor doesn't exist in the graph (for example
# it's 'global_step' or a similar housekeeping element)
# so skip it.
continue
var_list[key] = tensor
saver = saver_lib.Saver(var_list=var_list)
saver.restore(sess, input_checkpoint)
if initializer_nodes:
sess.run(initializer_nodes)
variable_names_blacklist = (variable_names_blacklist.split(',') if
variable_names_blacklist else None)
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
output_node_names.split(','),
variable_names_blacklist=variable_names_blacklist)
return output_graph_def
def replace_variable_values_with_moving_averages(graph,
current_checkpoint_file,
new_checkpoint_file):
"""Replaces variable values in the checkpoint with their moving averages.
If the current checkpoint has shadow variables maintaining moving averages
of the variables defined in the graph, this function generates a new
checkpoint where the variables contain the values of their moving averages.
Args:
graph: A tf.Graph object.
current_checkpoint_file: A checkpoint both original variables and
their moving averages.
new_checkpoint_file: File path to write a new checkpoint.
"""
with graph.as_default():
variable_averages = tf.train.ExponentialMovingAverage(0.0)
ema_variables_to_restore = variable_averages.variables_to_restore()
with tf.Session() as sess:
read_saver = tf.train.Saver(ema_variables_to_restore)
read_saver.restore(sess, current_checkpoint_file)
write_saver = tf.train.Saver()
write_saver.save(sess, new_checkpoint_file)
def _image_tensor_input_placeholder(input_shape=None):
"""Returns input placeholder and a 4-D uint8 image tensor."""
if input_shape is None:
input_shape = (None, None, None, 3)
input_tensor = tf.placeholder(
dtype=tf.uint8, shape=input_shape, name='image_tensor')
return input_tensor, input_tensor
def _encoded_image_string_tensor_input_placeholder():
"""Returns input that accepts a batch of PNG or JPEG strings.
Returns:
A tuple of input placeholder and the output decoded images.
"""
batch_image_str_placeholder = tf.placeholder(
dtype=tf.string,
shape=[None],
name='encoded_image_string_tensor')
def decode(encoded_image_string_tensor):
image_tensor = tf.image.decode_image(encoded_image_string_tensor,
channels=3)
image_tensor.set_shape((None, None, 3))
return image_tensor
return (batch_image_str_placeholder,
tf.map_fn(
decode,
elems=batch_image_str_placeholder,
dtype=tf.uint8,
parallel_iterations=32,
back_prop=False))
input_placeholder_fn_map = {
'image_tensor': _image_tensor_input_placeholder,
'encoded_image_string_tensor':
_encoded_image_string_tensor_input_placeholder,
# 'tf_example': _tf_example_input_placeholder,
}
def _add_output_tensor_nodes(postprocessed_tensors,
output_collection_name='inference_op'):
"""Adds output nodes.
Adjust according to specified implementations.
Args:
postprocessed_tensors: A dictionary containing the postprocessed
results.
output_collection_name: Name of collection to add output tensors to.
Returns:
A tensor dict containing the added output tensor nodes.
"""
outputs = {}
for key, value in postprocessed_tensors.items():
outputs[key] = tf.identity(value, name=key)
for output_key in outputs:
tf.add_to_collection(output_collection_name, outputs[output_key])
return outputs
def write_frozen_graph(frozen_graph_path, frozen_graph_def):
"""Writes frozen graph to disk.
Args:
frozen_graph_path: Path to write inference graph.
frozen_graph_def: tf.GraphDef holding frozen graph.
"""
with gfile.GFile(frozen_graph_path, 'wb') as f:
f.write(frozen_graph_def.SerializeToString())
logging.info('%d ops in the final graph.', len(frozen_graph_def.node))
def write_saved_model(saved_model_path,
frozen_graph_def,
inputs,
outputs):
"""Writes SavedModel to disk.
If checkpoint_path is not None bakes the weights into the graph thereby
eliminating the need of checkpoint files during inference. If the model
was trained with moving averages, setting use_moving_averages to True
restores the moving averages, otherwise the original set of variables
is restored.
Args:
saved_model_path: Path to write SavedModel.
frozen_graph_def: tf.GraphDef holding frozen graph.
inputs: The input image tensor.
outputs: A tensor dictionary containing the outputs of a slim model.
"""
with tf.Graph().as_default():
with session.Session() as sess:
tf.import_graph_def(frozen_graph_def, name='')
builder = tf.saved_model.builder.SavedModelBuilder(
saved_model_path)
tensor_info_inputs = {
'inputs': tf.saved_model.utils.build_tensor_info(inputs)}
tensor_info_outputs = {}
for k, v in outputs.items():
tensor_info_outputs[k] = tf.saved_model.utils.build_tensor_info(
v)
detection_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs=tensor_info_inputs,
outputs=tensor_info_outputs,
method_name=signature_constants.PREDICT_METHOD_NAME))
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
detection_signature,
},
)
builder.save()
def write_graph_and_checkpoint(inference_graph_def,
model_path,
input_saver_def,
trained_checkpoint_prefix):
"""Writes the graph and the checkpoint into disk."""
for node in inference_graph_def.node:
node.device = ''
with tf.Graph().as_default():
tf.import_graph_def(inference_graph_def, name='')
with session.Session() as sess:
saver = saver_lib.Saver(saver_def=input_saver_def,
save_relative_paths=True)
saver.restore(sess, trained_checkpoint_prefix)
saver.save(sess, model_path)
def _get_outputs_from_inputs(input_tensors, model,
output_collection_name):
inputs = tf.to_float(input_tensors)
preprocessed_inputs = model.preprocess(inputs)
output_tensors = model.predict(preprocessed_inputs)
postprocessed_tensors = model.postprocess(output_tensors)
return _add_output_tensor_nodes(postprocessed_tensors,
output_collection_name)
def _build_model_graph(input_type, model, input_shape,
output_collection_name, graph_hook_fn):
"""Build the desired graph."""
if input_type not in input_placeholder_fn_map:
raise ValueError('Unknown input type: {}'.format(input_type))
placeholder_args = {}
if input_shape is not None:
if input_type != 'image_tensor':
raise ValueError("Can only specify input shape for 'image_tensor' "
'inputs.')
placeholder_args['input_shape'] = input_shape
placeholder_tensor, input_tensors = input_placeholder_fn_map[input_type](
**placeholder_args)
outputs = _get_outputs_from_inputs(
input_tensors=input_tensors,
model=model,
output_collection_name=output_collection_name)
# Add global step to the graph
slim.get_or_create_global_step()
if graph_hook_fn: graph_hook_fn()
return outputs, placeholder_tensor
def export_inference_graph(input_type,
model,
trained_checkpoint_prefix,
output_directory,
input_shape=None,
use_moving_averages=None,
output_collection_name='inference_op',
additional_output_tensor_names=None,
graph_hook_fn=None):
"""Exports inference graph for the desired graph.
Args:
input_type: Type of input for the graph. Can be one of ['image_tensor',
'encoded_image_string_tensor', 'tf_example']. In this file,
input_type must be 'image_tensor'.
model: A model defined by model.py.
trained_checkpoint_prefix: Path to the trained checkpoint file.
output_directory: Path to write outputs.
input_shape: Sets a fixed shape for an 'image_tensor' input. If not
specified, will default to [None, None, None, 3].
use_moving_averages: A boolean indicating whether the
tf.train.ExponentialMovingAverage should be used or not.
output_collection_name: Name of collection to add output tensors to.
If None, does not add output tensors to a collection.
additional_output_tensor_names: List of additional output tensors to
include in the frozen graph.
"""
tf.gfile.MakeDirs(output_directory)
frozen_graph_path = os.path.join(output_directory,
'frozen_inference_graph.pb')
saved_model_path = os.path.join(output_directory, 'saved_model')
model_path = os.path.join(output_directory, 'model.ckpt')
outputs, placeholder_tensor = _build_model_graph(
input_type=input_type,
model=model,
input_shape=input_shape,
output_collection_name=output_collection_name,
graph_hook_fn=graph_hook_fn)
saver_kwargs = {}
if use_moving_averages:
# This check is to be compatible with both version of SaverDef.
if os.path.isfile(trained_checkpoint_prefix):
saver_kwargs['write_version'] = saver_pb2.SaverDef.V1
temp_checkpoint_prefix = tempfile.NamedTemporaryFile().name
else:
temp_checkpoint_prefix = tempfile.mkdtemp()
replace_variable_values_with_moving_averages(
tf.get_default_graph(), trained_checkpoint_prefix,
temp_checkpoint_prefix)
checkpoint_to_use = temp_checkpoint_prefix
else:
checkpoint_to_use = trained_checkpoint_prefix
saver = tf.train.Saver(**saver_kwargs)
input_saver_def = saver.as_saver_def()
write_graph_and_checkpoint(
inference_graph_def=tf.get_default_graph().as_graph_def(),
model_path=model_path,
input_saver_def=input_saver_def,
trained_checkpoint_prefix=checkpoint_to_use)
if additional_output_tensor_names is not None:
output_node_names = ','.join(outputs.keys()+
additional_output_tensor_names)
else:
output_node_names = ','.join(outputs.keys())
frozen_graph_def = freeze_graph_with_def_protos(
input_graph_def=tf.get_default_graph().as_graph_def(),
input_saver_def=input_saver_def,
input_checkpoint=checkpoint_to_use,
output_node_names=output_node_names,
restore_op_name='save/restore_all',
filename_tensor_name='save/Const:0',
clear_devices=True,
initializer_nodes='')
write_frozen_graph(frozen_graph_path, frozen_graph_def)
write_saved_model(saved_model_path, frozen_graph_def,
placeholder_tensor, outputs)