Skip to content

Commit

Permalink
Speedup network latency benchmark.
Browse files Browse the repository at this point in the history
Also fix a simple data_format problem.
  • Loading branch information
mingxingtan committed Apr 24, 2020
1 parent 73bcc22 commit 006668f
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions efficientdet/model_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,14 @@ def __init__(self,
self.batch_size = batch_size
self.num_classes = num_classes
self.data_format = data_format
self.inputs_shape = [self.batch_size, image_size[0], image_size[1], 3]
self.labels_shape = [self.batch_size, self.num_classes]
self.image_size = image_size

if data_format == 'channels_first':
self.inputs_shape = [self.batch_size, 3, image_size[0], image_size[1]]
else:
self.inputs_shape = [self.batch_size, image_size[0], image_size[1], 3]

def build_model(self, inputs: tf.Tensor,
is_training: bool = False) -> List[tf.Tensor]:
"""Build model with inputs and labels and print out model stats."""
Expand Down Expand Up @@ -376,15 +380,23 @@ def benchmark_model(self, warmup_runs, bm_runs, num_threads,
# Don't use tf.group because XLA removes the whole graph for tf.group.
output = tf.group(*output)

output_name = [output.name]
input_name = inputs.name
graphdef = tf.graph_util.convert_variables_to_constants(
sess, sess.graph_def, output_name)

with tf.Graph().as_default(), tf.Session(config=sess_config) as sess:
tf.import_graph_def(graphdef, name='')

for i in range(warmup_runs):
start_time = time.time()
sess.run(output, feed_dict={inputs: img})
sess.run(output_name, feed_dict={input_name: img})
print('Warm up: {} {:.4f}s'.format(i, time.time() - start_time))

print('Start benchmark runs total={}'.format(bm_runs))
start = time.perf_counter()
for i in range(bm_runs):
sess.run(output, feed_dict={inputs: img})
sess.run(output_name, feed_dict={input_name: img})
end = time.perf_counter()
inference_time = (end - start) / 10
print('Per batch inference time: ', inference_time)
Expand All @@ -394,7 +406,7 @@ def benchmark_model(self, warmup_runs, bm_runs, num_threads,
run_options = tf.RunOptions()
run_options.trace_level = tf.RunOptions.FULL_TRACE
run_metadata = tf.RunMetadata()
sess.run(output, feed_dict={inputs: img},
sess.run(output_name, feed_dict={input_name: img},
options=run_options, run_metadata=run_metadata)
logging.info('Dumping trace to %s', trace_filename)
trace_dir = os.path.dirname(trace_filename)
Expand Down

0 comments on commit 006668f

Please sign in to comment.