Skip to content

Commit

Permalink
add trt models and converter
Browse files Browse the repository at this point in the history
  • Loading branch information
SippieCup committed Apr 20, 2020
1 parent 2caa569 commit bb9061e
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 0 deletions.
1 change: 1 addition & 0 deletions models/trt/.gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.trt filter=lfs diff=lfs merge=lfs -text
1 change: 1 addition & 0 deletions models/trt/dmonitoring.metadata
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"input_names": ["input_img"], "output_names": ["face_descs/BiasAdd", "face_prob/Sigmoid", "left_blink_prob/Sigmoid", "left_eye_descs/BiasAdd", "left_eye_prob/Sigmoid", "right_blink_prob/Sigmoid", "right_eye_descs/BiasAdd", "right_eye_prob/Sigmoid"]}
3 changes: 3 additions & 0 deletions models/trt/dmonitoring.trt
Git LFS file not shown
1 change: 1 addition & 0 deletions models/trt/supercombo.metadata
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"input_names": ["desire", "input_imgs", "rnn_state", "traffic_convention"], "output_names": ["add_3/add", "desire_state/Softmax", "flatten/Reshape", "lead/BiasAdd", "left_lane/BiasAdd", "long_a/BiasAdd", "long_v/BiasAdd", "long_x/BiasAdd", "meta/Sigmoid", "path/BiasAdd", "pose/pose/Identity", "right_lane/BiasAdd"]}
3 changes: 3 additions & 0 deletions models/trt/supercombo.trt
Git LFS file not shown
43 changes: 43 additions & 0 deletions tools/keras/convert_trt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/usr/bin/env python3

import os

import json
from pathlib import Path
import tensorflow as tf
import sys

in_model = os.path.expanduser(sys.argv[1])
output = os.path.expanduser(sys.argv[2])
output_path = Path(output)
output_meta = Path('%s/%s.metadata' % (output_path.parent.as_posix(), output_path.stem))


# Reset session
tf.keras.backend.clear_session()
tf.keras.backend.set_learning_phase(0)

model = tf.keras.models.load_model(in_model, compile=False)
session = tf.keras.backend.get_session()

input_names = sorted([layer.op.name for layer in model.inputs])
output_names = sorted([layer.op.name for layer in model.outputs])

# Store additional information in metadata, useful when creating a TensorRT network
meta = {'input_names': input_names, 'output_names': output_names}

graph = session.graph

# Freeze Graph
with graph.as_default():
# Convert variables to constants
graph_frozen = tf.compat.v1.graph_util.convert_variables_to_constants(session, graph.as_graph_def(), output_names)
# Remove training nodes
graph_frozen = tf.compat.v1.graph_util.remove_training_nodes(graph_frozen)
with open(output, 'wb') as output_file, open(output_meta.as_posix(), 'w') as meta_file:
output_file.write(graph_frozen.SerializeToString())
meta_file.write(json.dumps(meta))

print ('Inputs = [%s], Outputs = [%s]' % (input_names, output_names))
print ('Writing metadata to %s' % output_meta.as_posix())
print ('To convert use: \n `convert-to-uff %s`' % (output))

0 comments on commit bb9061e

Please sign in to comment.