Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[1.x] Backporting TensorRT-Gluon Partition API (and TensorRT 7 support) #18916

Merged
merged 3 commits into from
Aug 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/onnx-tensorrt
Submodule onnx-tensorrt updated 63 files
+1 −0 .gitmodules
+46 −177 CMakeLists.txt
+0 −91 Dockerfile
+0 −101 FancyActivation.cu
+0 −142 FancyActivation.hpp
+187 −70 ImporterContext.hpp
+0 −185 InstanceNormalization.cpp
+0 −133 InstanceNormalization.hpp
+607 −517 ModelImporter.cpp
+60 −46 ModelImporter.hpp
+8 −6 NvOnnxParser.cpp
+46 −21 NvOnnxParser.h
+0 −29 NvOnnxParserRuntime.cpp
+0 −85 NvOnnxParserRuntime.h
+0 −30 NvOnnxParserTypedefs.h
+306 −56 OnnxAttrs.cpp
+37 −21 OnnxAttrs.hpp
+0 −57 PluginFactory.cpp
+0 −59 PluginFactory.hpp
+56 −35 README.md
+0 −120 ResizeNearest.cu
+0 −108 ResizeNearest.hpp
+361 −0 ShapeTensor.cpp
+155 −0 ShapeTensor.hpp
+85 −99 ShapedWeights.cpp
+17 −19 ShapedWeights.hpp
+0 −133 Split.cu
+0 −112 Split.hpp
+175 −78 Status.hpp
+78 −40 TensorOrWeights.hpp
+3,646 −1,808 builtin_op_importers.cpp
+2 −1 builtin_op_importers.hpp
+0 −38 builtin_plugins.cpp
+0 −32 builtin_plugins.hpp
+25 −1 common.hpp
+1 −7 contributing.md
+70 −0 docker/onnx-tensorrt-deb.Dockerfile
+80 −0 docker/onnx-tensorrt-tar.Dockerfile
+6 −1 getSupportedAPITest.cpp
+0 −9 libnvonnxparser_runtime.version
+66 −3 main.cpp
+0 −60 nv_onnx_runtime_bindings.i
+32 −17 onnx2trt.hpp
+43 −30 onnx2trt_common.hpp
+3 −3 onnx2trt_runtime.hpp
+1,684 −54 onnx2trt_utils.cpp
+236 −375 onnx2trt_utils.hpp
+155 −150 onnx_backend_test.py
+27 −49 onnx_tensorrt/backend.py
+30 −0 onnx_tensorrt/config.py
+64 −78 onnx_tensorrt/tensorrt_engine.py
+53 −10 onnx_trt_backend.cpp
+130 −44 onnx_utils.hpp
+162 −138 operators.md
+0 −175 plugin.cpp
+0 −183 plugin.hpp
+0 −27 plugin_common.hpp
+0 −125 serialize.hpp
+30 −14 setup.py
+1 −1 third_party/onnx
+73 −56 toposort.hpp
+149 −198 trt_utils.hpp
+1 −1 utils.hpp
8 changes: 3 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ if(USE_TENSORRT)
include_directories(3rdparty/onnx-tensorrt/third_party/onnx/)
add_definitions(-DMXNET_USE_TENSORRT=1)
add_definitions(-DONNX_NAMESPACE=onnx)
add_definitions(-DONNX_ML=1)

find_package(Protobuf REQUIRED)

Expand All @@ -248,14 +249,11 @@ if(USE_TENSORRT)
find_library(ONNX_PROTO_LIBRARY NAMES libonnx_proto.so REQUIRED
PATHS ${ONNX_PATH}
DOC "Path to onnx_proto library.")
find_library(ONNX_TRT_RUNTIME_LIBRARY NAMES libnvonnxparser_runtime.so REQUIRED
PATHS ${ONNX_TRT_PATH}
DOC "Path to onnx_proto library.")
find_library(ONNX_TRT_PARSER_LIBRARY NAMES libnvonnxparser.so REQUIRED
PATHS ${ONNX_TRT_PATH}
DOC "Path to onnx_proto library.")
DOC "Path to onnx_proto parser library.")

