From bd74acc10ec3982d8718810cd1fde807b3eb1c60 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 16 Apr 2021 14:03:27 +0200 Subject: [PATCH] ONNX Simplifier (#2815) * ONNX Simplifier Add ONNX Simplifier to ONNX export pipeline in export.py. Will auto-install onnx-simplifier if onnx is installed but onnx-simplifier is not. * Update general.py (cherry picked from commit 1f3e482bce89a348bcdace91dfc89c5e47862066) --- models/export.py | 45 ++++++++++++++++++++++++++++++--------------- utils/general.py | 2 +- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/models/export.py b/models/export.py index 0bb5398e4841..bec9194319c1 100644 --- a/models/export.py +++ b/models/export.py @@ -1,7 +1,7 @@ """Exports a YOLOv5 *.pt model to ONNX and TorchScript formats Usage: - $ export PYTHONPATH="$PWD" && python models/export.py --weights ./weights/yolov5s.pt --img 640 --batch 1 + $ export PYTHONPATH="$PWD" && python models/export.py --weights yolov5s.pt --img 640 --batch 1 """ import argparse @@ -16,7 +16,7 @@ import models from models.experimental import attempt_load from utils.activations import Hardswish, SiLU -from utils.general import set_logging, check_img_size +from utils.general import colorstr, check_img_size, check_requirements, set_logging from utils.torch_utils import select_device if __name__ == '__main__': @@ -59,20 +59,22 @@ y = model(img) # dry run # TorchScript export + prefix = colorstr('TorchScript:') try: - print('\nStarting TorchScript export with torch %s...' % torch.__version__) + print(f'\n{prefix} starting export with torch {torch.__version__}...') f = opt.weights.replace('.pt', '.torchscript.pt') # filename ts = torch.jit.trace(model, img, strict=False) ts.save(f) - print('TorchScript export success, saved as %s' % f) + print(f'{prefix} export success, saved as {f}') except Exception as e: - print('TorchScript export failure: %s' % e) + print(f'{prefix} export failure: {e}') # ONNX export + prefix = colorstr('ONNX:') try: import onnx - print('\nStarting ONNX export with onnx %s...' % onnx.__version__) + print(f'{prefix} starting export with onnx {onnx.__version__}...') f = opt.weights.replace('.pt', '.onnx') # filename torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'], output_names=['classes', 'boxes'] if y is None else ['output'], @@ -80,25 +82,38 @@ 'output': {0: 'batch', 2: 'y', 3: 'x'}} if opt.dynamic else None) # Checks - onnx_model = onnx.load(f) # load onnx model - onnx.checker.check_model(onnx_model) # check onnx model - # print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model - print('ONNX export success, saved as %s' % f) + model_onnx = onnx.load(f) # load onnx model + onnx.checker.check_model(model_onnx) # check onnx model + # print(onnx.helper.printable_graph(model_onnx.graph)) # print + + # Simplify + try: + check_requirements(['onnx-simplifier']) + import onnxsim + + print(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...') + model_onnx, check = onnxsim.simplify(model_onnx) + assert check, 'assert check failed' + onnx.save(model_onnx, f) + except Exception as e: + print(f'{prefix} simplifier failure: {e}') + print(f'{prefix} export success, saved as {f}') except Exception as e: - print('ONNX export failure: %s' % e) + print(f'{prefix} export failure: {e}') # CoreML export + prefix = colorstr('CoreML:') try: import coremltools as ct - print('\nStarting CoreML export with coremltools %s...' % ct.__version__) + print(f'{prefix} starting export with coremltools {onnx.__version__}...') # convert model from torchscript and apply pixel scaling as per detect.py model = ct.convert(ts, inputs=[ct.ImageType(name='image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])]) f = opt.weights.replace('.pt', '.mlmodel') # filename model.save(f) - print('CoreML export success, saved as %s' % f) + print(f'{prefix} export success, saved as {f}') except Exception as e: - print('CoreML export failure: %s' % e) + print(f'{prefix} export failure: {e}') # Finish - print('\nExport complete (%.2fs). Visualize with https://github.com/lutzroeder/netron.' % (time.time() - t)) + print(f'\nExport complete ({time.time() - t:.2f}s). Visualize with https://github.com/lutzroeder/netron.') diff --git a/utils/general.py b/utils/general.py index 726c1bf9cac7..d8c8b504d311 100644 --- a/utils/general.py +++ b/utils/general.py @@ -114,7 +114,7 @@ def check_requirements(requirements='requirements.txt', exclude=()): except Exception as e: # DistributionNotFound or VersionConflict if requirements not met n += 1 print(f"{prefix} {e.req} not found and is required by YOLOv5, attempting auto-update...") - print(subprocess.check_output(f"pip install '{e.req}'", shell=True).decode()) + print(subprocess.check_output(f"pip install {e.req}", shell=True).decode()) if n: # if packages updated source = file.resolve() if 'file' in locals() else requirements