Skip to content

Commit

Permalink
TensorFlow.js export enhancements (#4905)
Browse files Browse the repository at this point in the history
* Add arguments to TensorFlow NMS call

* Add regex substitution to reorder Identity_*

* Delete reorder in docstring

* Cleanup

* Cleanup2

* Removed `+ \` on string ends (not needed)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
zldrobit and glenn-jocher authored Sep 24, 2021
1 parent ce7fa81 commit 2c2ef25
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
29 changes: 27 additions & 2 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
yolov5s.tflite
TensorFlow.js:
$ # Edit yolov5s_web_model/model.json to sort Identity* in ascending order
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
$ npm install
$ ln -s ../../yolov5/yolov5s_web_model public/yolov5s_web_model
Expand Down Expand Up @@ -213,16 +212,32 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
# YOLOv5 TensorFlow.js export
try:
check_requirements(('tensorflowjs',))
import re
import tensorflowjs as tfjs

print(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
f = str(file).replace('.pt', '_web_model') # js dir
f_pb = file.with_suffix('.pb') # *.pb path
f_json = f + '/model.json' # *.json path

cmd = f"tensorflowjs_converter --input_format=tf_frozen_model " \
f"--output_node_names='Identity,Identity_1,Identity_2,Identity_3' {f_pb} {f}"
subprocess.run(cmd, shell=True)

json = open(f_json).read()
with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
subst = re.sub(
r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
r'"Identity.?.?": {"name": "Identity.?.?"}, '
r'"Identity.?.?": {"name": "Identity.?.?"}, '
r'"Identity.?.?": {"name": "Identity.?.?"}}}',
r'{"outputs": {"Identity": {"name": "Identity"}, '
r'"Identity_1": {"name": "Identity_1"}, '
r'"Identity_2": {"name": "Identity_2"}, '
r'"Identity_3": {"name": "Identity_3"}}}',
json)
j.write(subst)

print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'\n{prefix} export failure: {e}')
Expand All @@ -243,6 +258,10 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
dynamic=False, # ONNX/TF: dynamic axes
simplify=False, # ONNX: simplify model
opset=12, # ONNX: opset version
topk_per_class=100, # TF.js NMS: topk per class to keep
topk_all=100, # TF.js NMS: topk for all classes to keep
iou_thres=0.45, # TF.js NMS: IoU threshold
conf_thres=0.25 # TF.js NMS: confidence threshold
):
t = time.time()
include = [x.lower() for x in include]
Expand Down Expand Up @@ -290,7 +309,9 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
if any(tf_exports):
pb, tflite, tfjs = tf_exports[1:]
assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'
model = export_saved_model(model, im, file, dynamic, tf_nms=tfjs, agnostic_nms=tfjs) # keras model
model = export_saved_model(model, im, file, dynamic, tf_nms=tfjs, agnostic_nms=tfjs,
topk_per_class=topk_per_class, topk_all=topk_all, conf_thres=conf_thres,
iou_thres=iou_thres) # keras model
if pb or tfjs: # pb prerequisite to tfjs
export_pb(model, im, file)
if tflite:
Expand Down Expand Up @@ -319,6 +340,10 @@ def parse_opt():
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version')
parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
parser.add_argument('--include', nargs='+',
default=['torchscript', 'onnx'],
help='available formats are (torchscript, onnx, coreml, saved_model, pb, tflite, tfjs)')
Expand Down
2 changes: 1 addition & 1 deletion models/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ class AgnosticNMS(keras.layers.Layer):
# TF Agnostic NMS
def call(self, input, topk_all, iou_thres, conf_thres):
# wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450
return tf.map_fn(self._nms, input,
return tf.map_fn(lambda x: self._nms(x, topk_all, iou_thres, conf_thres), input,
fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32),
name='agnostic_nms')

Expand Down

0 comments on commit 2c2ef25

Please sign in to comment.