Skip to content

Commit

Permalink
Fix ONNX dynamic axes export support with onnx simplifier, make onnx …
Browse files Browse the repository at this point in the history
…simplifier optional (ultralytics#2856)

* Ensure dynamic export works succesfully, onnx simplifier optional

* Update export.py

* add dashes

Co-authored-by: Tim <tim.stokman@hal24k.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
3 people authored Apr 20, 2021
1 parent 8324d73 commit 1a1bc8a
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions models/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path') # from yolov5/models/
parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path')
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') # height, width
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes')
parser.add_argument('--grid', action='store_true', help='export Detect() layer grid')
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes') # ONNX-only
parser.add_argument('--simplify', action='store_true', help='simplify ONNX model') # ONNX-only
opt = parser.parse_args()
opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
print(opt)
Expand Down Expand Up @@ -58,7 +59,7 @@
model.model[-1].export = not opt.grid # set Detect() layer grid export
y = model(img) # dry run

# TorchScript export
# TorchScript export -----------------------------------------------------------------------------------------------
prefix = colorstr('TorchScript:')
try:
print(f'\n{prefix} starting export with torch {torch.__version__}...')
Expand All @@ -69,7 +70,7 @@
except Exception as e:
print(f'{prefix} export failure: {e}')

# ONNX export
# ONNX export ------------------------------------------------------------------------------------------------------
prefix = colorstr('ONNX:')
try:
import onnx
Expand All @@ -87,21 +88,24 @@
# print(onnx.helper.printable_graph(model_onnx.graph)) # print

# Simplify
try:
check_requirements(['onnx-simplifier'])
import onnxsim

print(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
model_onnx, check = onnxsim.simplify(model_onnx)
assert check, 'assert check failed'
onnx.save(model_onnx, f)
except Exception as e:
print(f'{prefix} simplifier failure: {e}')
if opt.simplify:
try:
check_requirements(['onnx-simplifier'])
import onnxsim

print(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
model_onnx, check = onnxsim.simplify(model_onnx,
dynamic_input_shape=opt.dynamic,
input_shapes={'images': list(img.shape)} if opt.dynamic else None)
assert check, 'assert check failed'
onnx.save(model_onnx, f)
except Exception as e:
print(f'{prefix} simplifier failure: {e}')
print(f'{prefix} export success, saved as {f}')
except Exception as e:
print(f'{prefix} export failure: {e}')

# CoreML export
# CoreML export ----------------------------------------------------------------------------------------------------
prefix = colorstr('CoreML:')
try:
import coremltools as ct
Expand Down

0 comments on commit 1a1bc8a

Please sign in to comment.