Skip to content

Commit

Permalink
ONNX Simplifier (ultralytics#2815)
Browse files Browse the repository at this point in the history
* ONNX Simplifier

Add ONNX Simplifier to ONNX export pipeline in export.py. Will auto-install onnx-simplifier if onnx is installed but onnx-simplifier is not.

* Update general.py
  • Loading branch information
glenn-jocher authored Apr 16, 2021
1 parent e1e4245 commit b82cab5
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 16 deletions.
45 changes: 30 additions & 15 deletions models/export.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Exports a YOLOv5 *.pt model to ONNX and TorchScript formats
Usage:
$ export PYTHONPATH="$PWD" && python models/export.py --weights ./weights/yolov5s.pt --img 640 --batch 1
$ export PYTHONPATH="$PWD" && python models/export.py --weights yolov5s.pt --img 640 --batch 1
"""

import argparse
Expand All @@ -16,7 +16,7 @@
import models
from models.experimental import attempt_load
from utils.activations import Hardswish, SiLU
from utils.general import set_logging, check_img_size
from utils.general import colorstr, check_img_size, check_requirements, set_logging
from utils.torch_utils import select_device

if __name__ == '__main__':
Expand Down Expand Up @@ -59,46 +59,61 @@
y = model(img) # dry run

# TorchScript export
prefix = colorstr('TorchScript:')
try:
print('\nStarting TorchScript export with torch %s...' % torch.__version__)
print(f'\n{prefix} starting export with torch {torch.__version__}...')
f = opt.weights.replace('.pt', '.torchscript.pt') # filename
ts = torch.jit.trace(model, img, strict=False)
ts.save(f)
print('TorchScript export success, saved as %s' % f)
print(f'{prefix} export success, saved as {f}')
except Exception as e:
print('TorchScript export failure: %s' % e)
print(f'{prefix} export failure: {e}')

# ONNX export
prefix = colorstr('ONNX:')
try:
import onnx

print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
print(f'{prefix} starting export with onnx {onnx.__version__}...')
f = opt.weights.replace('.pt', '.onnx') # filename
torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
output_names=['classes', 'boxes'] if y is None else ['output'],
dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640)
'output': {0: 'batch', 2: 'y', 3: 'x'}} if opt.dynamic else None)

# Checks
onnx_model = onnx.load(f) # load onnx model
onnx.checker.check_model(onnx_model) # check onnx model
# print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
print('ONNX export success, saved as %s' % f)
model_onnx = onnx.load(f) # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model
# print(onnx.helper.printable_graph(model_onnx.graph)) # print

# Simplify
try:
check_requirements(['onnx-simplifier'])
import onnxsim

print(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
model_onnx, check = onnxsim.simplify(model_onnx)
assert check, 'assert check failed'
onnx.save(model_onnx, f)
except Exception as e:
print(f'{prefix} simplifier failure: {e}')
print(f'{prefix} export success, saved as {f}')
except Exception as e:
print('ONNX export failure: %s' % e)
print(f'{prefix} export failure: {e}')

# CoreML export
prefix = colorstr('CoreML:')
try:
import coremltools as ct

print('\nStarting CoreML export with coremltools %s...' % ct.__version__)
print(f'{prefix} starting export with coremltools {onnx.__version__}...')
# convert model from torchscript and apply pixel scaling as per detect.py
model = ct.convert(ts, inputs=[ct.ImageType(name='image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])])
f = opt.weights.replace('.pt', '.mlmodel') # filename
model.save(f)
print('CoreML export success, saved as %s' % f)
print(f'{prefix} export success, saved as {f}')
except Exception as e:
print('CoreML export failure: %s' % e)
print(f'{prefix} export failure: {e}')

# Finish
print('\nExport complete (%.2fs). Visualize with https://github.com/lutzroeder/netron.' % (time.time() - t))
print(f'\nExport complete ({time.time() - t:.2f}s). Visualize with https://github.com/lutzroeder/netron.')
2 changes: 1 addition & 1 deletion utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def check_requirements(requirements='requirements.txt', exclude=()):
except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
n += 1
print(f"{prefix} {e.req} not found and is required by YOLOv5, attempting auto-update...")
print(subprocess.check_output(f"pip install '{e.req}'", shell=True).decode())
print(subprocess.check_output(f"pip install {e.req}", shell=True).decode())

if n: # if packages updated
source = file.resolve() if 'file' in locals() else requirements
Expand Down

0 comments on commit b82cab5

Please sign in to comment.