Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-533] MXNet-ONNX export #11213

Merged
merged 83 commits into from
Jun 25, 2018
Merged

Conversation

Roshrini
Copy link
Member

@Roshrini Roshrini commented Jun 8, 2018

Description

This PR has MXNet to ONNX exporter APIs to export MXNet trained models to ONNX protobuf so that those models can be imported in other frameworks for inference.

Test framework:
Currently, we import ONNX models in MXNet, then export them to ONNX, import it in MXNet again to verify inference results.

Working models:

  • Alexnet, Densenet, resnet50, squeezenet, vgg16, vgg19, inception_v1, inception_v2, Googlenet, caffenet, R-CNN

@spidydev @anirudhacharya @piiswrong @sandeep-krishnamurthy @nswamy @anirudh2290

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Feature1, tests, (and when applicable, API doc)
  • Feature2, tests, (and when applicable, API doc)

@Roshrini Roshrini requested a review from szha as a code owner June 8, 2018 23:05
@szha szha requested review from piiswrong and zhreshold and removed request for szha June 11, 2018 18:11
op = str(node["op"])
if op not in MXNetGraph.registry_:
raise AttributeError("No conversion function registered for op type %s yet." % op)
convert_fun = MXNetGraph.registry_[op]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change the name to convert_func

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Copy link
Contributor

@sandeep-krishnamurthy sandeep-krishnamurthy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for awesome work @Roshrini @spidydev @anirudhacharya

Some comments below.

import mxnet as mx

def load_module(json_path, params_path, input_shape):
"""Loads the MXNet model file, retrieves symbol and parameters and returns.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: and returns MXNet symbol and params (weights).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

import logging
import mxnet as mx

def load_module(json_path, params_path, input_shape):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: json_path is too generic name for the function. Will be hard to maintain later. Can we more specific? sym_filepath, params_filepath or something like that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Model weights including both arg and aux params.
"""
if not (os.path.isfile(json_path) and os.path.isfile(params_path)):
raise ValueError("Provide valid path to the json and params file")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It is always useful to have specific Error/Warnings message on what is wrong and why.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