list(APPEND mxnet_LINKER_LIBS libnvinfer.so ${ONNX_TRT_PARSER_LIBRARY} ${ONNX_TRT_RUNTIME_LIBRARY}
list(APPEND mxnet_LINKER_LIBS libnvinfer.so ${ONNX_TRT_PARSER_LIBRARY}
${ONNX_PROTO_LIBRARY} ${ONNX_LIBRARY} ${PROTOBUF_LIBRARY})
endif()

Expand Down
8 changes: 2 additions & 6 deletions ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#
# Dockerfile to run MXNet on Ubuntu 16.04 for CPU

FROM nvidia/cuda:10.0-devel
FROM nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04

WORKDIR /work/deps

Expand All @@ -36,12 +36,8 @@ ARG USER_ID=0
COPY install/ubuntu_adduser.sh /work/
RUN /work/ubuntu_adduser.sh

ENV CUDNN_VERSION=7.5.0.56
COPY install/ubuntu_cudnn.sh /work/
RUN /work/ubuntu_cudnn.sh

COPY runtime_functions.sh /work/

WORKDIR /work/mxnet
ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib
ENV CPLUS_INCLUDE_PATH=${CPLUS_INCLUDE_PATH}:/usr/local/cuda-10.0/targets/x86_64-linux/include/
ENV CPLUS_INCLUDE_PATH=${CPLUS_INCLUDE_PATH}:/usr/local/cuda-10.2/targets/x86_64-linux/include/
15 changes: 8 additions & 7 deletions ci/docker/install/tensorrt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# under the License.

# Install gluoncv since we're testing Gluon models as well
pip3 install gluoncv==0.2.0
pip3 install gluoncv==0.4.0

# Install Protobuf
# Install protoc 3.5 and build protobuf here (for onnx and onnx-tensorrt)
Expand All @@ -40,10 +40,11 @@ popd

# Install TensorRT
echo "TensorRT build enabled. Installing TensorRT."
wget -qO tensorrt.deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64/nvidia-machine-learning-repo-ubuntu1604_1.0.0-1_amd64.deb
dpkg -i tensorrt.deb
apt-get update
apt-get install -y --allow-downgrades libnvinfer5=5.1.5-1+cuda10.0
apt-get install -y --allow-downgrades libnvinfer-dev=5.1.5-1+cuda10.0
apt-mark hold libnvinfer5 libnvinfer-dev
rm tensorrt.deb
TRT_VERSION="7.0.0-1+cuda10.2"
TRT_MAJOR_VERSION=7
apt-get install -y libnvinfer${TRT_MAJOR_VERSION}=${TRT_VERSION} \
libnvinfer-dev=${TRT_VERSION} \
libnvinfer-plugin${TRT_MAJOR_VERSION}=${TRT_VERSION} \
libnvinfer-plugin-dev=${TRT_VERSION}
apt-mark hold libnvinfer${TRT_MAJOR_VERSION} libnvinfer-dev
29 changes: 7 additions & 22 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -708,37 +708,35 @@ build_ubuntu_gpu_tensorrt() {

build_ccache_wrappers

export ONNX_NAMESPACE=onnx

# Build ONNX
pushd .
echo "Installing ONNX."
cd 3rdparty/onnx-tensorrt/third_party/onnx
rm -rf build
mkdir -p build
cd build
cmake \
-DCMAKE_CXX_FLAGS=-I/usr/include/python${PYVER}\
-DBUILD_SHARED_LIBS=ON ..\
-G Ninja
ninja -j 1 -v onnx/onnx.proto
ninja -j 1 -v
cmake -DCMAKE_CXX_FLAGS=-I/usr/include/python${PYVER} -DBUILD_SHARED_LIBS=ON ..
make -j$(nproc)
export LIBRARY_PATH=`pwd`:`pwd`/onnx/:$LIBRARY_PATH
export CPLUS_INCLUDE_PATH=`pwd`:$CPLUS_INCLUDE_PATH
export CXXFLAGS=-I`pwd`
popd

# Build ONNX-TensorRT
pushd .
cd 3rdparty/onnx-tensorrt/
mkdir -p build
cd build
cmake ..
cmake -DONNX_NAMESPACE=$ONNX_NAMESPACE ..
make -j$(nproc)
export LIBRARY_PATH=`pwd`:$LIBRARY_PATH
popd

mkdir -p /work/mxnet/lib/
cp 3rdparty/onnx-tensorrt/third_party/onnx/build/*.so /work/mxnet/lib/
cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser_runtime.so.0 /work/mxnet/lib/
cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser.so.0 /work/mxnet/lib/
cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser.so* /work/mxnet/lib/

cd /work/build
cmake -DUSE_CUDA=1 \
Expand Down Expand Up @@ -1071,19 +1069,6 @@ unittest_ubuntu_python3_gpu_nocudnn() {
nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS $NOSE_TIMER_ARGUMENTS --with-xunit --xunit-file nosetests_gpu.xml --verbose tests/python/gpu
}

unittest_ubuntu_tensorrt_gpu() {
set -ex
export PYTHONPATH=./python/
export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
export MXNET_SUBGRAPH_VERBOSE=0
export LD_LIBRARY_PATH=/work/mxnet/lib:$LD_LIBRARY_PATH
export CUDNN_VERSION=${CUDNN_VERSION:-7.0.3}
export MXNET_ENABLE_CYTHON=0
export DMLC_LOG_STACK_TRACE_DEPTH=10
tests/python/tensorrt/lenet5_train.py
nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS $NOSE_TIMER_ARGUMENTS --with-xunit --xunit-file nosetests_trt_gpu.xml --verbose --nocapture tests/python/tensorrt/
}

# quantization gpu currently only runs on P3 instances
# need to separte it from unittest_ubuntu_python3_gpu()
unittest_ubuntu_python3_quantization_gpu() {
Expand Down
20 changes: 1 addition & 19 deletions ci/jenkins/Jenkins_steps.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ mx_cmake_lib_cython = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/l
mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests'
mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so'
mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a'
mx_tensorrt_lib = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so'
mx_tensorrt_lib = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, lib/libnvonnxparser.so*, lib/libonnx_proto.so, lib/libonnx.so'
mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy3/*.so, python/mxnet/_ffi/_cy3/*.so'
mx_lib_cpp_capi = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, lib/libmkldnn.so.1, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy3/*.so, python/mxnet/_ffi/_cy3/*.so, build/tests/cpp/mxnet_unit_tests'
mx_lib_cpp_examples_no_tvm_op = 'lib/libmxnet.so, lib/libmxnet.a, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy3/*.so, python/mxnet/_ffi/_cy3/*.so'
Expand Down Expand Up @@ -853,24 +853,6 @@ def test_unix_python3_mkldnn_nocudnn_gpu() {
}]
}

def test_unix_python3_tensorrt_gpu() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mseth10 I guess this is the cause of the CI job issue

Copy link
Contributor

@samskalicky samskalicky Aug 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Kh4L you'd need to get rid of this too if you wanna remove this test:
https://github.com/apache/incubator-mxnet/blob/9981e847fff9270268385068c5b7d0c3929e46f9/ci/jenkins/Jenkinsfile_unix_gpu#L53
This is where its being called from, but since you deleted the function it doesnt exist. Hence the error in the log:
https://jenkins.mxnet-ci.amazon-ml.com/job/mxnet-validation/job/unix-gpu/view/change-requests/job/PR-18916/8/consoleText

[2020-08-18T07:58:59.886Z] java.lang.NoSuchMethodError: No such DSL method ‘test_unix_python3_tensorrt_gpu’

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's right, if we remove this test we should also remove it from jenkins pipeline

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in master, it's removed in this commit
c1098aa#diff-0c2b22569b87fa79153336ea3ac94f33

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for investigating, I missed that file...

return ['Python3: TensorRT GPU': {
node(NODE_LINUX_GPU_P3) {
ws('workspace/build-tensorrt') {
timeout(time: max_time, unit: 'MINUTES') {
try {
utils.unpack_and_init('tensorrt', mx_tensorrt_lib)
utils.docker_run('ubuntu_gpu_tensorrt', 'unittest_ubuntu_tensorrt_gpu', true)
utils.publish_test_coverage()
} finally {
utils.collect_test_results_unix('nosetests_tensorrt.xml', 'nosetests_python3_tensorrt_gpu.xml')
}
}
}
}
}]
}

def test_unix_python3_integration_gpu() {
return ['Python Integration GPU': {
node(NODE_LINUX_GPU_G4) {
Expand Down
1 change: 0 additions & 1 deletion ci/jenkins/Jenkinsfile_unix_gpu
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ core_logic: {
custom_steps.test_unix_python3_quantize_gpu(),
custom_steps.test_unix_python3_mkldnn_gpu(),
custom_steps.test_unix_python3_mkldnn_nocudnn_gpu(),
custom_steps.test_unix_python3_tensorrt_gpu(),
custom_steps.test_unix_perl_gpu(),
custom_steps.test_unix_r_gpu(),
custom_steps.test_unix_cpp_gpu(),
Expand Down
28 changes: 5 additions & 23 deletions example/extensions/lib_pass/test_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,34 +48,16 @@
sym = mx.sym.log(d)

def test_model(pass_name):
args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}
# execute in MXNet
print('-------------------------------')
print('Testing regular MXNet execution')

exe = sym.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))})
out = exe.forward()
inputs = [a,b]
sym_block = nn.SymbolBlock(sym, inputs)
sym_block.initialize()
out = sym_block(mx.nd.ones((3,2)),mx.nd.ones((3,2)))
print(out)

# Symbol optimize_for
# with propogating shapes/types
print('-------------------------------')
print('Testing pass "%s" with shapes/types' % pass_name)
arg_array = [mx.nd.ones((3,2),dtype='float32'), mx.nd.ones((3,2),dtype='float32')]
aux = []
mysym2 = sym.optimize_for(pass_name,arg_array,aux)
print(mysym2.tojson())
exe2 = mysym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))})
out2 = exe2.forward()
print(out2)

# without propogating shapes/types
print('-------------------------------')
print('Testing pass "%s" without shapes/types' % pass_name)
mysym3 = sym.optimize_for(pass_name, myOpt='yello')
exe3 = mysym3.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))})
out3 = exe3.forward()
print(out3)

# Gluon Hybridize
print('-------------------------------')
print('Testing pass "%s" Gluon Hybridize with shapes/types' % pass_name)
Expand Down
25 changes: 12 additions & 13 deletions example/extensions/lib_subgraph/test_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,40 +49,39 @@
sym2 = mx.sym.log(d2)

def test(backend):
args = {'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}
###############################################
# Test with subgraph not consuming params
###############################################
#execute in MXNet
print('-------------------------------')
print('Testing regular MXNet execution')
exe = sym.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))})
exe = sym.bind(ctx=mx.cpu(), args=args)
out = exe.forward()
print(out)

# with propogating shapes/types
print('-------------------------------')
print('Testing %s partitioning with shapes/types' % backend)
arg_array = [mx.nd.ones((3,2),dtype='float32'), mx.nd.ones((3,2),dtype='float32')]
mysym2 = sym.optimize_for(backend,arg_array)
mysym2 = sym.optimize_for(backend,args)
print(mysym2.tojson())
exe2 = mysym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))})
exe2 = mysym2.bind(ctx=mx.cpu(), args=args)
out2 = exe2.forward()
print(out2)

# with propogating shapes/types, rejecting subgraph
print('-------------------------------')
print('Testing %s partitioning with shapes/types - rejecting subgraph' % backend)
arg_array = [mx.nd.ones((3,2),dtype='float32'), mx.nd.ones((3,2),dtype='float32')]
mysym2 = sym.optimize_for(backend, arg_array, reject=True)
exe2 = mysym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))})
mysym2 = sym.optimize_for(backend, args, reject=True)
exe2 = mysym2.bind(ctx=mx.cpu(), args=args)
out2 = exe2.forward()
print(out2)

# without propogating shapes/types
print('-------------------------------')
print('Testing %s partitioning without shapes/types' % backend)
mysym3 = sym.optimize_for(backend, myOpt='yello')
exe3 = mysym3.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))})
exe3 = mysym3.bind(ctx=mx.cpu(), args=args)
out3 = exe3.forward()
print(out3)

Expand Down Expand Up @@ -115,28 +114,28 @@ def test(backend):
###############################################
# Test with subgraph directly consuming params
###############################################
args = {'a':mx.nd.ones((3,2))}
#execute in MXNet
print('-------------------------------')
print('Testing regular MXNet execution')
exe5 = sym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))})
exe5 = sym2.bind(ctx=mx.cpu(), args=args)
out5 = exe5.forward()
print(out5)

# with propogating shapes/types
print('-------------------------------')
print('Testing %s partitioning with shapes/types' % backend)
arg_array = [mx.nd.ones((3,2),dtype='float32')]
mysym6 = sym2.optimize_for(backend, arg_array, reqArgs=True)
mysym6 = sym2.optimize_for(backend, args, reqArgs=True)
print(mysym6.tojson())
exe6 = mysym6.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))})
exe6 = mysym6.bind(ctx=mx.cpu(), args=args)
out6 = exe6.forward()
print(out6)

# without propogating shapes/types
print('-------------------------------')
print('Testing %s partitioning without shapes/types' % backend)
mysym7 = sym2.optimize_for(backend, reqArgs=True)
exe7 = mysym7.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))})
exe7 = mysym7.bind(ctx=mx.cpu(), args=args)
out7 = exe7.forward()
print(out7)

Expand Down
30 changes: 30 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2166,6 +2166,25 @@ MXNET_DLL int MXGenAtomicSymbolFromSymbol(SymbolHandle sym_handle, SymbolHandle
* \param num_options number of key value pairs
* \param keys keys for options
* \param vals values corresponding to keys
* \param num_input_shapes number of input shapes
* \param input_shape_names names of the input shapes
* \param input_shape_data pointer to the contiguous data shapes
* \param input_shape_idx array of per shape starting idx, the shape length for the i-th input shape
* is calculate as input_shape_idx[i+1] - input_shape_idx[i]
* \param num_input_dtypes number of input data types
* \param input_dtype_names array of names of the input data types
* \param input_dtypes array of values of the input data types
* \param num_input_stypesnumber of input storage types
* \param input_stype_names array of names of the input storage types
* \param input_stypes array of values of input storage types
* \param skip_infer if the optimization should skip the attribute inferences
* (to use if the backend does not require shape inference)
* \param new_args_cnt pointer a number to store the number of new args
* \param new_args_handle pointer on array to store the new args handles
* \param new_arg_names_handle pointer on array to store the new args names
* \param new_aux_cnt pointer a number to store the number of new aux
* \param new_aux_handle pointer on array to store the new aux handles
* \param new_aux_names_handle pointer on array to store the new aux names
*/
MXNET_DLL int MXOptimizeForBackend(SymbolHandle sym_handle,
const char* backend_name,
Expand All @@ -2178,6 +2197,17 @@ MXNET_DLL int MXOptimizeForBackend(SymbolHandle sym_handle,
const mx_uint num_options,
const char** keys,
const char** vals,
const uint32_t num_input_shapes,
const char** input_shape_names,
const int64_t* input_shape_data,
const uint32_t* input_shape_idx,
const uint32_t num_input_dtypes,
const char** input_dtype_names,
const int* input_dtypes,
const uint32_t num_input_stypes,
const char** input_stype_names,
const int* input_stypes,
bool skip_infer,
int* new_args_cnt,
NDArrayHandle** new_args_handle,
char*** new_arg_names_handle,
Expand Down
11 changes: 11 additions & 0 deletions perl-package/AI-MXNetCAPI/mxnet.i
Original file line number Diff line number Diff line change
Expand Up @@ -1637,6 +1637,17 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
const mx_uint in,
const char** keys,
const char** vals,
const uint32_t num_input_shapes,
const char** input_shape_names,
const int64_t* input_shape_data,
const uint32_t* input_shape_idx,
const uint32_t num_input_dtypes,
const char** input_dtype_names,
const int* input_dtypes,
const uint32_t num_input_stypes,
const char** input_stype_names,
const int* input_stypes,
bool skip_infer,
int* new_args_cnt,
NDArrayHandle** new_args_handle,
char*** new_arg_names_handle,
Expand Down
Loading