Skip to content

Commit

Permalink
Multithreaded Inference Support (apache#16654)
Browse files Browse the repository at this point in the history
* Add cached op threadsafe version with corresponding C APIs, CPP Package changes, CI changes and tests

* Fix download cmd in runtime_functions

* Add CI changes

* Add stage

Fix indentation

* Fix lint

* Change to DEFAULT for C API

* Fix mxnet_unit_tests path

* export correct LD_LIBRARY_PATH

* Add cpp include dirs

* Build test with USE_CPP_PACKAGE

* Add cached op threadsafe version with corresponding C APIs, CPP Package changes, CI changes and tests

* Fix download cmd in runtime_functions

* Merge

* change mkldnn lib name

* Add static_alloc, static_Shape support

* Address review comments

* Make GetCachedOpThreadSafeState similar to cached_op

* Address review comments: comments for locking strategy

* multithreaded inference tutorial

* [Estimator] handle composite metrics in estimator (apache#16676)

* handle composite metrics in estimator

* fix composite metric case in handlers

* remove unused import

* [Estimator] refactor estimator to allow overriding evaluate/fit of a batch (apache#16678)

* refactor estimator to allow overriding evaluate/fit of a batch

* add doc to explain call structure and how to override

* fix and doc

* Pointwise fusion for GPU (apache#15167)

* Beginning of RTC of pointwise ops

* Code generation from the given JSON

* add initial simple_partition_pass and use it for pointwise fusion

* fix the fusion, use a symbol.Copy() at the beginning of binding function, use the name of input nodes in the cuda code

* Fixes

* Adding support for attribute inference for backward nodes when fusing

* keep proper input ordering for fused Op

* instantiate the indexed_graph before starting the subgraph replacement, return a new graph to reset the indexed_graph

* Fuse backward

* fix ordering of subgraph node inputs using subgraph topological ordering instead of main graph topological ordering, add tvm.patch

* excluse forward node fusion during the fusion of the nodes in the backward graph

* Dealing with fused backward nodes inferattr

* use subgraph.indexed_graph() instead of main for _FusedOpHelper nodes node_id, invert control_deps loop to modify topology of subgraph before calling its indexed_graph(), check that all node of the first DFSVisit are actually in the subgraph

* Adding support for other reqs in codegen

* Fix

* Cleaning

* Change the TVM submodule

* More cleaning

* Making linter happy

* Do fusion only if default context is GPU

* Fixes for tests
Add powerscalar and rpowerscalar, fix return type of zero and one
Cleaning, fixing lint
Go back to proper TVM submodule

* Fix the TVM commit

* Fix lint

* Guard fusion with MXNET_USE_CUDA

* Fix

* Fix clang-tidy

* Add erf and erfinv backward

* Gluon support for fusion

* Cleaning

* Cleaning and allow shape/type change in FusedOp

* Fixing Gluon bugs

* Fixing after rebase

* Fixing race condition and guarding against races when using NVRTC

* Cleaning and renaming FusedOp to _FusedOp

* Going easy on Windows compiler

* Disable fusion on Windows for now

* Refactor InferAttr and InferShapeAttr

* Added slice and half2 support to FusedOp

* Fix lint errors

* Added multiple types support for vector loading/storing

* add slice fusion when it's at the beginning of subgraphs

* Removed constant ndim assumption in fused op

* Fix memory alignment issue in slice for FusedOp

* Fixes

* Fix lint errors

* Do not include cuda_fp16.h

* Refactor fused op op lists

* Make linter happy

* Changes from review

* Fixes after rebase

* Expand FusedOp support for slice

* Fix for fp16 _zeros and _ones

* Fix

* Moving aux functions to unnamed namespace and detail namespace -> fusion
namespace

* Disabling fusion if it alters topological order of inputs

* Print code only when env variable is set

* Fix

* Fix lint and 2 tests that specify the same names for multiple inputs

* Fixes from review and disabling fusion of slice with non-default step

* Add amp_cast to fusion, fixes

* Add amp_multicast and its backward to the list of support ops

* Apply wording suggestions from code review

Co-Authored-By: Aaron Markham <markhama@amazon.com>

* Apply wording suggestions from code review

Co-Authored-By: Aaron Markham <markhama@amazon.com>

* Make clearer comment

* Adding punctuation and capitalization to \brief descriptions

* Fix

* Fix

* Add backward_cast to fusion

* Adding unittests for fusion. Fix for erfinv_grad

* Adding slice ops and add_n to tests

* Fixes from review

* Setting inplace option

* Fix lint

* Storing double in half

* Retrigger CI

* Slight relaxing of the relative tolerance in the test

* Move the env variable check to the end

* Fix a race condition between InferShape and scheduled Forward

* Fix flakey test_fusion test involving fp32 erfinv op.

* Fix from review

* Added broadcast_like and slice_like to fused op

* Minor fix and cleanup

* Added negative axis support in slice_axis, temporarily disabled fusion of slice_like and broadcast_like

* Added axes support to slice_like

* Added axis support to broadcast_like

* Add fast_load_slice function to fused op code

* Added runtime switch for choosing fast and slow slice kernel

* Fix lint and warning

* Going easy on Windows compiler (again)

* Fix slice_like

* Debug broadcast_like fusion

* Fix lint

* Fix lint

* Trigger CI

* Get rid of the initializer list

* Fix backward calls with different gradient type

* avoid cycle when adding node specific for inputs of subgraph for pointwise fusion

* Fix lint

* Add namespace to the fusion implementations

* Set launch bounds on the fused kernel

* Fix NumPy tests

* Test showcasing an issue fixed in PR apache#16553

* Cast scalarts to FP32 and perform (a*1.0/b) instead of (a/b)

Fix lint errors

Fix lint

* Fix a bug in cycle detection for inputs only op in pointwise fusion

* Add comments to simple_partition_pass.h file

* fix install dir (apache#16690)

* [numpy] add numpy operator : append (apache#16564)

* add operator : append ; fix op concatenate when axis = None

* pylint disable

remove mistake

disable pylint

* Initializer.__eq__ (apache#16680)

* fix binary dependencies in CD and nightly (apache#16693)

* [MKL-DNN] Add mxnet mkldnn cmake tutorial (apache#16688)

* add mxnet mkldnn cmake instruction

* imporve doc

* OMP->OpenMP

* Revert "[MKLDNN]Fix reorder2default (apache#16602)" (apache#16697)

This reverts commit dd4eaf5.

* [Estimator] refactor estimator and clarify docs (apache#16694)

* refactor estimator and clarify docs

* fix info message and test

* clean up after releasing logging handler

* Eliminate common expressions (apache#15657)

* Eliminate common expressions from a graph

* Guarding against optimizing out stateful ops and ops that require
resource

* Fix lint

* Added THasDeterministicOutput to multiple ops

* DDebug eliminate common expr

* Added test

* Expose get_optimized_symbol

* Fix

* Fix 2

* Add doc to the Python call

* Add env var MXNET_ELIMINATE_COMMON_EXPR, default true

* Add comments, improve readability of eliminate_common_expr_pass.cc

* Expand testing

* Lower priority of THasDeterministicOutput attr for equal Node test

* Change mx.gpu() to mx.cpu() in tests

* Skip CSE test on Windows (as env variable setting during test does not work there)

* Add missing import sys

* Add missing import logging

* Backport of apache#16711, apache#16737, apache#16408 to 1.6 branch (apache#16763)

* support mixed-precision true_divide (apache#16711)

* [MKLDNN] use dim_t instead of int in slice/transpose operators (apache#16737)

* use dim_t instead of int

* fix same issue in pooling

* rebase code

* trigger CI

* Add MXNet Ops for fast multihead attention (apache#16408)

* add MXNet Ops for fast multihead attention

* add cutlass as 3rdparty dependency

* add cutlass to compilation flags

* remove all cutlass stuff

* add better error message and description and remove cutlass from compilation flags

* change credit for the approach since the code have changed

* fix typos

* correct another typo

* Add all the cuda/cublas helper functions

* remove tests using kAddTo

* only use cublasStridedBatchedGemm if CUDA >= 9.1

* add equivalent mxnet code in description of mha ops

* remove a wrong copy-paste

* add _contrib for namespace and add GPU only on description

* add warning in bwd_ignore_zero_init description, also test with fp32

* add error return if bwd_ignore_zero_init is used without MXNET_EXEC_ENABLE_ADDTO

* remove std::move for clang

* remove bwd_ignore_zero_init flag

* remove bwd_ignore_zero_init in test_operator_gpu.py

* fix typo

* fix another typo

* Removed unrelated test

* Add example and documentation for multi threaded inference

* Add LICENSE

* Add get_model.py

* Add license for README

* Refactor cached op and cached op threadsafe

* Add limitation

* Add tests for naive engine

* Add latest test changes

* Thread Safety tests in NaiveEngine mode

* Thread Safety tests update

* Update thread safety tests, add unsupported use cases

* Changes to doc and refactor

* Fix todo owner, indentation and mx_float->float

* Refactor cached op code, remove num_threads arg from example

* Fix lint

* Fix warning

* Add back cython, required for unix-gpu build

* Fix for windows

* Add bulking support for thread safe cached op version

* Add support for subgraph testing

* import mxnet before calling get_backend_symbol

* Fix symbol json name

* Refactor DynamicForward

* Add comments

* Add DMLC_ATTRIBUTE_UNUSED

* Fix use_naive_run issue

* Fix lint

* Revert unittest_cpp to old test since it doesnt test thread safety

* Fix doc

Co-authored-by: Sheng Zha <szha@users.noreply.github.com>
Co-authored-by: Przemyslaw Tredak <ptrendx@gmail.com>
Co-authored-by: Tao Lv <tao.a.lv@intel.com>
Co-authored-by: JiangZhaoh <54654391+JiangZhaoh@users.noreply.github.com>
Co-authored-by: Leonard Lausen <leonard@lausen.nl>
Co-authored-by: Xinyu Chen <xinyu1.chen@intel.com>
Co-authored-by: Zhennan Qin <zhennan.qin@intel.com>
  • Loading branch information
8 people committed Feb 1, 2020
1 parent a726c40 commit b1e4911
Show file tree
Hide file tree
Showing 26 changed files with 2,361 additions and 330 deletions.
6 changes: 5 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,10 @@ if(USE_MKLDNN)
set(INSTALL_MKLDNN ON)
endif()

if(USE_CPP_PACKAGE)
add_definitions(-DMXNET_USE_CPP_PACKAGE=1)
endif()

# Allow Cuda compiles outside of src tree to find things in 'src' and 'include'
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src)
Expand Down Expand Up @@ -853,7 +857,6 @@ if(MSVC AND USE_MXNET_LIB_NAMING)
set_target_properties(mxnet PROPERTIES OUTPUT_NAME "libmxnet")
endif()

add_subdirectory(tests)

include(GNUInstallDirs)
install(TARGETS ${MXNET_INSTALL_TARGETS}
Expand Down Expand Up @@ -915,6 +918,7 @@ endif()
if(BUILD_CPP_EXAMPLES)
add_subdirectory(example/image-classification/predict-cpp)
endif()
add_subdirectory(tests)

# ---[ Linter target
if(MSVC)
Expand Down
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,7 @@ $(BIN) :
# CPP Package
ifeq ($(USE_CPP_PACKAGE), 1)
include cpp-package/cpp-package.mk
CFLAGS += -DMXNET_USE_CPP_PACKAGE=1
endif

include mkldnn.mk
Expand Down
38 changes: 38 additions & 0 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,27 @@ build_ubuntu_gpu_cuda101_cudnn7() {
CUDA_ARCH="$CI_CUDA_COMPUTE_CAPABILITIES" \
USE_SIGNAL_HANDLER=1 \
-j$(nproc)
make cython PYTHON=python2
make cython PYTHON=python3
}

build_ubuntu_gpu_cuda101_cudnn7_mkldnn_cpp_test() {
set -ex
build_ccache_wrappers
make \
DEV=1 \
USE_BLAS=openblas \
USE_MKLDNN=1 \
USE_CUDA=1 \
USE_CUDA_PATH=/usr/local/cuda \
USE_CUDNN=1 \
USE_TVM_OP=0 \
USE_CPP_PACKAGE=1 \
USE_DIST_KVSTORE=1 \
CUDA_ARCH="$CI_CUDA_COMPUTE_CAPABILITIES" \
USE_SIGNAL_HANDLER=1 \
-j$(nproc)
make test USE_CPP_PACKAGE=1 -j$(nproc)
make cython PYTHON=python2
make cython PYTHON=python3
}
Expand Down Expand Up @@ -1323,6 +1343,24 @@ integrationtest_ubuntu_gpu_cpp_package() {
cpp-package/tests/ci_test.sh
}

integrationtest_ubuntu_gpu_capi_cpp_package() {
set -ex
export PYTHONPATH=./python/
export LD_LIBRARY_PATH=/work/mxnet/lib:$LD_LIBRARY_PATH
python3 -c "import mxnet as mx; mx.test_utils.download_model(\"imagenet1k-resnet-18\"); mx.test_utils.download_model(\"imagenet1k-resnet-152\"); mx.test_utils.download_model(\"imagenet1k-resnet-50\");"
# Load symbol, convert symbol to leverage fusion with subgraphs, save the model
python3 -c "import mxnet as mx; x = mx.sym.load(\"imagenet1k-resnet-152-symbol.json\"); x.get_backend_symbol(\"MKLDNN\"); x.save(\"imagenet1k-resnet-152-subgraph-symbol.json\");"
# Copy params file with a different name, used in subgraph symbol testing
cp imagenet1k-resnet-152-0000.params imagenet1k-resnet-152-subgraph-0000.params
build/tests/cpp/mxnet_unit_tests --gtest_filter="ThreadSafety.*"
build/tests/cpp/mxnet_unit_tests --gtest_filter="ThreadSafety.*" --thread-safety-with-cpu
# Also run thread safety tests in NaiveEngine mode
export MXNET_ENGINE_TYPE=NaiveEngine
build/tests/cpp/mxnet_unit_tests --gtest_filter="ThreadSafety.*"
build/tests/cpp/mxnet_unit_tests --gtest_filter="ThreadSafety.*" --thread-safety-with-cpu
unset MXNET_ENGINE_TYPE
}

integrationtest_ubuntu_cpu_dist_kvstore() {
set -ex
pushd .
Expand Down
29 changes: 29 additions & 0 deletions ci/jenkins/Jenkins_steps.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/l
mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a'
mx_tensorrt_lib = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so'
mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so'
mx_lib_cpp_capi = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, lib/libmkldnn.so.1, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so, build/tests/cpp/mxnet_unit_tests'
mx_lib_cpp_examples_no_tvm_op = 'lib/libmxnet.so, lib/libmxnet.a, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so'
mx_lib_cpp_examples_cpu = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/cpp-package/example/*'

Expand Down Expand Up @@ -261,6 +262,20 @@ def compile_unix_full_gpu() {
}]
}

def compile_unix_full_gpu_mkldnn_cpp_test() {
return ['GPU: CUDA10.1+cuDNN7+MKLDNN+CPPTEST': {
node(NODE_LINUX_CPU) {
ws('workspace/build-gpu-mkldnn-cpp') {
timeout(time: max_time, unit: 'MINUTES') {
utils.init_git()
utils.docker_run('ubuntu_build_cuda', 'build_ubuntu_gpu_cuda101_cudnn7_mkldnn_cpp_test', false)
utils.pack_lib('gpu_mkldnn_cpp_test', mx_lib_cpp_capi)
}
}
}
}]
}

def compile_unix_full_gpu_no_tvm_op() {
return ['GPU: CUDA10.1+cuDNN7 TVM_OP OFF': {
node(NODE_LINUX_CPU) {
Expand Down Expand Up @@ -1010,6 +1025,20 @@ def test_unix_cpp_package_gpu() {
}]
}

def test_unix_capi_cpp_package() {
return ['capi-cpp-package GPU': {
node(NODE_LINUX_GPU) {
ws('workspace/it-capi-cpp-package') {
timeout(time: max_time, unit: 'MINUTES') {
utils.unpack_and_init('gpu_mkldnn_cpp_test', mx_lib_cpp_capi)
utils.docker_run('ubuntu_gpu_cu101', 'integrationtest_ubuntu_gpu_capi_cpp_package', true)
utils.publish_test_coverage()
}
}
}
}]
}

def test_unix_scala_cpu() {
return ['Scala: CPU': {
node(NODE_LINUX_CPU) {
Expand Down
2 changes: 2 additions & 0 deletions ci/jenkins/Jenkinsfile_unix_gpu
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ core_logic: {
custom_steps.compile_unix_int64_gpu(),
custom_steps.compile_unix_full_gpu_no_tvm_op(),
custom_steps.compile_unix_cmake_gpu_no_tvm_op(),
custom_steps.compile_unix_full_gpu_mkldnn_cpp_test()
])

utils.parallel_stage('Tests', [
Expand All @@ -64,6 +65,7 @@ core_logic: {
custom_steps.test_unix_distributed_kvstore_gpu(),
custom_steps.test_static_python_gpu(),
custom_steps.test_unix_python3_gpu_no_tvm_op(),
custom_steps.test_unix_capi_cpp_package(),

// Disabled due to: https://github.com/apache/incubator-mxnet/issues/11407
//custom_steps.test_unix_caffe_gpu()
Expand Down
2 changes: 1 addition & 1 deletion cpp-package/include/mxnet-cpp/ndarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ inline NDArray::NDArray(const mx_float *data, const Shape &shape,
CHECK_EQ(MXNDArrayCreate(shape.data(), shape.ndim(), context.GetDeviceType(),
context.GetDeviceId(), false, &handle),
0);
MXNDArraySyncCopyFromCPU(handle, data, shape.Size());
CHECK_EQ(MXNDArraySyncCopyFromCPU(handle, data, shape.Size()), 0);
blob_ptr_ = std::make_shared<NDBlob>(handle);
}
inline NDArray::NDArray(const std::vector<mx_float> &data, const Shape &shape,
Expand Down
2 changes: 2 additions & 0 deletions cpp-package/include/mxnet-cpp/symbol.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ class Symbol {
*unnamed (empty string).
*/
std::vector<std::string> ListArguments() const;
/*! \return lists all argument names and aux states of the symbol */
std::vector<std::string> ListInputs() const;
/*! \return get the descriptions of outputs for this symbol */
std::vector<std::string> ListOutputs() const;
/*! \return get the descriptions of auxiliary data for this symbol */
Expand Down
12 changes: 12 additions & 0 deletions cpp-package/include/mxnet-cpp/symbol.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,18 @@ inline std::vector<std::string> Symbol::ListArguments() const {
}
return ret;
}

inline std::vector<std::string> Symbol::ListInputs() const {
std::vector<std::string> ret;
mx_uint size;
const char **sarr;
NNSymbolListInputNames(GetHandle(), 0, &size, &sarr);
for (mx_uint i = 0; i < size; ++i) {
ret.push_back(std::string(sarr[i]));
}
return ret;
}

inline std::vector<std::string> Symbol::ListOutputs() const {
std::vector<std::string> ret;
mx_uint size;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
---
layout: page_api
title: Multi Threaded Inference
action: Get Started
action_url: /get_started
permalink: /api/cpp/docs/tutorials/multi_threaded_inference
is_tutorial: true
tag: cpp
---
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->

<!--- http://www.apache.org/licenses/LICENSE-2.0 -->

<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->

# Multi Threaded Inference API

A long standing request from MXNet users has been to invoke parallel inference on a model from multiple threads while sharing the parameters.
With this use case in mind, the threadsafe version of CachedOp was added to provide a way for customers to do multi-threaded inference for MXNet users.
This doc attempts to do the following:
1. Discuss the current state of thread safety in MXNet
2. Explain how one can use C API and thread safe version of cached op, along with CPP package to achieve iultithreaded inference. This will be useful for end users as well as frontend developers of different language bindings
3. Discuss the limitations of the above approach
4. Future Work

## Current state of Thread Safety in MXNet

Examining the current state of thread safety in MXNet we can arrive to the following conclusion:

1. MXNet Dependency Engine is thread safe (except for WaitToRead invoked inside a spawned thread. Please see Limitations section)
2. Graph Executor which is Module/Symbolic/C Predict API backend is not thread safe
3. Cached Op (Gluon Backend) is not thread safe

The CachedOpThreadSafe and corresponding C APIs were added to address point 3 above and provide a way
for MXNet users to do multi-threaded inference.

```
/*!
* \brief create cached operator, allows to choose thread_safe version
* of cachedop
*/
MXNET_DLL int MXCreateCachedOpEX(SymbolHandle handle,
int num_flags,
const char** keys,
const char** vals,
CachedOpHandle *out,
bool thread_safe DEFAULT(false));
```

## Multithreaded inference in MXNet with C API and CPP Package

### Prerequisites
To complete this tutorial you need to:
- Learn the basics about [MXNet C++ API](/api/cpp)
- Build MXNet from source with make/cmake
- Build the multi-threaded inference example

### Setup the MXNet C++ API
To use the C++ API in MXNet, you need to build MXNet from source with C++ package. Please follow the [built from source guide](/get_started/ubuntu_setup.html), and [C++ Package documentation](/api/cpp)
The summary of those two documents is that you need to build MXNet from source with `USE_CPP_PACKAGE` flag set to 1. For example: `make -j USE_CPP_PACKAGE=1 USE_CUDA=1 USE_CUDNN=1`.
This example requires a build with CUDA and CUDNN.

### Build the example
If you have built mxnet from source with make, then do the following:

```bash
$ cd example/multi_threaded_inference
$ make
```

If you have built mxnet from source with cmake, please uncomment the specific lines for cmake build or set the following environment variables: `MKLDNN_BUILD_DIR (default is $(MXNET_ROOT)/3rdparty/mkldnn/build)`, `MKLDNN_INCLUDE_DIR (default is $(MXNET_ROOT)/3rdparty/mkldnn/include)`, `MXNET_LIB_DIR (default is $(MXNET_ROOT)/lib)`.

### Download the model and run multi threaded inference example
To download a model use the `get_model.py` script. This downloads a model to run inference.

```python
python3 get_model.py --model <model_name>
```
e.g.
```python
python3 get_model.py --model imagenet1k-inception-bn
```
Only the supported models with `get_model.py` work with multi threaded inference.

To run the multi threaded inference example:

First export `LD_LIBRARY_PATH`:

```bash
$ export LD_LIBRARY_PATH=<MXNET_LIB_DIR>:$LD_LIBRARY_PATH
```

```bash
$ ./multi_threaded_inference [model_name] [is_gpu] [file_names]
```
e.g.

```bash
./multi_threaded_inference imagenet1k-inception-bn 2 1 grace_hopper.jpg dog.jpg
```

The above script spawns 2 threads, shares the same cachedop and params among two threads, and runs inference on GPU. It returns the inference results in the order in which files are provided.

NOTE: This example is to demonstrate the multi-threaded-inference with cached op. The inference results work well only with specific models (e.g. imagenet1k-inception-bn). The results may not necessarily be very accurate because of different preprocessing step required etc.

### Code walkthrough multi-threaded inference with CachedOp

The multi threaded inference example (`multi_threaded_inference.cc`) involves the following steps:

1. Parse arguments and load input image into ndarray
2. Prepare input data and load parameters, copying data to a specific context
3. Preparing arguments to pass to the CachedOp and calling C API to **create cached op**
4. Prepare lambda function which will run in spawned threads. Call C API to **invoke cached op** within the lambda function.
5. Spawn multiple threads and wait for all threads to complete.
6. Post process data to obtain inference results and cleanup.

### Step 1: Parse arguments and load input image into ndarray

[https://github.com/apache/incubator-mxnet/example/multi_threaded_inference/multi_threaded_inference.cc#L299-L341](multi_threaded_inference.cc#L299-L341)

The above code parses arguments, loads the image file into a ndarray with a specific shape. There are a few things that are set by default and not configurable. For example, `static_alloc` and `static_shape` are by default set to true.


### Step 2: Prepare input data and load parameters, copying data to a specific context

[https://github.com/apache/incubator-mxnet/example/multi_threaded_inference/multi_threaded_inference.cc#L147-L205](multi_threaded_inference.cc#L147-L205)

The above code loads params and copies input data and params to specific context.

### Step 3: Preparing arguments to pass to the CachedOp and calling C API to create cached op

[https://github.com/apache/incubator-mxnet/example/multi_threaded_inference/multi_threaded_inference.cc#L207-L233](multi_threaded_inference.cc#L207-233)

The above code prepares `flag_key_cstrs` and `flag_val_cstrs` to be passed the Cached op.
The C API call is made with `MXCreateCachedOpEX`. This will lead to creation of thread safe cached
op since the `thread_safe` (which is the last parameter to `MXCreateCachedOpEX`) is set to
true. When this is set to false, it will invoke CachedOp instead of CachedOpThreadSafe.


### Step 4: Prepare lambda function which will run in spawned threads

[https://github.com/apache/incubator-mxnet/example/multi_threaded_inference/multi_threaded_inference.cc#L248-L262](multi_threaded_inference.cc#L248-262)

The above creates the lambda function taking the thread number as the argument.
If `random_sleep` is set it will sleep for a random number (secs) generated between 0 to 5 seconds.
Following this, it invokes `MXInvokeCachedOpEx`(from the hdl it determines whether to invoke cached op threadsafe version or not).
When this is set to false, it will invoke CachedOp instead of CachedOpThreadSafe.

### Step 5: Spawn multiple threads and wait for all threads to complete

[https://github.com/anirudh2290/apache/incubator-mxnet/example/multi_threaded_inference/multi_threaded_inference.cc#L264-L276](multi_threaded_inference.cc#L264-L276)

Spawns multiple threads, joins and waits to wait for all ops to complete.
The other alternative is to wait in the thread on the output ndarray and remove the WaitAll after join.

### Step 6: Post process data to obtain inference results and cleanup

[https://github.com/apache/incubator-/mxnet/example/multi_threaded_inference/multi_threaded_inference.cc#L286-L293](multi_threaded_inference.cc#L286-293)

The above code outputs results for different threads and cleans up the thread safe cached op.

## Current Limitations

1. Only operators tested with the existing model coverage are supported. Other operators and operator types (stateful operators, custom operators are not supported. Existing model coverage is as follows (this list will keep growing as we test more models with different model types):

|Models Tested|MKLDNN|CUDNN|NO-CUDNN|
| --- | --- | --- | --- |
| imagenet1k-resnet-18 | Yes | Yes | Yes |
| imagenet1k-resnet-152 | Yes | Yes | Yes |
| imagenet1k-resnet-50 | Yes | Yes | Yes |

2. Only dense storage types are supported currently.
3. Multi GPU Inference not supported currently.
4. Instantiating multiple instances of SymbolBlockThreadSafe is not supported. Can run parallel inference only on one model per process.
5. dynamic shapes not supported in thread safe cached op.
6. Bulking of ops is not supported.
7. This only supports inference use cases currently, training use cases are not supported.
8. Graph rewrites with subgraph API currently not supported.
9. There is currently no frontend API support to run multi threaded inference. Users can use CreateCachedOpEX and InvokeCachedOp in combination with
the CPP frontend to run multi-threaded inference as of today.
10. Multi threaded inference with threaded engine with Module/Symbolic API and C Predict API are not currently supported.
11. Exception thrown with `wait_to_read` in individual threads can cause issues. Calling invoke from each thread and calling WaitAll after thread joins should still work fine.
12. Tested only on environments supported by CI. This means that MacOS is not supported.

## Future Work

Future work includes Increasing model coverage and addressing most of the limitations mentioned under Current Limitations except the training use case.
For more updates, please subscribe to discussion activity on RFC: https://github.com/apache/incubator-mxnet/issues/16431.
Loading

0 comments on commit b1e4911

Please sign in to comment.