raise ValueError("Provide valid path to the json and params file")
else:
try:
model_name = json_path.rsplit('.', 1)[0].rsplit('-', 1)[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I understand this logic reads symbol and epochs from sym.json file. But, please add code comment for this logic for future bug fixes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

model_name = json_path.rsplit('.', 1)[0].rsplit('-', 1)[0]
num_epochs = int(params_path.rsplit('.', 1)[0].rsplit('-', 1)[1])
except IndexError:
logging.info("Model and params name should be in format: "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is epoch necessary? Only for retraining the loaded model?
As a standard, saving a model need not have epoch number. Probably a necessary for saving checkpoint models. Though MXNet as of today mandates. But if we introduce a new API to save models without epochs attached, do we have any issue here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping epochs to 0 if not provided with the model name

name=name,
epsilon=eps,
momentum=momentum,
spatial=1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

always 1?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MXNET doesnt't have spatial Batch Norm , so actually should be set to 0. While importing ONNX model we will ignore this attribute. But might be an issue when exporting to caffe2/other frameworks that supports spatialBN , thanks for pointing.

# Creating a dictionary here, but if this titlecase pattern
# mxnet_name.title()
act_types = {
"tanh": "Tanh",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only tanh and relu supported?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done !!

onnx_pad_width = [0]*num_pad_values

start_index = 0
end_index = int(num_pad_values/2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: floor?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MXNet pad values in pad op(https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.pad) is always multiple of two. Will add comment to clarify.



@mx_op.register("slice_axis")
def convert_slice_axis(node, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just slice operator?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slice operator will be added later. not used by any models tested yet :)

@@ -114,7 +114,7 @@ def maximum(attrs, inputs, proto_obj):
for op_input in inputs[2:]:
mxnet_op = symbol.maximum(mxnet_op, op_input)
else:
mxnet_op = inputs[0]
mxnet_op = symbol.maximum(inputs[0], inputs[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maximum of same element?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, onnx has a case where if there is only one input, it returns that input itself as output. MXNet needs 2 inputs always

import logging
import mxnet as mx

def load_module(sym_filepath, params_filepath, input_shape):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: what is the purpose of this function why couldn't it be replaced by a simple:

sym = mx.sym.load(sym_filepath)
params = mx.nd.load(params_filepath)
return sym, params

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sym.load and nd.load works to get model and params objects from files but if the model is trained using old version of mxnet, it wont upgrade the model. There will is a compatibility issue.
for example, some models has "param" or "attr" instead of "attrs" in json file.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice corner case that is hard to think of 👍

model : str or symbol object
Path to the json file or Symbol object
weights : str or symbol object
Path to the params file or Params object. (Including both arg_params and aux_params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it a dictionary of Parameters or something else ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it can be both , changed the desc to be more explicit.

from .export_helper import load_module


def export_model(model, weights, input_shape, input_type=np.float32,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weights -> params, to be consistent with the rest of the file

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed.

return dict([(k.replace("arg:", "").replace("aux:", ""), v.asnumpy())
for k, v in weights_dict.items()])

def create_onnx_graph_proto(self, sym, params, in_shape, in_type, log=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

verbose=False

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -18,3 +18,4 @@

from ._import.import_model import import_model, get_model_metadata
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this does not make sense to me.
why do you want to put public function in private module folder _import or _export and include them later?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import is a reserved keyword, we cant have a folder called import. we can probably rename the two folders to onnx_import and onnx_export and make its member files private, except for the modules that we are exposing to the user.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

folder name changed to be public , _import --> onnx2mx , _export-->mx2onnx . also changed the files in the folder as per their usage

from .export_helper import load_module


def export_model(model, weights, input_shape, input_type=np.float32,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use verbose=False

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# create module, passing cpu context
ctx = context.cpu()
test_mod = mod.Module(symbol=sym, data_names=data_names, context=ctx, label_names=None)
test_mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

label_shapes may not always be None?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, but the motive of this function is to just get the shape of the output after forward pass.

self.output_tensors = []

@staticmethod
def register(op_name):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no doc for input and output through out static methods in this class

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usually , detailed info is only added for public api's.

op = node["op"]
name = node["name"]
if log:
print("Converting idx: %d, op: %s, name: %s" % (idx, op, name))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use logging.xx

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if log:
print("Converting idx: %d, op: %s, name: %s" % (idx, op, name))

if op == "null" and name not in params:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, the logic is a confusing here, better simplify it



@classmethod
def prepare(cls, model, device='CPU', **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not directly use mx.cpu()? and it's in capital letter without careful handling.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this method is declared by ONNX. The backends using ONNX test framework derives from "backend" class and implements these functions.

@@ -0,0 +1,98 @@
# Licensed to the Apache Software Foundation (ASF) under one
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems to be added already, but what is the name python-pytest?
there's already a folder tests/python/unittest

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mxnet unittests uses nosetests but onnx backend test framework uses pytest, so keep them separate we created another folder for pytests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just remove the python-pytest folder, and use onnx, the name is pretty confusing and meaningless as an empty one containing only onnx.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the future if there is another component that is built into MXNet that uses pytest instead of nosetests, then what will we do?

the point of naming it python-pytest is separate them from the other tests which uses nosetests framework?

And this naming was part of a previous PR #9963 and was suggested by @marcoabreu during the review process.

params.update(arg_params)
params.update(aux_params)

onnx_file = model_path.rsplit('/', 1)[0] + "/exported_"+model_name+".onnx"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use + to concat path is not portable

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

@rajanksin rajanksin force-pushed the onnx_export branch 7 times, most recently from f28076a to 9922a00 Compare June 14, 2018 15:24
Roshrini and others added 15 commits June 14, 2018 11:54
2. Refactored test framework to support ONNX backened tests.
2. Added Operator support:
   - Convolution2D
   - BatchNorm
   - Add
- Add, Sub, Mul, Div, Sum
- sigmoid, relu, pad( constant, edge, reflect), tanh
- enabled corresponding ONNX backend tests.
Added Operators :
Ceil, Floor
MaxPool, AvgPool, GlobalMaxPool, GlobalAvgPool, matmul
ArgMax, ArgMin, maximum, minimum
…dded only for these.

Fixed logic error with convert_string_to_list()
@rajanksin rajanksin force-pushed the onnx_export branch 2 times, most recently from a723644 to 43788cf Compare June 14, 2018 21:55
Changed underline files public or private as per usage

Resolved conflicts with the latest
Added some error checking
@Roshrini
Copy link
Member Author

@aaronmarkham Can you review docs part of this PR?

@Roshrini
Copy link
Member Author

@sandeep-krishnamurthy @zhreshold Thank you for reviewing the code. Addressed all the comments now.

Copy link
Contributor

@sandeep-krishnamurthy sandeep-krishnamurthy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @Roshrini @spidydev. Great work! Will be very useful for users in combination with ONNX model zoo.

Will wait for other reviewers approval and doc approval.
@zhreshold @aaronmarkham @ThomasDelteil

@sandeep-krishnamurthy
Copy link
Contributor

LGTM. Merging the changes.

@sandeep-krishnamurthy sandeep-krishnamurthy merged commit 7d91602 into apache:master Jun 25, 2018
@szha
Copy link
Member

szha commented Jun 28, 2018

@sandeep-krishnamurthy There are vetos in effect from @zhreshold. Per agreement among committers you are not supposed to merge it.

@szha
Copy link
Member

szha commented Jun 28, 2018

@zhreshold could you take a look at this change again and see if your concerns are sufficiently addressed?

@zhreshold
Copy link
Member

Sorry about the late update, there's one minor issue need to be addressed.

@zhreshold
Copy link
Member

Since the conversation is really long, I might have missed some updates, please ping me directly if I am not responsive. Thanks!

@anirudhacharya
Copy link
Member

@zhreshold also please let me know how to ping you directly, do i do it on the slack channel?

@szha
Copy link
Member

szha commented Jun 28, 2018

@zhreshold if this is a small issue to address then let's request a patch from the author. Could you create an issue?

@marcoabreu
Copy link
Contributor

Since there is so much conversation in this thread, could you please list the open issues?

@sandeep-krishnamurthy
Copy link
Contributor

@szha - I explicitly pinged all the reviewers and waited for 6 days, before merging the PR. I also tried to best of my ability to gather data from contributors if the suggested changes by other reviewers are addressed before merging the PR.

@zhreshold
Copy link
Member

I opened a new issue regarding my concerns in #11475

@szha
Copy link
Member

szha commented Jun 28, 2018

@sandeep-krishnamurthy thanks for the efforts. Please respect "request changes" as vetos nonetheless and try and reach @zhreshold, especially given that you sit in the same office. Much appreciated.

XinYao1994 pushed a commit to XinYao1994/incubator-mxnet that referenced this pull request Aug 29, 2018
* Resolve conflicts

* Export module Test Framework

* refactoring export to work with pretrained models

* comments added

* 1. Refactored export module.
2. Refactored test framework to support ONNX backened tests.
2. Added Operator support:
   - Convolution2D
   - BatchNorm
   - Add

* Added Arithmetic operators:
- Add, Sub, Mul, Div, Sum

* Added operator support:
- sigmoid, relu, pad( constant, edge, reflect), tanh
- enabled corresponding ONNX backend tests.

* Enabled ONNX tests: test_conv, test_basic_conv

Added Operators :
Ceil, Floor

* Added support for:
MaxPool, AvgPool, GlobalMaxPool, GlobalAvgPool, matmul

* adding more operators

* Added Operator support:
ArgMax, ArgMin, maximum, minimum

* Enabled more BASIC_MODEL tests

* Added power operator tests

* Added support for reshape. ONNX only supports 0, -1  special values. Added only for these.
Fixed logic error with convert_string_to_list()

* some tests enabled

* enabling squeezenet

* LRN Op support

* mul_scalar modified to take scalar input

* cleaning some code

* Resolving conlicts on rebase

* Resolving rebase conflicts

* id mapping updated for all operators

* save onnx models added, some code cleanup

* enabled more tests

* conv pad calc fixed

* reshape op fix

* Added support for elu, leakyRelu, prelu

* Cleanup
- Removed run_node, not needed anymore.
- Used correct get_metadata api

* valueinfoproto fix, googlenet test added

* Removed redundant code.
- run_node
- Using correct get_metadata_api

* dilation added

* Lint fixes

* lint fixes

* some fixes to make export work with onx1.2.1

* enabled more tests

* mxnet_export_test file added

* duplicate file deleted

* reduce ops added

* some small fixes

* some lint fixes

* Add tests for inception_v1 and inception_v2

* Add CI runs for export module

* docstring added

* lint fixes, pooling attr fix

* fix

* fix global_pool

* CI  run fix

* code cleanup

* lint fix

* some code cleanup

* pad in pooling added

* slicechannel notimplementederror raised

* Added required license comments

* Lint fixes

* lint fix

* lint fix

* lint fix

* lint fix

* Correct license statement

* Adding onnx a runtime dependency

* Fix import module error for string_types

* Making ONNX runtime dependency

* fixing some comments

* addressing some comments

* params rename

* lint fixes

* fixes

* spatial disabled, path fixed

* fixing some comments

* Added support for remaining act_type(softsign, sigmoid, softrelu) in Activation operator

* changing import

* adding some comments

* Add squeeze op

* Refactored logic to handle extra node(output label node) for saved mxnet model
Added comments

* minor fix for squeeze operator.
Also, added error handling

* identity operator added

* scalar ops added

* Renamed onnx support folders to mark it public folders
Changed underline files public or private as per usage

Resolved conflicts with the latest

* Added support L2Normalization op
Added some error checking

* added comments and warning

* added comments and warning

* doc API ref added
@ciyongch ciyongch mentioned this pull request Jun 4, 2020
11 tasks
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants